From 8ecf1752650635d736621d4a18aea0af1a12dcf0 Mon Sep 17 00:00:00 2001 From: Corentin Date: Sat, 13 Feb 2021 03:02:03 +0900 Subject: [PATCH] Trainer source saving argument * Saving network state instead of whole instance --- trainer.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/trainer.py b/trainer.py index ddf8c6b..d825e91 100644 --- a/trainer.py +++ b/trainer.py @@ -23,7 +23,7 @@ class Trainer: epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, data_dtype=None, label_dtype=None, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, - logger=DummyLogger(), verbose=True): + logger=DummyLogger(), verbose=True, save_src=True): super().__init__() self.device = device self.output_dir = output_dir @@ -76,14 +76,15 @@ class Trainer: if summary_per_epoch % image_per_epoch == 0: self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch) - torch.save(network, os.path.join(output_dir, 'model_init.pt')) + torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt')) # Save source files - for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( - os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'): - dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) - if not os.path.exists(dirname): - os.makedirs(dirname) - shutil.copy2(entry, os.path.join(output_dir, 'code', entry)) + if save_src: + for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( + os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'): + dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) + if not os.path.exists(dirname): + os.makedirs(dirname) + shutil.copy2(entry, os.path.join(output_dir, 'code', entry)) # Initialize training loop variables self.batch_inputs = batch_inputs @@ -251,10 +252,10 @@ class Trainer: f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') if self.accuracy_fn is not None: - torch.save(self.network, os.path.join( + torch.save(self.network.state_dict(), os.path.join( self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt')) else: - torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt')) + torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt')) self.benchmark_time = time.time() self.benchmark_step = 0 self.running_loss = 0.0