Fix typos

This commit is contained in:
Corentin 2021-05-25 14:06:22 +09:00
commit 0cf142571b

View file

@ -1,3 +1,9 @@
"""
Data efficent image transformer (deit)
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
"""
from functools import partial from functools import partial
import math import math
@ -41,7 +47,7 @@ class Attention(nn.Module):
qkv = self.qkv(input_data).reshape( qkv = self.qkv(input_data).reshape(
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute( batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
2, 0, 3, 1, 4) 2, 0, 3, 1, 4)
# (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count) # (output shape : 3, batch_size, head_count, sequence_lenght, channel_count / head_count)
query, key, value = qkv[0], qkv[1], qkv[2] query, key, value = qkv[0], qkv[1], qkv[2]
attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1)) attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1))
return self.projection_drop(self.projector( return self.projection_drop(self.projector(
@ -76,7 +82,7 @@ class VissionTransformer(nn.Module):
NORM_LAYER = nn.LayerNorm NORM_LAYER = nn.LayerNorm
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int, def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
path_size: int = 16, embed_dim: int = 768, patch_size: int = 16, embed_dim: int = 768,
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None, head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
representation_size=None, distilled: bool = False, drop_rate: float = 0.0, representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed, attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
@ -92,7 +98,7 @@ class VissionTransformer(nn.Module):
self.distilled = distilled self.distilled = distilled
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size, self.patch_embed = embed_layer(image_shape[1:], patch_size=patch_size,
in_channels=image_shape[0], embed_dim=embed_dim) in_channels=image_shape[0], embed_dim=embed_dim)
patch_count = self.patch_embed.patch_count patch_count = self.patch_embed.patch_count
token_count = 2 if distilled else 1 token_count = 2 if distilled else 1
@ -128,6 +134,7 @@ class VissionTransformer(nn.Module):
if self.distilled: if self.distilled:
nn.init.trunc_normal_(self.distillation_token, std=0.02) nn.init.trunc_normal_(self.distillation_token, std=0.02)
# Applying weights initialization made no difference so far
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count))) self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
@ -171,10 +178,11 @@ class VissionTransformer(nn.Module):
elif name.startswith('pre_logits'): elif name.startswith('pre_logits'):
nn.init.xavier_normal_(module.weight) nn.init.xavier_normal_(module.weight)
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d): # pytorch init for conv is fine
nn.init.xavier_normal_(module.weight) # elif isinstance(module, nn.Conv2d):
if module.bias is not None: # nn.init.xavier_normal_(module.weight)
nn.init.zeros_(module.bias) # if module.bias is not None:
# nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm): elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight) nn.init.ones_(module.weight)
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)