Fix batch norm for conv2d
* Change tensorboard folders * Save root scripts in output code
This commit is contained in:
parent
976f5cef49
commit
42ae4474dd
2 changed files with 10 additions and 7 deletions
|
|
@ -85,7 +85,7 @@ class Conv2d(Layer):
|
||||||
self.batch_norm = nn.BatchNorm2d(
|
self.batch_norm = nn.BatchNorm2d(
|
||||||
out_channels,
|
out_channels,
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(self.conv(input_data))
|
return super().forward(self.conv(input_data))
|
||||||
|
|
@ -118,7 +118,7 @@ class Deconv2d(Layer):
|
||||||
self.batch_norm = nn.BatchNorm2d(
|
self.batch_norm = nn.BatchNorm2d(
|
||||||
out_channels,
|
out_channels,
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(self.deconv(input_data))
|
return super().forward(self.deconv(input_data))
|
||||||
|
|
|
||||||
13
trainer.py
13
trainer.py
|
|
@ -44,8 +44,8 @@ class Trainer:
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
||||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
|
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
|
||||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
|
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
|
||||||
|
|
||||||
# Save network graph
|
# Save network graph
|
||||||
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
||||||
|
|
@ -79,7 +79,7 @@ class Trainer:
|
||||||
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
||||||
# Save source files
|
# Save source files
|
||||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||||
os.path.join('src', '**', '*.py'), recursive=True):
|
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
@ -181,7 +181,6 @@ class Trainer:
|
||||||
train_stop_time = time.time()
|
train_stop_time = time.time()
|
||||||
self.writer_train.close()
|
self.writer_train.close()
|
||||||
self.writer_val.close()
|
self.writer_val.close()
|
||||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
|
||||||
|
|
||||||
memory_peak, gpu_memory = resource_usage()
|
memory_peak, gpu_memory = resource_usage()
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
|
|
@ -251,7 +250,11 @@ class Trainer:
|
||||||
print(f'Step {global_step}, '
|
print(f'Step {global_step}, '
|
||||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||||
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
if self.accuracy_fn is not None:
|
||||||
|
torch.save(self.network, os.path.join(
|
||||||
|
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
||||||
|
else:
|
||||||
|
torch.save(self.network, os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||||
self.benchmark_time = time.time()
|
self.benchmark_time = time.time()
|
||||||
self.benchmark_step = 0
|
self.benchmark_step = 0
|
||||||
self.running_loss = 0.0
|
self.running_loss = 0.0
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue