Fix dropouts and typos in ViT
This commit is contained in:
parent
0cf142571b
commit
1bac46219b
1 changed files with 13 additions and 13 deletions
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Data efficent image transformer (deit)
|
||||
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
|
||||
And Vit : https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
|
||||
|
||||
|
|
@ -29,7 +30,7 @@ class PatchEmbed(nn.Module):
|
|||
return self.projector(input_data).flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
|
||||
attention_drop_rate: float, projection_drop_rate: float):
|
||||
super().__init__()
|
||||
|
|
@ -38,9 +39,9 @@ class Attention(nn.Module):
|
|||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity()
|
||||
self.attention_drop = nn.Dropout(attention_drop_rate)
|
||||
self.projector = nn.Linear(dim, dim)
|
||||
self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity()
|
||||
self.projection_drop = nn.Dropout(projection_drop_rate)
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, channel_count = input_data.shape
|
||||
|
|
@ -62,7 +63,7 @@ class Block(nn.Module):
|
|||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
|
||||
self.attention = SelfAttention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = nn.Sequential(
|
||||
|
|
@ -105,7 +106,7 @@ class VissionTransformer(nn.Module):
|
|||
|
||||
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||
self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
|
||||
self.position_embeddings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
|
||||
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
|
||||
|
||||
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
|
||||
|
|
@ -130,17 +131,16 @@ class VissionTransformer(nn.Module):
|
|||
|
||||
# Init weights
|
||||
nn.init.trunc_normal_(self.class_token, std=0.02)
|
||||
nn.init.trunc_normal_(self.position_embedings, std=0.02)
|
||||
nn.init.trunc_normal_(self.position_embeddings, std=0.02)
|
||||
if self.distilled:
|
||||
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
||||
|
||||
# Applying weights initialization made no difference so far
|
||||
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self) -> dict:
|
||||
return {'class_token', 'distillation_token', 'position_embedings'}
|
||||
return {'class_token', 'distillation_token', 'position_embeddings'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
|
||||
|
|
@ -152,13 +152,13 @@ class VissionTransformer(nn.Module):
|
|||
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
embedings = self.patch_embed(input_data)
|
||||
class_token = self.class_token.expand(embedings.shape[0], -1, -1)
|
||||
embeddings = self.patch_embed(input_data)
|
||||
class_token = self.class_token.expand(embeddings.shape[0], -1, -1)
|
||||
|
||||
if self.distilled:
|
||||
block_output = self.norm(self.blocks(self.position_drop(
|
||||
torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1)
|
||||
+ self.position_embedings)))
|
||||
torch.cat((class_token, self.distillation_token.expand(embeddings.shape[0], -1, -1), embeddings), dim=1)
|
||||
+ self.position_embeddings)))
|
||||
distilled_head_output = self.head_distilled(block_output[:, 1])
|
||||
head_output = self.head(block_output[:, 0])
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
|
|
@ -166,7 +166,7 @@ class VissionTransformer(nn.Module):
|
|||
return (head_output + distilled_head_output) / 2.0
|
||||
|
||||
block_output = self.norm(self.blocks(self.position_drop(
|
||||
torch.cat((class_token, embedings), dim=1) + self.position_embedings)))
|
||||
torch.cat((class_token, embeddings), dim=1) + self.position_embeddings)))
|
||||
return self.head(self.pre_logits(block_output[:, 0]))
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue