diff --git a/layers.py b/layers.py index 3966c43..d27697e 100644 --- a/layers.py +++ b/layers.py @@ -38,20 +38,35 @@ class Layer(nn.Module): output = self.batch_norm(output) return output + @staticmethod + def add_weight_decay(module: nn.Module, weight_decay: float, exclude=()): + decay = [] + no_decay = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + continue + if len(param.shape) == 1 or name.endswith('.bias') or name in exclude: + no_decay.append(param) + else: + decay.append(param) + return [ + {'params': no_decay, 'weight_decay': 0.0}, + {'params': decay, 'weight_decay': weight_decay}] + class Linear(Layer): def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): super().__init__(activation) - self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs) use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.linear = nn.Linear(in_channels, out_channels, bias=not use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: - return super().forward(self.fc(input_data)) + return super().forward(self.linear(input_data)) class Conv1d(Layer): @@ -59,9 +74,9 @@ class Conv1d(Layer): stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): super().__init__(activation) - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.use_batch_norm, **kwargs) use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, + bias=use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -72,30 +87,30 @@ class Conv1d(Layer): class Conv2d(Layer): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int]] = 3, + stride: Union[int, tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.use_batch_norm, **kwargs) use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, + bias=not use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) class Conv3d(Layer): - def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, + def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, tuple[int, int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): super().__init__(activation) - self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.use_batch_norm, **kwargs) use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, + bias=use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -110,10 +125,10 @@ class Deconv2d(Layer): stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): super().__init__(activation) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, - bias=not self.use_batch_norm, **kwargs) - use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + bias=not use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -121,3 +136,18 @@ class Deconv2d(Layer): def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + if self.drop_prob == 0.0: + return input_data + keep_prob = 1 - self.drop_prob + shape = (input_data.shape[0],) + (1,) * (input_data.ndim - 1) + random_tensor = keep_prob + torch.rand(shape, dtype=input_data.dtype, device=input_data.device) + random_tensor.floor_() # binarize + return input_data.div(keep_prob) * random_tensor diff --git a/transformer/vision_transformer.py b/transformer/vision_transformer.py index 2e5ef4b..f066417 100644 --- a/transformer/vision_transformer.py +++ b/transformer/vision_transformer.py @@ -1,23 +1,42 @@ +from functools import partial +import math + +import numpy as np import torch import torch.nn as nn +from ..layers import DropPath, Layer + + +class PatchEmbed(nn.Module): + def __init__(self, image_shape: tuple[int, int], patch_size: int = 16, + in_channels: int = 3, embed_dim: int = 768): + super().__init__() + patch_count = (image_shape[0] // patch_size) * (image_shape[1] // patch_size) + self.image_shape = image_shape + self.patch_size = patch_size + self.patch_count = patch_count + + self.projector = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return self.projector(input_data).flatten(2).transpose(1, 2) + class Attention(nn.Module): - def __init__(self, dim: int, head_count: int = None, qkv_bias: bool = False, qk_scale: float = None, - attention_drop: float = None, projection_drop: float = None): + def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float, + attention_drop_rate: float, projection_drop_rate: float): super().__init__() self.head_count = head_count head_dim = dim // head_count self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attention_drop = nn.Dropout( - attention_drop if attention_drop is not None else VisionTransformer.ATTENTION_DROP) + self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity() self.projector = nn.Linear(dim, dim) - self.projection_drop = nn.Dropout( - projection_drop if projection_drop is not None else VisionTransformer.PROJECTION_DROP) + self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity() - def foward(self, input_data: torch.Tensor) -> torch.Tensor: + def forward(self, input_data: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, channel_count = input_data.shape qkv = self.qkv(input_data).reshape( batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute( @@ -29,12 +48,133 @@ class Attention(nn.Module): (attention @ value).transpose(1, 2).reshape(batch_size, sequence_length, channel_count))) -class VisionTransformer(nn.Module): - HEAD_COUNT = 8 - MLP_RATIO = 4.0 - QKV_BIAS = False - ATTENTION_DROP = 0.0 - PROJECTION_DROP = 0.0 +class Block(nn.Module): + def __init__(self, dim: int, head_count: int, mlp_ratio: float, + qkv_bias: bool, qk_scale: float, drop_rate: float, + attention_drop_rate: float, drop_path_rate: float, + norm_layer=0, activation=0): + super().__init__() - def __init__(self, dim: int, head_count: int, mlp_ratio: float = None, - qkv_bias: bool = None + self.norm1 = norm_layer(dim) + self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, int(dim * mlp_ratio)), + activation(), + nn.Linear(int(dim * mlp_ratio), dim), + nn.Dropout(drop_rate)) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + out = input_data + self.drop_path(self.attention(self.norm1(input_data))) + return out + self.drop_path(self.mlp(self.norm2(out))) + + +class VissionTransformer(nn.Module): + QK_SCALE = None + ACTIVATION = 0 + NORM_LAYER = nn.LayerNorm + + def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int, + path_size: int = 16, embed_dim: int = 768, + head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None, + representation_size=None, distilled: bool = False, drop_rate: float = 0.0, + attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed, + norm_layer=0, activation=0): + super().__init__() + qk_scale = qk_scale if qk_scale is not None else self.QK_SCALE + activation = activation if activation != 0 else self.ACTIVATION + activation = activation if activation != 0 else Layer.ACTIVATION + norm_layer = norm_layer if norm_layer != 0 else self.NORM_LAYER + + self.class_count = class_count + self.feature_count = self.embed_dim = embed_dim + self.distilled = distilled + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size, + in_channels=image_shape[0], embed_dim=embed_dim) + patch_count = self.patch_embed.patch_count + token_count = 2 if distilled else 1 + + self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim)) + self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth + self.blocks = nn.Sequential(*[ + Block(embed_dim, head_count, mlp_ratio, qkv_bias, qk_scale, drop_rate, attention_drop_rate, + pdr, norm_layer, activation) for pdr in depth_path_drop_rates]) + self.norm = norm_layer(embed_dim) + + # Representation Layer + if representation_size and not distilled: + self.feature_count = representation_size + self.pre_logits = nn.Sequential( + nn.Linear(embed_dim, representation_size), + nn.Tanh()) + else: + self.pre_logits = nn.Identity() + + # Final classifier + self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity() + self.head_distilled = nn.Linear( + self.embed_dim, self.class_count) if class_count > 0 and distilled else nn.Identity() + + # Init weights + nn.init.trunc_normal_(self.class_token, std=0.02) + nn.init.trunc_normal_(self.position_embedings, std=0.02) + if self.distilled: + nn.init.trunc_normal_(self.distillation_token, std=0.02) + + self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count))) + + + @torch.jit.ignore + def no_weight_decay(self) -> dict: + return {'class_token', 'distillation_token', 'position_embedings'} + + def get_classifier(self): + return self.head if self.distillation_token is None else (self.head, self.head_distilled) + + def reset_classifier(self, class_count: int): + self.class_count = class_count + self.head = nn.Linear(self.feature_count, class_count) if class_count > 0 else nn.Identity() + self.head_distilled = nn.Linear( + self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity() + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + embedings = self.patch_embed(input_data) + class_token = self.class_token.expand(embedings.shape[0], -1, -1) + + if self.distilled: + block_output = self.norm(self.blocks(self.position_drop( + torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1) + + self.position_embedings))) + distilled_head_output = self.head_distilled(block_output[:, 1]) + head_output = self.head(block_output[:, 0]) + if self.training and not torch.jit.is_scripting(): + return head_output, distilled_head_output + return (head_output + distilled_head_output) / 2.0 + + block_output = self.norm(self.blocks(self.position_drop( + torch.cat((class_token, embedings), dim=1) + self.position_embedings))) + return self.head(self.pre_logits(block_output[:, 0])) + + @staticmethod + def _init_weights(module: nn.Module, name: str = '', head_bias: float = 0.0): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + nn.init.xavier_normal_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias)