From 4b786943f5d54d1c59a5d07a71a5a7679d3d5855 Mon Sep 17 00:00:00 2001 From: corentin Date: Tue, 1 Dec 2020 11:43:02 +0900 Subject: [PATCH] Fix batch norm training default value --- layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/layers.py b/layers.py index f1531f7..4a5296e 100644 --- a/layers.py +++ b/layers.py @@ -19,7 +19,7 @@ class Layer(nn.Module): ACTIVATION = F.leaky_relu BATCH_NORM = True - BATCH_NORM_TRAINING = False + BATCH_NORM_TRAINING = True BATCH_NORM_MOMENTUM = 0.01 IS_TRAINING = False @@ -28,7 +28,7 @@ class Layer(nn.Module): LOGGER = DummyLogger() def __init__(self, activation, batch_norm): - super(Layer, self).__init__() + super().__init__() self.name = 'Layer' self.info = LayerInfo() @@ -55,7 +55,7 @@ class Conv1d(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -87,7 +87,7 @@ class Conv3d(Layer): self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data))