From 54000b6c3405ebd54358f13fb1e26eb107698730 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 21 Jan 2021 20:36:22 +0900 Subject: [PATCH] Fixed default use_batch_norm value --- layers.py | 50 +++++++++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/layers.py b/layers.py index 0d7ae78..1f916d0 100644 --- a/layers.py +++ b/layers.py @@ -2,49 +2,49 @@ from typing import Union, Tuple import torch import torch.nn as nn -import torch.nn.functional as F from .utils.logger import DummyLogger class Layer(nn.Module): # Default layer arguments - ACTIVATION = F.leaky_relu + ACTIVATION = torch.nn.LeakyReLU + ACTIVATION_KWARGS = {"negative_slope": 0.1} - BATCH_NORM = True + USE_BATCH_NORM = True BATCH_NORM_TRAINING = True BATCH_NORM_MOMENTUM = 0.01 IS_TRAINING = False METRICS = False - VERBOSE = 0 LOGGER = DummyLogger() - def __init__(self, activation): + def __init__(self, activation, use_batch_norm): super().__init__() self.name = 'Layer' # Preload default self.activation = Layer.ACTIVATION if activation == 0 else activation + self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = input_data if self.activation is not None: output = self.activation(output) - if self.batch_norm is not None: + if self.use_batch_norm is not None: output = self.batch_norm(output) return output class Linear(Layer): - def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.fc = nn.Linear(in_channels, out_channels, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -52,15 +52,15 @@ class Linear(Layer): class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -68,15 +68,15 @@ class Conv1d(Layer): class Conv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -84,15 +84,15 @@ class Conv2d(Layer): class Conv3d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -100,16 +100,16 @@ class Conv3d(Layer): class Deconv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data))