Trainer source saving argument

* Saving network state instead of whole instance
This commit is contained in:
Corentin 2021-02-13 03:02:03 +09:00
commit 8ecf175265

View file

@ -23,7 +23,7 @@ class Trainer:
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger(), verbose=True):
logger=DummyLogger(), verbose=True, save_src=True):
super().__init__()
self.device = device
self.output_dir = output_dir
@ -76,8 +76,9 @@ class Trainer:
if summary_per_epoch % image_per_epoch == 0:
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
# Save source files
if save_src:
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
@ -251,10 +252,10 @@ class Trainer:
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
if self.accuracy_fn is not None:
torch.save(self.network, os.path.join(
torch.save(self.network.state_dict(), os.path.join(
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
else:
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0