Fix last training loop

This commit is contained in:
Corentin 2021-03-01 23:56:15 +09:00
commit 1fe6cd796e

View file

@ -187,6 +187,8 @@ class Trainer:
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
loss = self.criterion(self.network_outputs, labels)
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
@ -195,6 +197,7 @@ class Trainer:
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.benchmark_step += 1
self.batch_generator_train.next_batch()
self.save_summaries(force_summary=True)
train_stop_time = time.time()
self.writer_train.close()