torch_utils/residual.py
2021-04-29 19:45:32 +09:00

53 lines
2.5 KiB
Python

from typing import Union, Tuple
import torch
import torch.nn as nn
from .layers import Conv2d, LayerInfo, Layer
class ResBlock(Layer):
def __init__(self, in_channels: int, out_channels: int = -1, kernel_size: int = 3, padding: int = 1,
stride: Union[int, Tuple[int, int]] = 1, activation=None, batch_norm=None, **kwargs):
super().__init__(activation if activation is not None else 0, False)
self.batch_norm = None
if out_channels == -1:
out_channels = in_channels
self.seq = nn.Sequential(
Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs),
Conv2d(in_channels, out_channels, kernel_size=3, padding=1,
activation=None, batch_norm=batch_norm))
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
out_channels != in_channels or stride != 1) else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
if self.residual is not None:
return super().forward(self.residual(input_data) + self.seq(input_data))
return super().forward(input_data + self.seq(input_data))
class ResBottleneck(Layer):
def __init__(self, in_channels: int, out_channels: int = -1, bottleneck_channels: int = -1, kernel_size: int = 3,
stride: Union[int, Tuple[int, int]] = 1, padding=1,
activation=None, batch_norm=None, **kwargs):
super().__init__(activation if activation is not None else 0, False)
self.batch_norm = None
if out_channels == -1:
out_channels = in_channels
if bottleneck_channels == -1:
bottleneck_channels = in_channels // 4
self.seq = nn.Sequential(
Conv2d(in_channels, bottleneck_channels, kernel_size=1),
Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=kernel_size,
stride=stride, padding=padding, **kwargs),
Conv2d(bottleneck_channels, out_channels, kernel_size=1,
activation=None, batch_norm=batch_norm))
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
out_channels != in_channels or stride != 1) else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
if self.residual is not None:
return super().forward(self.residual(input_data) + self.seq(input_data))
return super().forward(input_data + self.seq(input_data))