From a4280a1b78dffcea39ff918dc1731855aacc67e9 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Fri, 22 Jan 2021 12:38:07 +0900 Subject: [PATCH] Fixed issues: layers now use self.use_batch_norm instead of default value, fixed Layer's forward --- layers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/layers.py b/layers.py index 1f916d0..3bedb24 100644 --- a/layers.py +++ b/layers.py @@ -21,9 +21,8 @@ class Layer(nn.Module): def __init__(self, activation, use_batch_norm): super().__init__() - self.name = 'Layer' - # Preload default + self.batch_norm: torch.nn._BatchNorm = None self.activation = Layer.ACTIVATION if activation == 0 else activation self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm @@ -31,7 +30,8 @@ class Layer(nn.Module): output = input_data if self.activation is not None: output = self.activation(output) - if self.use_batch_norm is not None: + if self.use_batch_norm: + # It is assumed here that if using batch norm, then self.batch_norm has been instanciated. output = self.batch_norm(output) return output @@ -44,7 +44,7 @@ class Linear(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -60,7 +60,7 @@ class Conv1d(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -92,7 +92,7 @@ class Conv3d(Layer): self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data))