diff --git a/trainer.py b/trainer.py index 12abd73..0412c56 100644 --- a/trainer.py +++ b/trainer.py @@ -197,7 +197,7 @@ class Trainer: val_inputs = torch.as_tensor( self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device) val_labels = torch.as_tensor( - self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device) + self.batch_generator_val.batch_label, dtype=self.label_dtype, device=self.device) val_pre_process = self.pre_process(val_inputs) val_outputs = self.network(val_pre_process)