diff --git a/transformer/vision_transformer.py b/transformer/vision_transformer.py index f066417..04195e2 100644 --- a/transformer/vision_transformer.py +++ b/transformer/vision_transformer.py @@ -1,3 +1,9 @@ +""" +Data efficent image transformer (deit) +from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877 +""" + + from functools import partial import math @@ -41,7 +47,7 @@ class Attention(nn.Module): qkv = self.qkv(input_data).reshape( batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute( 2, 0, 3, 1, 4) - # (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count) + # (output shape : 3, batch_size, head_count, sequence_lenght, channel_count / head_count) query, key, value = qkv[0], qkv[1], qkv[2] attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1)) return self.projection_drop(self.projector( @@ -76,7 +82,7 @@ class VissionTransformer(nn.Module): NORM_LAYER = nn.LayerNorm def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int, - path_size: int = 16, embed_dim: int = 768, + patch_size: int = 16, embed_dim: int = 768, head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None, representation_size=None, distilled: bool = False, drop_rate: float = 0.0, attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed, @@ -92,7 +98,7 @@ class VissionTransformer(nn.Module): self.distilled = distilled norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size, + self.patch_embed = embed_layer(image_shape[1:], patch_size=patch_size, in_channels=image_shape[0], embed_dim=embed_dim) patch_count = self.patch_embed.patch_count token_count = 2 if distilled else 1 @@ -128,6 +134,7 @@ class VissionTransformer(nn.Module): if self.distilled: nn.init.trunc_normal_(self.distillation_token, std=0.02) + # Applying weights initialization made no difference so far self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count))) @@ -171,10 +178,11 @@ class VissionTransformer(nn.Module): elif name.startswith('pre_logits'): nn.init.xavier_normal_(module.weight) nn.init.zeros_(module.bias) - elif isinstance(module, nn.Conv2d): - nn.init.xavier_normal_(module.weight) - if module.bias is not None: - nn.init.zeros_(module.bias) + # pytorch init for conv is fine + # elif isinstance(module, nn.Conv2d): + # nn.init.xavier_normal_(module.weight) + # if module.bias is not None: + # nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias)