diff --git a/trainer.py b/trainer.py index 772cef3..9d8c05d 100644 --- a/trainer.py +++ b/trainer.py @@ -177,8 +177,10 @@ class Trainer: if self.verbose: print() + self.save_summaries(force_summary=True) train_stop_time = time.time() self.writer_train.close() + self.writer_val.close() torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) memory_peak, gpu_memory = resource_usage() @@ -186,7 +188,7 @@ class Trainer: f'Training time : {train_stop_time - train_start_time:.03f}s\n' f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}') - def save_summaries(self): + def save_summaries(self, force_summary=False): global_step = self.batch_generator_train.global_step if self.batch_generator_train.epoch < self.epoch_skip: return @@ -224,7 +226,7 @@ class Trainer: self.network.train(True) # Add summaries - if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1): + if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1): self.writer_train.add_scalar( 'loss', self.running_loss / self.running_count, global_step=global_step) self.writer_val.add_scalar(