Fix dropouts and typos in ViT

This commit is contained in:
Corentin 2021-08-17 15:54:35 +09:00
commit 1bac46219b

View file

@ -1,6 +1,7 @@
"""
Data efficent image transformer (deit)
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
And Vit : https://arxiv.org/abs/2010.11929
"""
@ -29,7 +30,7 @@ class PatchEmbed(nn.Module):
return self.projector(input_data).flatten(2).transpose(1, 2)
class Attention(nn.Module):
class SelfAttention(nn.Module):
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
attention_drop_rate: float, projection_drop_rate: float):
super().__init__()
@ -38,9 +39,9 @@ class Attention(nn.Module):
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity()
self.attention_drop = nn.Dropout(attention_drop_rate)
self.projector = nn.Linear(dim, dim)
self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity()
self.projection_drop = nn.Dropout(projection_drop_rate)
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, channel_count = input_data.shape
@ -62,7 +63,7 @@ class Block(nn.Module):
super().__init__()
self.norm1 = norm_layer(dim)
self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
self.attention = SelfAttention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = nn.Sequential(
@ -105,7 +106,7 @@ class VissionTransformer(nn.Module):
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
self.position_embeddings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
@ -130,17 +131,16 @@ class VissionTransformer(nn.Module):
# Init weights
nn.init.trunc_normal_(self.class_token, std=0.02)
nn.init.trunc_normal_(self.position_embedings, std=0.02)
nn.init.trunc_normal_(self.position_embeddings, std=0.02)
if self.distilled:
nn.init.trunc_normal_(self.distillation_token, std=0.02)
# Applying weights initialization made no difference so far
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
@torch.jit.ignore
def no_weight_decay(self) -> dict:
return {'class_token', 'distillation_token', 'position_embedings'}
return {'class_token', 'distillation_token', 'position_embeddings'}
def get_classifier(self):
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
@ -152,13 +152,13 @@ class VissionTransformer(nn.Module):
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
embedings = self.patch_embed(input_data)
class_token = self.class_token.expand(embedings.shape[0], -1, -1)
embeddings = self.patch_embed(input_data)
class_token = self.class_token.expand(embeddings.shape[0], -1, -1)
if self.distilled:
block_output = self.norm(self.blocks(self.position_drop(
torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1)
+ self.position_embedings)))
torch.cat((class_token, self.distillation_token.expand(embeddings.shape[0], -1, -1), embeddings), dim=1)
+ self.position_embeddings)))
distilled_head_output = self.head_distilled(block_output[:, 1])
head_output = self.head(block_output[:, 0])
if self.training and not torch.jit.is_scripting():
@ -166,7 +166,7 @@ class VissionTransformer(nn.Module):
return (head_output + distilled_head_output) / 2.0
block_output = self.norm(self.blocks(self.position_drop(
torch.cat((class_token, embedings), dim=1) + self.position_embedings)))
torch.cat((class_token, embeddings), dim=1) + self.position_embeddings)))
return self.head(self.pre_logits(block_output[:, 0]))
@staticmethod