Trainer last summary fix + memory utils

This commit is contained in:
Corentin 2021-02-25 02:18:02 +09:00
commit 50c395a07f
3 changed files with 54 additions and 3 deletions

View file

@ -178,6 +178,23 @@ class Trainer:
if self.verbose:
print()
# Small training loop for last metrics
for _ in range(20):
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.benchmark_step += 1
self.save_summaries(force_summary=True)
train_stop_time = time.time()
self.writer_train.close()