From 1bac46219b42fe41ba3568fdde3ca364b02e46e9 Mon Sep 17 00:00:00 2001 From: Corentin Date: Tue, 17 Aug 2021 15:54:35 +0900 Subject: [PATCH] Fix dropouts and typos in ViT --- transformer/vision_transformer.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/transformer/vision_transformer.py b/transformer/vision_transformer.py index 04195e2..bdca511 100644 --- a/transformer/vision_transformer.py +++ b/transformer/vision_transformer.py @@ -1,6 +1,7 @@ """ Data efficent image transformer (deit) from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877 +And Vit : https://arxiv.org/abs/2010.11929 """ @@ -29,7 +30,7 @@ class PatchEmbed(nn.Module): return self.projector(input_data).flatten(2).transpose(1, 2) -class Attention(nn.Module): +class SelfAttention(nn.Module): def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float, attention_drop_rate: float, projection_drop_rate: float): super().__init__() @@ -38,9 +39,9 @@ class Attention(nn.Module): self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity() + self.attention_drop = nn.Dropout(attention_drop_rate) self.projector = nn.Linear(dim, dim) - self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity() + self.projection_drop = nn.Dropout(projection_drop_rate) def forward(self, input_data: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, channel_count = input_data.shape @@ -62,7 +63,7 @@ class Block(nn.Module): super().__init__() self.norm1 = norm_layer(dim) - self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate) + self.attention = SelfAttention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = nn.Sequential( @@ -105,7 +106,7 @@ class VissionTransformer(nn.Module): self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None - self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim)) + self.position_embeddings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim)) self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth @@ -130,17 +131,16 @@ class VissionTransformer(nn.Module): # Init weights nn.init.trunc_normal_(self.class_token, std=0.02) - nn.init.trunc_normal_(self.position_embedings, std=0.02) + nn.init.trunc_normal_(self.position_embeddings, std=0.02) if self.distilled: nn.init.trunc_normal_(self.distillation_token, std=0.02) # Applying weights initialization made no difference so far self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count))) - @torch.jit.ignore def no_weight_decay(self) -> dict: - return {'class_token', 'distillation_token', 'position_embedings'} + return {'class_token', 'distillation_token', 'position_embeddings'} def get_classifier(self): return self.head if self.distillation_token is None else (self.head, self.head_distilled) @@ -152,13 +152,13 @@ class VissionTransformer(nn.Module): self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity() def forward(self, input_data: torch.Tensor) -> torch.Tensor: - embedings = self.patch_embed(input_data) - class_token = self.class_token.expand(embedings.shape[0], -1, -1) + embeddings = self.patch_embed(input_data) + class_token = self.class_token.expand(embeddings.shape[0], -1, -1) if self.distilled: block_output = self.norm(self.blocks(self.position_drop( - torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1) - + self.position_embedings))) + torch.cat((class_token, self.distillation_token.expand(embeddings.shape[0], -1, -1), embeddings), dim=1) + + self.position_embeddings))) distilled_head_output = self.head_distilled(block_output[:, 1]) head_output = self.head(block_output[:, 0]) if self.training and not torch.jit.is_scripting(): @@ -166,7 +166,7 @@ class VissionTransformer(nn.Module): return (head_output + distilled_head_output) / 2.0 block_output = self.norm(self.blocks(self.position_drop( - torch.cat((class_token, embedings), dim=1) + self.position_embedings))) + torch.cat((class_token, embeddings), dim=1) + self.position_embeddings))) return self.head(self.pre_logits(block_output[:, 0])) @staticmethod