From 7f900109721d99667d9b2e46994b20b38618d17e Mon Sep 17 00:00:00 2001 From: Corentin Date: Thu, 26 Nov 2020 18:50:38 +0900 Subject: [PATCH] Add Conv3d layer --- layers.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/layers.py b/layers.py index 699fc86..f1531f7 100644 --- a/layers.py +++ b/layers.py @@ -50,7 +50,7 @@ class Conv1d(Layer): stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, @@ -66,7 +66,7 @@ class Conv2d(Layer): stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, @@ -77,6 +77,22 @@ class Conv2d(Layer): return super().forward(self.conv(input_data)) +class Conv3d(Layer): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, + stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): + super().__init__(activation, batch_norm) + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, + bias=not self.batch_norm, **kwargs) + self.batch_norm = nn.BatchNorm3d( + out_channels, + momentum=Layer.BATCH_NORM_MOMENTUM, + track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return super().forward(self.conv(input_data)) + + class Linear(Layer): def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm)