From 42ae4474dda4139aa8db0c4d75fbd8d4da44b0e9 Mon Sep 17 00:00:00 2001 From: Corentin Date: Sat, 30 Jan 2021 02:11:06 +0900 Subject: [PATCH] Fix batch norm for conv2d * Change tensorboard folders * Save root scripts in output code --- layers.py | 4 ++-- trainer.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/layers.py b/layers.py index 10df5f2..142aac0 100644 --- a/layers.py +++ b/layers.py @@ -85,7 +85,7 @@ class Conv2d(Layer): self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -118,7 +118,7 @@ class Deconv2d(Layer): self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) diff --git a/trainer.py b/trainer.py index 9d8c05d..ddf8c6b 100644 --- a/trainer.py +++ b/trainer.py @@ -44,8 +44,8 @@ class Trainer: self.optimizer = optimizer self.criterion = criterion self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None - self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) - self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) + self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30) + self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30) # Save network graph batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device) @@ -79,7 +79,7 @@ class Trainer: torch.save(network, os.path.join(output_dir, 'model_init.pt')) # Save source files for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( - os.path.join('src', '**', '*.py'), recursive=True): + os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'): dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) if not os.path.exists(dirname): os.makedirs(dirname) @@ -181,7 +181,6 @@ class Trainer: train_stop_time = time.time() self.writer_train.close() self.writer_val.close() - torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) memory_peak, gpu_memory = resource_usage() self.logger.info( @@ -251,7 +250,11 @@ class Trainer: print(f'Step {global_step}, ' f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') - torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) + if self.accuracy_fn is not None: + torch.save(self.network, os.path.join( + self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt')) + else: + torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt')) self.benchmark_time = time.time() self.benchmark_step = 0 self.running_loss = 0.0