Fix validation writer in Trainer
* Add last step summary (forced)
This commit is contained in:
parent
d315f342a4
commit
976f5cef49
1 changed files with 4 additions and 2 deletions
|
|
@ -177,8 +177,10 @@ class Trainer:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
self.save_summaries(force_summary=True)
|
||||||
train_stop_time = time.time()
|
train_stop_time = time.time()
|
||||||
self.writer_train.close()
|
self.writer_train.close()
|
||||||
|
self.writer_val.close()
|
||||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||||
|
|
||||||
memory_peak, gpu_memory = resource_usage()
|
memory_peak, gpu_memory = resource_usage()
|
||||||
|
|
@ -186,7 +188,7 @@ class Trainer:
|
||||||
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
||||||
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
||||||
|
|
||||||
def save_summaries(self):
|
def save_summaries(self, force_summary=False):
|
||||||
global_step = self.batch_generator_train.global_step
|
global_step = self.batch_generator_train.global_step
|
||||||
if self.batch_generator_train.epoch < self.epoch_skip:
|
if self.batch_generator_train.epoch < self.epoch_skip:
|
||||||
return
|
return
|
||||||
|
|
@ -224,7 +226,7 @@ class Trainer:
|
||||||
self.network.train(True)
|
self.network.train(True)
|
||||||
|
|
||||||
# Add summaries
|
# Add summaries
|
||||||
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||||
self.writer_train.add_scalar(
|
self.writer_train.add_scalar(
|
||||||
'loss', self.running_loss / self.running_count, global_step=global_step)
|
'loss', self.running_loss / self.running_count, global_step=global_step)
|
||||||
self.writer_val.add_scalar(
|
self.writer_val.add_scalar(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue