Fix validation writer in Trainer

* Add last step summary (forced)
This commit is contained in:
Corentin 2021-01-27 02:04:12 +09:00
commit 976f5cef49

View file

@ -177,8 +177,10 @@ class Trainer:
if self.verbose:
print()
self.save_summaries(force_summary=True)
train_stop_time = time.time()
self.writer_train.close()
self.writer_val.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage()
@ -186,7 +188,7 @@ class Trainer:
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
def save_summaries(self):
def save_summaries(self, force_summary=False):
global_step = self.batch_generator_train.global_step
if self.batch_generator_train.epoch < self.epoch_skip:
return
@ -224,7 +226,7 @@ class Trainer:
self.network.train(True)
# Add summaries
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
self.writer_train.add_scalar(
'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar(