Trainer source saving argument
* Saving network state instead of whole instance
This commit is contained in:
parent
42ae4474dd
commit
8ecf175265
1 changed files with 11 additions and 10 deletions
21
trainer.py
21
trainer.py
|
|
@ -23,7 +23,7 @@ class Trainer:
|
|||
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||
data_dtype=None, label_dtype=None,
|
||||
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||
logger=DummyLogger(), verbose=True):
|
||||
logger=DummyLogger(), verbose=True, save_src=True):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.output_dir = output_dir
|
||||
|
|
@ -76,14 +76,15 @@ class Trainer:
|
|||
if summary_per_epoch % image_per_epoch == 0:
|
||||
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
||||
|
||||
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
||||
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
|
||||
# Save source files
|
||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||
if save_src:
|
||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||
|
||||
# Initialize training loop variables
|
||||
self.batch_inputs = batch_inputs
|
||||
|
|
@ -251,10 +252,10 @@ class Trainer:
|
|||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||
if self.accuracy_fn is not None:
|
||||
torch.save(self.network, os.path.join(
|
||||
torch.save(self.network.state_dict(), os.path.join(
|
||||
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
||||
else:
|
||||
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||
self.benchmark_time = time.time()
|
||||
self.benchmark_step = 0
|
||||
self.running_loss = 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue