Fix batch norm for conv2d
* Change tensorboard folders * Save root scripts in output code
This commit is contained in:
parent
976f5cef49
commit
42ae4474dd
2 changed files with 10 additions and 7 deletions
13
trainer.py
13
trainer.py
|
|
@ -44,8 +44,8 @@ class Trainer:
|
|||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
|
||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
|
||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
|
||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
|
||||
|
||||
# Save network graph
|
||||
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
||||
|
|
@ -79,7 +79,7 @@ class Trainer:
|
|||
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
||||
# Save source files
|
||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||
os.path.join('src', '**', '*.py'), recursive=True):
|
||||
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
|
@ -181,7 +181,6 @@ class Trainer:
|
|||
train_stop_time = time.time()
|
||||
self.writer_train.close()
|
||||
self.writer_val.close()
|
||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||
|
||||
memory_peak, gpu_memory = resource_usage()
|
||||
self.logger.info(
|
||||
|
|
@ -251,7 +250,11 @@ class Trainer:
|
|||
print(f'Step {global_step}, '
|
||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
||||
if self.accuracy_fn is not None:
|
||||
torch.save(self.network, os.path.join(
|
||||
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
||||
else:
|
||||
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||
self.benchmark_time = time.time()
|
||||
self.benchmark_step = 0
|
||||
self.running_loss = 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue