diff --git a/residual.py b/residual.py index bdc14c8..e35a009 100644 --- a/residual.py +++ b/residual.py @@ -3,65 +3,51 @@ from typing import Union, Tuple import torch import torch.nn as nn -from .layers import LayerInfo, Layer +from .layers import Conv2d, LayerInfo, Layer class ResBlock(Layer): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - activation=None, **kwargs): + def __init__(self, in_channels: int, out_channels: int = -1, kernel_size: int = 3, padding: int = 1, + stride: Union[int, Tuple[int, int]] = 1, activation=None, batch_norm=None, **kwargs): super().__init__(activation if activation is not None else 0, False) + self.batch_norm = None + if out_channels == -1: + out_channels = in_channels self.seq = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs), - nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING), - torch.nn.LeakyReLU(), - nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1), - nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING)) - self.batch_norm = nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs), + Conv2d(in_channels, out_channels, kernel_size=3, padding=1, + activation=None, batch_norm=batch_norm)) + self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if ( + out_channels != in_channels or stride != 1) else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: + if self.residual is not None: + return super().forward(self.residual(input_data) + self.seq(input_data)) return super().forward(input_data + self.seq(input_data)) class ResBottleneck(Layer): - def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs): + def __init__(self, in_channels: int, out_channels: int = -1, bottleneck_channels: int = -1, kernel_size: int = 3, + stride: Union[int, Tuple[int, int]] = 1, padding=1, + activation=None, batch_norm=None, **kwargs): super().__init__(activation if activation is not None else 0, False) self.batch_norm = None + if out_channels == -1: + out_channels = in_channels + if bottleneck_channels == -1: + bottleneck_channels = in_channels // 4 self.seq = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), - nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING), - torch.nn.LeakyReLU(), - nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs), - nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING), - torch.nn.LeakyReLU(), - nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False), - nn.BatchNorm2d( - out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING)) - self.downsample = nn.Sequential( - nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1), - nn.BatchNorm2d( - planes * out_channels, - momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING)) + Conv2d(in_channels, bottleneck_channels, kernel_size=1), + Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=kernel_size, + stride=stride, padding=padding, **kwargs), + Conv2d(bottleneck_channels, out_channels, kernel_size=1, + activation=None, batch_norm=batch_norm)) + self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if ( + out_channels != in_channels or stride != 1) else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: - return super().forward(self.downsample(input_data) + self.seq(input_data)) + if self.residual is not None: + return super().forward(self.residual(input_data) + self.seq(input_data)) + return super().forward(input_data + self.seq(input_data))