Vision Transformer
This commit is contained in:
parent
90abb84710
commit
06db437aa4
2 changed files with 200 additions and 30 deletions
60
layers.py
60
layers.py
|
|
@ -38,20 +38,35 @@ class Layer(nn.Module):
|
|||
output = self.batch_norm(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def add_weight_decay(module: nn.Module, weight_decay: float, exclude=()):
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in module.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if len(param.shape) == 1 or name.endswith('.bias') or name in exclude:
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [
|
||||
{'params': no_decay, 'weight_decay': 0.0},
|
||||
{'params': decay, 'weight_decay': weight_decay}]
|
||||
|
||||
|
||||
class Linear(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation)
|
||||
|
||||
self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
self.linear = nn.Linear(in_channels, out_channels, bias=not use_batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm1d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.fc(input_data))
|
||||
return super().forward(self.linear(input_data))
|
||||
|
||||
|
||||
class Conv1d(Layer):
|
||||
|
|
@ -59,9 +74,9 @@ class Conv1d(Layer):
|
|||
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation)
|
||||
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=not self.use_batch_norm, **kwargs)
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=use_batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm1d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
|
|
@ -72,30 +87,30 @@ class Conv1d(Layer):
|
|||
|
||||
|
||||
class Conv2d(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation, use_batch_norm)
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int]] = 3,
|
||||
stride: Union[int, tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation)
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=not self.use_batch_norm, **kwargs)
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=not use_batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
|
||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.conv(input_data))
|
||||
|
||||
|
||||
class Conv3d(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int, int]] = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation)
|
||||
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=not self.use_batch_norm, **kwargs)
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=use_batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm3d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
|
|
@ -110,10 +125,10 @@ class Deconv2d(Layer):
|
|||
stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation)
|
||||
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
self.deconv = nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride=stride,
|
||||
bias=not self.use_batch_norm, **kwargs)
|
||||
use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
bias=not use_batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
|
|
@ -121,3 +136,18 @@ class Deconv2d(Layer):
|
|||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.deconv(input_data))
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
def __init__(self, drop_prob=None):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
if self.drop_prob == 0.0:
|
||||
return input_data
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (input_data.shape[0],) + (1,) * (input_data.ndim - 1)
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=input_data.dtype, device=input_data.device)
|
||||
random_tensor.floor_() # binarize
|
||||
return input_data.div(keep_prob) * random_tensor
|
||||
|
|
|
|||
|
|
@ -1,23 +1,42 @@
|
|||
from functools import partial
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..layers import DropPath, Layer
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, image_shape: tuple[int, int], patch_size: int = 16,
|
||||
in_channels: int = 3, embed_dim: int = 768):
|
||||
super().__init__()
|
||||
patch_count = (image_shape[0] // patch_size) * (image_shape[1] // patch_size)
|
||||
self.image_shape = image_shape
|
||||
self.patch_size = patch_size
|
||||
self.patch_count = patch_count
|
||||
|
||||
self.projector = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(input_data).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim: int, head_count: int = None, qkv_bias: bool = False, qk_scale: float = None,
|
||||
attention_drop: float = None, projection_drop: float = None):
|
||||
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
|
||||
attention_drop_rate: float, projection_drop_rate: float):
|
||||
super().__init__()
|
||||
self.head_count = head_count
|
||||
head_dim = dim // head_count
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attention_drop = nn.Dropout(
|
||||
attention_drop if attention_drop is not None else VisionTransformer.ATTENTION_DROP)
|
||||
self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity()
|
||||
self.projector = nn.Linear(dim, dim)
|
||||
self.projection_drop = nn.Dropout(
|
||||
projection_drop if projection_drop is not None else VisionTransformer.PROJECTION_DROP)
|
||||
self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity()
|
||||
|
||||
def foward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, channel_count = input_data.shape
|
||||
qkv = self.qkv(input_data).reshape(
|
||||
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
|
||||
|
|
@ -29,12 +48,133 @@ class Attention(nn.Module):
|
|||
(attention @ value).transpose(1, 2).reshape(batch_size, sequence_length, channel_count)))
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
HEAD_COUNT = 8
|
||||
MLP_RATIO = 4.0
|
||||
QKV_BIAS = False
|
||||
ATTENTION_DROP = 0.0
|
||||
PROJECTION_DROP = 0.0
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim: int, head_count: int, mlp_ratio: float,
|
||||
qkv_bias: bool, qk_scale: float, drop_rate: float,
|
||||
attention_drop_rate: float, drop_path_rate: float,
|
||||
norm_layer=0, activation=0):
|
||||
super().__init__()
|
||||
|
||||
def __init__(self, dim: int, head_count: int, mlp_ratio: float = None,
|
||||
qkv_bias: bool = None
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(dim, int(dim * mlp_ratio)),
|
||||
activation(),
|
||||
nn.Linear(int(dim * mlp_ratio), dim),
|
||||
nn.Dropout(drop_rate))
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
out = input_data + self.drop_path(self.attention(self.norm1(input_data)))
|
||||
return out + self.drop_path(self.mlp(self.norm2(out)))
|
||||
|
||||
|
||||
class VissionTransformer(nn.Module):
|
||||
QK_SCALE = None
|
||||
ACTIVATION = 0
|
||||
NORM_LAYER = nn.LayerNorm
|
||||
|
||||
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
|
||||
path_size: int = 16, embed_dim: int = 768,
|
||||
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
|
||||
representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
|
||||
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
|
||||
norm_layer=0, activation=0):
|
||||
super().__init__()
|
||||
qk_scale = qk_scale if qk_scale is not None else self.QK_SCALE
|
||||
activation = activation if activation != 0 else self.ACTIVATION
|
||||
activation = activation if activation != 0 else Layer.ACTIVATION
|
||||
norm_layer = norm_layer if norm_layer != 0 else self.NORM_LAYER
|
||||
|
||||
self.class_count = class_count
|
||||
self.feature_count = self.embed_dim = embed_dim
|
||||
self.distilled = distilled
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size,
|
||||
in_channels=image_shape[0], embed_dim=embed_dim)
|
||||
patch_count = self.patch_embed.patch_count
|
||||
token_count = 2 if distilled else 1
|
||||
|
||||
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||
self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
|
||||
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
|
||||
|
||||
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
|
||||
self.blocks = nn.Sequential(*[
|
||||
Block(embed_dim, head_count, mlp_ratio, qkv_bias, qk_scale, drop_rate, attention_drop_rate,
|
||||
pdr, norm_layer, activation) for pdr in depth_path_drop_rates])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# Representation Layer
|
||||
if representation_size and not distilled:
|
||||
self.feature_count = representation_size
|
||||
self.pre_logits = nn.Sequential(
|
||||
nn.Linear(embed_dim, representation_size),
|
||||
nn.Tanh())
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
|
||||
# Final classifier
|
||||
self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity()
|
||||
self.head_distilled = nn.Linear(
|
||||
self.embed_dim, self.class_count) if class_count > 0 and distilled else nn.Identity()
|
||||
|
||||
# Init weights
|
||||
nn.init.trunc_normal_(self.class_token, std=0.02)
|
||||
nn.init.trunc_normal_(self.position_embedings, std=0.02)
|
||||
if self.distilled:
|
||||
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
||||
|
||||
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> dict:
|
||||
return {'class_token', 'distillation_token', 'position_embedings'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
|
||||
|
||||
def reset_classifier(self, class_count: int):
|
||||
self.class_count = class_count
|
||||
self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity()
|
||||
self.head_distilled = nn.Linear(
|
||||
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
embedings = self.patch_embed(input_data)
|
||||
class_token = self.class_token.expand(embedings.shape[0], -1, -1)
|
||||
|
||||
if self.distilled:
|
||||
block_output = self.norm(self.blocks(self.position_drop(
|
||||
torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1)
|
||||
+ self.position_embedings)))
|
||||
distilled_head_output = self.head_distilled(block_output[:, 1])
|
||||
head_output = self.head(block_output[:, 0])
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
return head_output, distilled_head_output
|
||||
return (head_output + distilled_head_output) / 2.0
|
||||
|
||||
block_output = self.norm(self.blocks(self.position_drop(
|
||||
torch.cat((class_token, embedings), dim=1) + self.position_embedings)))
|
||||
return self.head(self.pre_logits(block_output[:, 0]))
|
||||
|
||||
@staticmethod
|
||||
def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.0):
|
||||
if isinstance(module, nn.Linear):
|
||||
if name.startswith('head'):
|
||||
nn.init.zeros_(module.weight)
|
||||
nn.init.constant_(module.bias, head_bias)
|
||||
elif name.startswith('pre_logits'):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue