From 9ac6fb64e8484002c4543829ed5870c77469520c Mon Sep 17 00:00:00 2001 From: Corentin Date: Wed, 13 Jan 2021 00:14:04 +0900 Subject: [PATCH] Fix trainer validation loop --- trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer.py b/trainer.py index 12abd73..0412c56 100644 --- a/trainer.py +++ b/trainer.py @@ -197,7 +197,7 @@ class Trainer: val_inputs = torch.as_tensor( self.batch_generator_val.batch_data, dtype=self.data_dtype, device=self.device) val_labels = torch.as_tensor( - self.batch_generator_val.batch_data, dtype=self.label_dtype, device=self.device) + self.batch_generator_val.batch_label, dtype=self.label_dtype, device=self.device) val_pre_process = self.pre_process(val_inputs) val_outputs = self.network(val_pre_process)