Fix typos
This commit is contained in:
parent
06db437aa4
commit
0cf142571b
1 changed files with 15 additions and 7 deletions
|
|
@ -1,3 +1,9 @@
|
|||
"""
|
||||
Data efficent image transformer (deit)
|
||||
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
|
||||
"""
|
||||
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
|
|
@ -41,7 +47,7 @@ class Attention(nn.Module):
|
|||
qkv = self.qkv(input_data).reshape(
|
||||
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
|
||||
2, 0, 3, 1, 4)
|
||||
# (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count)
|
||||
# (output shape : 3, batch_size, head_count, sequence_lenght, channel_count / head_count)
|
||||
query, key, value = qkv[0], qkv[1], qkv[2]
|
||||
attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1))
|
||||
return self.projection_drop(self.projector(
|
||||
|
|
@ -76,7 +82,7 @@ class VissionTransformer(nn.Module):
|
|||
NORM_LAYER = nn.LayerNorm
|
||||
|
||||
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
|
||||
path_size: int = 16, embed_dim: int = 768,
|
||||
patch_size: int = 16, embed_dim: int = 768,
|
||||
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
|
||||
representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
|
||||
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
|
||||
|
|
@ -92,7 +98,7 @@ class VissionTransformer(nn.Module):
|
|||
self.distilled = distilled
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size,
|
||||
self.patch_embed = embed_layer(image_shape[1:], patch_size=patch_size,
|
||||
in_channels=image_shape[0], embed_dim=embed_dim)
|
||||
patch_count = self.patch_embed.patch_count
|
||||
token_count = 2 if distilled else 1
|
||||
|
|
@ -128,6 +134,7 @@ class VissionTransformer(nn.Module):
|
|||
if self.distilled:
|
||||
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
||||
|
||||
# Applying weights initialization made no difference so far
|
||||
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
||||
|
||||
|
||||
|
|
@ -171,10 +178,11 @@ class VissionTransformer(nn.Module):
|
|||
elif name.startswith('pre_logits'):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Conv2d):
|
||||
nn.init.xavier_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
# pytorch init for conv is fine
|
||||
# elif isinstance(module, nn.Conv2d):
|
||||
# nn.init.xavier_normal_(module.weight)
|
||||
# if module.bias is not None:
|
||||
# nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue