From 7f4a1620337516271be729acbd58dbd82bc3ad56 Mon Sep 17 00:00:00 2001 From: Corentin Risselin Date: Tue, 7 Jul 2020 11:04:57 +0900 Subject: [PATCH] Fix Batch norm tracking --- layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/layers.py b/layers.py index d698dc5..757196c 100644 --- a/layers.py +++ b/layers.py @@ -68,7 +68,7 @@ class Linear(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data))