Fix validation writer in Trainer
* Add last step summary (forced)
This commit is contained in:
parent
d315f342a4
commit
976f5cef49
1 changed files with 4 additions and 2 deletions
|
|
@ -177,8 +177,10 @@ class Trainer:
|
|||
if self.verbose:
|
||||
print()
|
||||
|
||||
self.save_summaries(force_summary=True)
|
||||
train_stop_time = time.time()
|
||||
self.writer_train.close()
|
||||
self.writer_val.close()
|
||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||
|
||||
memory_peak, gpu_memory = resource_usage()
|
||||
|
|
@ -186,7 +188,7 @@ class Trainer:
|
|||
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
||||
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
||||
|
||||
def save_summaries(self):
|
||||
def save_summaries(self, force_summary=False):
|
||||
global_step = self.batch_generator_train.global_step
|
||||
if self.batch_generator_train.epoch < self.epoch_skip:
|
||||
return
|
||||
|
|
@ -224,7 +226,7 @@ class Trainer:
|
|||
self.network.train(True)
|
||||
|
||||
# Add summaries
|
||||
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||
self.writer_train.add_scalar(
|
||||
'loss', self.running_loss / self.running_count, global_step=global_step)
|
||||
self.writer_val.add_scalar(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue