diff --git a/layers.py b/layers.py index d698dc5..757196c 100644 --- a/layers.py +++ b/layers.py @@ -68,7 +68,7 @@ class Linear(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data))