Fix last training loop
This commit is contained in:
parent
50c395a07f
commit
1fe6cd796e
1 changed files with 3 additions and 0 deletions
|
|
@ -187,6 +187,8 @@ class Trainer:
|
|||
|
||||
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||
self.network_outputs = self.network(self.processed_inputs)
|
||||
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
||||
loss = self.criterion(self.network_outputs, labels)
|
||||
|
||||
self.train_loss = loss.item()
|
||||
self.train_accuracy = self.accuracy_fn(
|
||||
|
|
@ -195,6 +197,7 @@ class Trainer:
|
|||
self.running_accuracy += self.train_accuracy
|
||||
self.running_count += len(self.batch_generator_train.batch_data)
|
||||
self.benchmark_step += 1
|
||||
self.batch_generator_train.next_batch()
|
||||
self.save_summaries(force_summary=True)
|
||||
train_stop_time = time.time()
|
||||
self.writer_train.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue