from typing import Union, Tuple import torch import torch.nn as nn from .layers import LayerInfo, Layer class ResBlock(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, activation=None, **kwargs): super().__init__(activation if activation is not None else 0, False) self.seq = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs), nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING), torch.nn.LeakyReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1), nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING)) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(input_data + self.seq(input_data)) class ResBottleneck(Layer): def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs): super().__init__(activation if activation is not None else 0, False) self.batch_norm = None self.seq = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING), torch.nn.LeakyReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs), nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING), torch.nn.LeakyReLU(), nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False), nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING)) self.downsample = nn.Sequential( nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1), nn.BatchNorm2d( planes * out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=not Layer.BATCH_NORM_TRAINING)) def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.downsample(input_data) + self.seq(input_data))