Improve ResNet layers
This commit is contained in:
parent
9b0c8ec49d
commit
8d13de5711
1 changed files with 30 additions and 44 deletions
74
residual.py
74
residual.py
|
|
@ -3,65 +3,51 @@ from typing import Union, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .layers import LayerInfo, Layer
|
from .layers import Conv2d, LayerInfo, Layer
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(Layer):
|
class ResBlock(Layer):
|
||||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
def __init__(self, in_channels: int, out_channels: int = -1, kernel_size: int = 3, padding: int = 1,
|
||||||
activation=None, **kwargs):
|
stride: Union[int, Tuple[int, int]] = 1, activation=None, batch_norm=None, **kwargs):
|
||||||
super().__init__(activation if activation is not None else 0, False)
|
super().__init__(activation if activation is not None else 0, False)
|
||||||
|
self.batch_norm = None
|
||||||
|
if out_channels == -1:
|
||||||
|
out_channels = in_channels
|
||||||
|
|
||||||
self.seq = nn.Sequential(
|
self.seq = nn.Sequential(
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs),
|
Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs),
|
||||||
nn.BatchNorm2d(
|
Conv2d(in_channels, out_channels, kernel_size=3, padding=1,
|
||||||
out_channels,
|
activation=None, batch_norm=batch_norm))
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
out_channels != in_channels or stride != 1) else None
|
||||||
torch.nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1),
|
|
||||||
nn.BatchNorm2d(
|
|
||||||
out_channels,
|
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
|
||||||
self.batch_norm = nn.BatchNorm2d(
|
|
||||||
out_channels,
|
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.residual is not None:
|
||||||
|
return super().forward(self.residual(input_data) + self.seq(input_data))
|
||||||
return super().forward(input_data + self.seq(input_data))
|
return super().forward(input_data + self.seq(input_data))
|
||||||
|
|
||||||
|
|
||||||
class ResBottleneck(Layer):
|
class ResBottleneck(Layer):
|
||||||
def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3,
|
def __init__(self, in_channels: int, out_channels: int = -1, bottleneck_channels: int = -1, kernel_size: int = 3,
|
||||||
stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs):
|
stride: Union[int, Tuple[int, int]] = 1, padding=1,
|
||||||
|
activation=None, batch_norm=None, **kwargs):
|
||||||
super().__init__(activation if activation is not None else 0, False)
|
super().__init__(activation if activation is not None else 0, False)
|
||||||
self.batch_norm = None
|
self.batch_norm = None
|
||||||
|
if out_channels == -1:
|
||||||
|
out_channels = in_channels
|
||||||
|
if bottleneck_channels == -1:
|
||||||
|
bottleneck_channels = in_channels // 4
|
||||||
|
|
||||||
self.seq = nn.Sequential(
|
self.seq = nn.Sequential(
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
Conv2d(in_channels, bottleneck_channels, kernel_size=1),
|
||||||
nn.BatchNorm2d(
|
Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=kernel_size,
|
||||||
out_channels,
|
stride=stride, padding=padding, **kwargs),
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
Conv2d(bottleneck_channels, out_channels, kernel_size=1,
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
activation=None, batch_norm=batch_norm))
|
||||||
torch.nn.LeakyReLU(),
|
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
|
||||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs),
|
out_channels != in_channels or stride != 1) else None
|
||||||
nn.BatchNorm2d(
|
|
||||||
out_channels,
|
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
|
||||||
torch.nn.LeakyReLU(),
|
|
||||||
nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False),
|
|
||||||
nn.BatchNorm2d(
|
|
||||||
out_channels,
|
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
|
||||||
self.downsample = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1),
|
|
||||||
nn.BatchNorm2d(
|
|
||||||
planes * out_channels,
|
|
||||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
|
||||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(self.downsample(input_data) + self.seq(input_data))
|
if self.residual is not None:
|
||||||
|
return super().forward(self.residual(input_data) + self.seq(input_data))
|
||||||
|
return super().forward(input_data + self.seq(input_data))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue