Fix batch norm for conv2d

* Change tensorboard folders
* Save root scripts in output code
This commit is contained in:
Corentin 2021-01-30 02:11:06 +09:00
commit 42ae4474dd
2 changed files with 10 additions and 7 deletions

View file

@ -44,8 +44,8 @@ class Trainer:
self.optimizer = optimizer
self.criterion = criterion
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
# Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
@ -79,7 +79,7 @@ class Trainer:
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
# Save source files
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True):
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname):
os.makedirs(dirname)
@ -181,7 +181,6 @@ class Trainer:
train_stop_time = time.time()
self.writer_train.close()
self.writer_val.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage()
self.logger.info(
@ -251,7 +250,11 @@ class Trainer:
print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
if self.accuracy_fn is not None:
torch.save(self.network, os.path.join(
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
else:
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
self.benchmark_time = time.time()
self.benchmark_step = 0
self.running_loss = 0.0