Fix batch norm for conv2d

* Change tensorboard folders
* Save root scripts in output code
This commit is contained in:
Corentin 2021-01-30 02:11:06 +09:00
commit 42ae4474dd
2 changed files with 10 additions and 7 deletions

View file

@ -85,7 +85,7 @@ class Conv2d(Layer):
self.batch_norm = nn.BatchNorm2d( self.batch_norm = nn.BatchNorm2d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.conv(input_data)) return super().forward(self.conv(input_data))
@ -118,7 +118,7 @@ class Deconv2d(Layer):
self.batch_norm = nn.BatchNorm2d( self.batch_norm = nn.BatchNorm2d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.deconv(input_data)) return super().forward(self.deconv(input_data))

View file

@ -44,8 +44,8 @@ class Trainer:
self.optimizer = optimizer self.optimizer = optimizer
self.criterion = criterion self.criterion = criterion
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
# Save network graph # Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device) batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
@ -79,7 +79,7 @@ class Trainer:
torch.save(network, os.path.join(output_dir, 'model_init.pt')) torch.save(network, os.path.join(output_dir, 'model_init.pt'))
# Save source files # Save source files
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True): os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
@ -181,7 +181,6 @@ class Trainer:
train_stop_time = time.time() train_stop_time = time.time()
self.writer_train.close() self.writer_train.close()
self.writer_val.close() self.writer_val.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
memory_peak, gpu_memory = resource_usage() memory_peak, gpu_memory = resource_usage()
self.logger.info( self.logger.info(
@ -251,7 +250,11 @@ class Trainer:
print(f'Step {global_step}, ' print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) if self.accuracy_fn is not None:
torch.save(self.network, os.path.join(
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
else:
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
self.benchmark_time = time.time() self.benchmark_time = time.time()
self.benchmark_step = 0 self.benchmark_step = 0
self.running_loss = 0.0 self.running_loss = 0.0