Fix last training loop
This commit is contained in:
parent
50c395a07f
commit
1fe6cd796e
1 changed files with 3 additions and 0 deletions
|
|
@ -187,6 +187,8 @@ class Trainer:
|
||||||
|
|
||||||
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||||
self.network_outputs = self.network(self.processed_inputs)
|
self.network_outputs = self.network(self.processed_inputs)
|
||||||
|
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
||||||
|
loss = self.criterion(self.network_outputs, labels)
|
||||||
|
|
||||||
self.train_loss = loss.item()
|
self.train_loss = loss.item()
|
||||||
self.train_accuracy = self.accuracy_fn(
|
self.train_accuracy = self.accuracy_fn(
|
||||||
|
|
@ -195,6 +197,7 @@ class Trainer:
|
||||||
self.running_accuracy += self.train_accuracy
|
self.running_accuracy += self.train_accuracy
|
||||||
self.running_count += len(self.batch_generator_train.batch_data)
|
self.running_count += len(self.batch_generator_train.batch_data)
|
||||||
self.benchmark_step += 1
|
self.benchmark_step += 1
|
||||||
|
self.batch_generator_train.next_batch()
|
||||||
self.save_summaries(force_summary=True)
|
self.save_summaries(force_summary=True)
|
||||||
train_stop_time = time.time()
|
train_stop_time = time.time()
|
||||||
self.writer_train.close()
|
self.writer_train.close()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue