Fix dropouts and typos in ViT
This commit is contained in:
parent
0cf142571b
commit
1bac46219b
1 changed files with 13 additions and 13 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Data efficent image transformer (deit)
|
Data efficent image transformer (deit)
|
||||||
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
|
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
|
||||||
|
And Vit : https://arxiv.org/abs/2010.11929
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -29,7 +30,7 @@ class PatchEmbed(nn.Module):
|
||||||
return self.projector(input_data).flatten(2).transpose(1, 2)
|
return self.projector(input_data).flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
|
def __init__(self, dim: int, head_count: int, qkv_bias: bool, qk_scale: float,
|
||||||
attention_drop_rate: float, projection_drop_rate: float):
|
attention_drop_rate: float, projection_drop_rate: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -38,9 +39,9 @@ class Attention(nn.Module):
|
||||||
self.scale = qk_scale or head_dim ** -0.5
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
self.attention_drop = nn.Dropout(attention_drop_rate) if attention_drop_rate > 0.0 else nn.Identity()
|
self.attention_drop = nn.Dropout(attention_drop_rate)
|
||||||
self.projector = nn.Linear(dim, dim)
|
self.projector = nn.Linear(dim, dim)
|
||||||
self.projection_drop = nn.Dropout(projection_drop_rate) if projection_drop_rate > 0.0 else nn.Identity()
|
self.projection_drop = nn.Dropout(projection_drop_rate)
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
batch_size, sequence_length, channel_count = input_data.shape
|
batch_size, sequence_length, channel_count = input_data.shape
|
||||||
|
|
@ -62,7 +63,7 @@ class Block(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.norm1 = norm_layer(dim)
|
self.norm1 = norm_layer(dim)
|
||||||
self.attention = Attention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
|
self.attention = SelfAttention(dim, head_count, qkv_bias, qk_scale, attention_drop_rate, drop_rate)
|
||||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||||
self.norm2 = norm_layer(dim)
|
self.norm2 = norm_layer(dim)
|
||||||
self.mlp = nn.Sequential(
|
self.mlp = nn.Sequential(
|
||||||
|
|
@ -105,7 +106,7 @@ class VissionTransformer(nn.Module):
|
||||||
|
|
||||||
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
self.distillation_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||||
self.position_embedings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
|
self.position_embeddings = nn.Parameter(torch.zeros(1, patch_count + token_count, embed_dim))
|
||||||
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
|
self.position_drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity()
|
||||||
|
|
||||||
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
|
depth_path_drop_rates = np.linspace(0, drop_path_rate, depth) if drop_path_rate > 0.0 else [0.0] * depth
|
||||||
|
|
@ -130,17 +131,16 @@ class VissionTransformer(nn.Module):
|
||||||
|
|
||||||
# Init weights
|
# Init weights
|
||||||
nn.init.trunc_normal_(self.class_token, std=0.02)
|
nn.init.trunc_normal_(self.class_token, std=0.02)
|
||||||
nn.init.trunc_normal_(self.position_embedings, std=0.02)
|
nn.init.trunc_normal_(self.position_embeddings, std=0.02)
|
||||||
if self.distilled:
|
if self.distilled:
|
||||||
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
||||||
|
|
||||||
# Applying weights initialization made no difference so far
|
# Applying weights initialization made no difference so far
|
||||||
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self) -> dict:
|
def no_weight_decay(self) -> dict:
|
||||||
return {'class_token', 'distillation_token', 'position_embedings'}
|
return {'class_token', 'distillation_token', 'position_embeddings'}
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
|
return self.head if self.distillation_token is None else (self.head, self.head_distilled)
|
||||||
|
|
@ -152,13 +152,13 @@ class VissionTransformer(nn.Module):
|
||||||
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
|
self.embed_dim, self.class_count) if class_count > 0 and self.distilled else nn.Identity()
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
embedings = self.patch_embed(input_data)
|
embeddings = self.patch_embed(input_data)
|
||||||
class_token = self.class_token.expand(embedings.shape[0], -1, -1)
|
class_token = self.class_token.expand(embeddings.shape[0], -1, -1)
|
||||||
|
|
||||||
if self.distilled:
|
if self.distilled:
|
||||||
block_output = self.norm(self.blocks(self.position_drop(
|
block_output = self.norm(self.blocks(self.position_drop(
|
||||||
torch.cat((class_token, self.distillation_token.expand(embedings.shape[0], -1, -1), embedings), dim=1)
|
torch.cat((class_token, self.distillation_token.expand(embeddings.shape[0], -1, -1), embeddings), dim=1)
|
||||||
+ self.position_embedings)))
|
+ self.position_embeddings)))
|
||||||
distilled_head_output = self.head_distilled(block_output[:, 1])
|
distilled_head_output = self.head_distilled(block_output[:, 1])
|
||||||
head_output = self.head(block_output[:, 0])
|
head_output = self.head(block_output[:, 0])
|
||||||
if self.training and not torch.jit.is_scripting():
|
if self.training and not torch.jit.is_scripting():
|
||||||
|
|
@ -166,7 +166,7 @@ class VissionTransformer(nn.Module):
|
||||||
return (head_output + distilled_head_output) / 2.0
|
return (head_output + distilled_head_output) / 2.0
|
||||||
|
|
||||||
block_output = self.norm(self.blocks(self.position_drop(
|
block_output = self.norm(self.blocks(self.position_drop(
|
||||||
torch.cat((class_token, embedings), dim=1) + self.position_embedings)))
|
torch.cat((class_token, embeddings), dim=1) + self.position_embeddings)))
|
||||||
return self.head(self.pre_logits(block_output[:, 0]))
|
return self.head(self.pre_logits(block_output[:, 0]))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue