Residual blocks, precache for BatchGenerator
This commit is contained in:
parent
7f4a162033
commit
5081cf63fe
3 changed files with 139 additions and 4 deletions
67
residual.py
Normal file
67
residual.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Union, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .layers import LayerInfo, Layer
|
||||
|
||||
|
||||
class ResBlock(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||
activation=None, **kwargs):
|
||||
super().__init__(activation if activation is not None else 0, False)
|
||||
|
||||
self.seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input_data + self.seq(input_data))
|
||||
|
||||
|
||||
class ResBottleneck(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs):
|
||||
super().__init__(activation if activation is not None else 0, False)
|
||||
self.batch_norm = None
|
||||
|
||||
self.seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1),
|
||||
nn.BatchNorm2d(
|
||||
planes * out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.downsample(input_data) + self.seq(input_data))
|
||||
Loading…
Add table
Add a link
Reference in a new issue