Trainer source saving argument
* Saving network state instead of whole instance
This commit is contained in:
parent
42ae4474dd
commit
8ecf175265
1 changed files with 11 additions and 10 deletions
21
trainer.py
21
trainer.py
|
|
@ -23,7 +23,7 @@ class Trainer:
|
||||||
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||||
data_dtype=None, label_dtype=None,
|
data_dtype=None, label_dtype=None,
|
||||||
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||||
logger=DummyLogger(), verbose=True):
|
logger=DummyLogger(), verbose=True, save_src=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
@ -76,14 +76,15 @@ class Trainer:
|
||||||
if summary_per_epoch % image_per_epoch == 0:
|
if summary_per_epoch % image_per_epoch == 0:
|
||||||
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
||||||
|
|
||||||
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
|
||||||
# Save source files
|
# Save source files
|
||||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
if save_src:
|
||||||
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||||
if not os.path.exists(dirname):
|
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||||
os.makedirs(dirname)
|
if not os.path.exists(dirname):
|
||||||
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
os.makedirs(dirname)
|
||||||
|
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||||
|
|
||||||
# Initialize training loop variables
|
# Initialize training loop variables
|
||||||
self.batch_inputs = batch_inputs
|
self.batch_inputs = batch_inputs
|
||||||
|
|
@ -251,10 +252,10 @@ class Trainer:
|
||||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||||
if self.accuracy_fn is not None:
|
if self.accuracy_fn is not None:
|
||||||
torch.save(self.network, os.path.join(
|
torch.save(self.network.state_dict(), os.path.join(
|
||||||
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
||||||
else:
|
else:
|
||||||
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||||
self.benchmark_time = time.time()
|
self.benchmark_time = time.time()
|
||||||
self.benchmark_step = 0
|
self.benchmark_step = 0
|
||||||
self.running_loss = 0.0
|
self.running_loss = 0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue