Vision Transformer

This commit is contained in:
Corentin 2021-05-22 01:18:39 +09:00
commit 06db437aa4
2 changed files with 200 additions and 30 deletions

View file

@ -38,20 +38,35 @@ class Layer(nn.Module):
output = self.batch_norm(output)
return output
@staticmethod
def add_weight_decay(module: nn.Module, weight_decay: float, exclude=()):
decay = []
no_decay = []
for name, param in module.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name.endswith('.bias') or name in exclude:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.0},
{'params': decay, 'weight_decay': weight_decay}]
class Linear(Layer):
def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation)
self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
self.linear = nn.Linear(in_channels, out_channels, bias=not use_batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm1d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.fc(input_data))
return super().forward(self.linear(input_data))
class Conv1d(Layer):
@ -59,9 +74,9 @@ class Conv1d(Layer):
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
bias=not self.use_batch_norm, **kwargs)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
bias=use_batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm1d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
@ -72,30 +87,30 @@ class Conv1d(Layer):
class Conv2d(Layer):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation, use_batch_norm)
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int]] = 3,
stride: Union[int, tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
bias=not self.use_batch_norm, **kwargs)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
bias=not use_batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.conv(input_data))
class Conv3d(Layer):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int, int]] = 3,
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
bias=not self.use_batch_norm, **kwargs)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
bias=use_batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm3d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
@ -110,10 +125,10 @@ class Deconv2d(Layer):
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
self.deconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride=stride,
bias=not self.use_batch_norm, **kwargs)
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
bias=not use_batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
@ -121,3 +136,18 @@ class Deconv2d(Layer):
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.deconv(input_data))
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0:
return input_data
keep_prob = 1 - self.drop_prob
shape = (input_data.shape[0],) + (1,) * (input_data.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=input_data.dtype, device=input_data.device)
random_tensor.floor_() # binarize
return input_data.div(keep_prob) * random_tensor

View file

@ -1,23 +1,42 @@
from functools import partial
import math
import numpy as np
import torch
import torch.nn as nn
from ..layers import DropPath, Layer
class PatchEmbed(nn.Module):
def __init__(self, image_shape: tuple[int, int], patch_size: int = 16,
in_channels: int = 3, embed_dim: int = 768):
super().__init__()
patch_count = (image_shape[0] // patch_size) * (image_shape[1] // patch_size)
self.image_shape = image_shape
self.patch_size = patch_size
self.patch_count = patch_count
self.projector = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return self.projector(input_data).flatten(2).transpose(1, 2)
class Attention(nn.Module):
def __init__(self, dim: int, head_count: int = None, qkv_bias: bool = False, qk_scale: float = None,
attention_drop: float = None, projection_drop: float = None):
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
attention_drop_rate: float, projection_drop_rate: float):
super().__init__()
self.head_count = head_count
head_dim = dim // head_count
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attention_drop = nn.Dropout(
attention_drop if attention_drop is not None else VisionTransformer.ATTENTION_DROP)
self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity()
self.projector = nn.Linear(dim, dim)
self.projection_drop = nn.Dropout(
projection_drop if projection_drop is not None else VisionTransformer.PROJECTION_DROP)
self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity()
def foward(self, input_data: torch.Tensor) -> torch.Tensor:
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, channel_count = input_data.shape
qkv = self.qkv(input_data).reshape(
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
@ -29,12 +48,133 @@ class Attention(nn.Module):
(attention @ value).transpose(1, 2).reshape(batch_size, sequence_length, channel_count)))
class VisionTransformer(nn.Module):
HEAD_COUNT = 8
MLP_RATIO = 4.0
QKV_BIAS = False
ATTENTION_DROP = 0.0
PROJECTION_DROP = 0.0
class Block(nn.Module):
def __init__(self, dim: int, head_count: int, mlp_ratio: float,
qkv_bias: bool, qk_scale: float, drop_rate: float,
attention_drop_rate: float, drop_path_rate: float,
norm_layer=0, activation=0):
super().__init__()
def __init__(self, dim: int, head_count: int, mlp_ratio: float = None,
qkv_bias: bool = None
self.norm1 = norm_layer(dim)
self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, int(dim * mlp_ratio)),
activation(),
nn.Linear(int(dim * mlp_ratio), dim),
nn.Dropout(drop_rate))
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
out = input_data + self.drop_path(self.attention(self.norm1(input_data)))
return out + self.drop_path(self.mlp(self.norm2(out)))
class VissionTransformer(nn.Module):
QK_SCALE = None
ACTIVATION = 0
NORM_LAYER = nn.LayerNorm
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
path_size: int = 16, embed_dim: int = 768,
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
norm_layer=0, activation=0):
super().__init__()
qk_scale = qk_scale if qk_scale is not None else self.QK_SCALE
activation = activation if activation != 0 else self.ACTIVATION
activation = activation if activation != 0 else Layer.ACTIVATION
norm_layer = norm_layer if norm_layer != 0 else self.NORM_LAYER
self.class_count = class_count
self.feature_count = self.embed_dim = embed_dim
self.distilled = distilled
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size,
in_channels=image_shape[0], embed_dim=embed_dim)
patch_count = self.patch_embed.patch_count
token_count = 2 if distilled else 1
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
self.blocks = nn.Sequential(*[
Block(embed_dim, head_count, mlp_ratio, qkv_bias, qk_scale, drop_rate, attention_drop_rate,
pdr, norm_layer, activation) for pdr in depth_path_drop_rates])
self.norm = norm_layer(embed_dim)
# Representation Layer
if representation_size and not distilled:
self.feature_count = representation_size
self.pre_logits = nn.Sequential(
nn.Linear(embed_dim, representation_size),
nn.Tanh())
else:
self.pre_logits = nn.Identity()
# Final classifier
self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity()
self.head_distilled = nn.Linear(
self.embed_dim, self.class_count) if class_count > 0 and distilled else nn.Identity()
# Init weights
nn.init.trunc_normal_(self.class_token, std=0.02)
nn.init.trunc_normal_(self.position_embedings, std=0.02)
if self.distilled:
nn.init.trunc_normal_(self.distillation_token, std=0.02)
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
@torch.jit.ignore
def no_weight_decay(self) -> dict:
return {'class_token', 'distillation_token', 'position_embedings'}
def get_classifier(self):
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
def reset_classifier(self, class_count: int):
self.class_count = class_count
self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity()
self.head_distilled = nn.Linear(
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
embedings = self.patch_embed(input_data)
class_token = self.class_token.expand(embedings.shape[0], -1, -1)
if self.distilled:
block_output = self.norm(self.blocks(self.position_drop(
torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1)
+ self.position_embedings)))
distilled_head_output = self.head_distilled(block_output[:, 1])
head_output = self.head(block_output[:, 0])
if self.training and not torch.jit.is_scripting():
return head_output, distilled_head_output
return (head_output + distilled_head_output) / 2.0
block_output = self.norm(self.blocks(self.position_drop(
torch.cat((class_token, embedings), dim=1) + self.position_embedings)))
return self.head(self.pre_logits(block_output[:, 0]))
@staticmethod
def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.0):
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
nn.init.xavier_normal_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)