From 1fe6cd796e2a6f42563f9b52dc6ed3399e5eb711 Mon Sep 17 00:00:00 2001 From: Corentin Date: Mon, 1 Mar 2021 23:56:15 +0900 Subject: [PATCH] Fix last training loop --- trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trainer.py b/trainer.py index bccf875..4f94396 100644 --- a/trainer.py +++ b/trainer.py @@ -187,6 +187,8 @@ class Trainer: self.processed_inputs = self.train_pre_process(self.batch_inputs) self.network_outputs = self.network(self.processed_inputs) + labels = self.batch_labels if not self.data_is_label else self.processed_inputs + loss = self.criterion(self.network_outputs, labels) self.train_loss = loss.item() self.train_accuracy = self.accuracy_fn( @@ -195,6 +197,7 @@ class Trainer: self.running_accuracy += self.train_accuracy self.running_count += len(self.batch_generator_train.batch_data) self.benchmark_step += 1 + self.batch_generator_train.next_batch() self.save_summaries(force_summary=True) train_stop_time = time.time() self.writer_train.close()