Skip to content

Commit da733c7

Browse files
mvpatel2000Bandish Shah
authored andcommitted
extend test and patch bug (#2028)
1 parent 3618c63 commit da733c7

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

composer/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2604,10 +2604,11 @@ def eval(
26042604
evaluators = self.state.evaluators
26052605

26062606
for evaluator in evaluators:
2607+
eval_subset_num_batches = evaluator.subset_num_batches if subset_num_batches == -1 else subset_num_batches
26072608
self._eval_loop(
26082609
dataloader=evaluator.dataloader,
26092610
dataloader_label=evaluator.label,
2610-
subset_num_batches=subset_num_batches,
2611+
subset_num_batches=eval_subset_num_batches,
26112612
metrics=self.state.eval_metrics[evaluator.label],
26122613
)
26132614
if eval_passed_in:

tests/trainer/test_trainer_eval.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,23 +103,27 @@ def test_trainer_eval_loop():
103103
assert trainer.state.eval_metrics['eval']['MulticlassAccuracy'].compute() != 0.0
104104

105105

106-
def test_trainer_eval_subset_num_batches():
106+
@pytest.mark.parametrize('evaluator_on_init,subset_on_init', [[True, True], [True, False], [False, False]])
107+
def test_trainer_eval_subset_num_batches(evaluator_on_init: bool, subset_on_init: bool):
108+
dataset = RandomClassificationDataset()
109+
eval_dataloader = DataLoader(
110+
dataset=dataset,
111+
sampler=dist.get_sampler(dataset),
112+
)
113+
107114
# Construct the trainer
108115
event_counter_callback = EventCounterCallback()
109116
trainer = Trainer(
110117
model=SimpleModel(),
111118
callbacks=[event_counter_callback],
119+
eval_dataloader=eval_dataloader if evaluator_on_init else None,
120+
eval_subset_num_batches=1 if subset_on_init else -1,
112121
)
113122

114123
# Evaluate the model
115-
dataset = RandomClassificationDataset()
116-
eval_dataloader = DataLoader(
117-
dataset=dataset,
118-
sampler=dist.get_sampler(dataset),
119-
)
120124
trainer.eval(
121-
eval_dataloader=eval_dataloader,
122-
subset_num_batches=1,
125+
eval_dataloader=eval_dataloader if not evaluator_on_init else None,
126+
subset_num_batches=1 if not subset_on_init else -1,
123127
)
124128

125129
# Ensure that just one batch was evaluated

0 commit comments

Comments
 (0)