diff --git a/layers.py b/layers.py index f1531f7..4a5296e 100644 --- a/layers.py +++ b/layers.py @@ -19,7 +19,7 @@ class Layer(nn.Module): ACTIVATION = F.leaky_relu BATCH_NORM = True - BATCH_NORM_TRAINING = False + BATCH_NORM_TRAINING = True BATCH_NORM_MOMENTUM = 0.01 IS_TRAINING = False @@ -28,7 +28,7 @@ class Layer(nn.Module): LOGGER = DummyLogger() def __init__(self, activation, batch_norm): - super(Layer, self).__init__() + super().__init__() self.name = 'Layer' self.info = LayerInfo() @@ -55,7 +55,7 @@ class Conv1d(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -87,7 +87,7 @@ class Conv3d(Layer): self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data))