Fix typos
This commit is contained in:
parent
06db437aa4
commit
0cf142571b
1 changed files with 15 additions and 7 deletions
|
|
@ -1,3 +1,9 @@
|
||||||
|
"""
|
||||||
|
Data efficent image transformer (deit)
|
||||||
|
from https://github.com/facebookresearch/deit, https://arxiv.org/abs/2012.12877
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
@ -41,7 +47,7 @@ class Attention(nn.Module):
|
||||||
qkv = self.qkv(input_data).reshape(
|
qkv = self.qkv(input_data).reshape(
|
||||||
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
|
batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute(
|
||||||
2, 0, 3, 1, 4)
|
2, 0, 3, 1, 4)
|
||||||
# (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count)
|
# (output shape : 3, batch_size, head_count, sequence_lenght, channel_count / head_count)
|
||||||
query, key, value = qkv[0], qkv[1], qkv[2]
|
query, key, value = qkv[0], qkv[1], qkv[2]
|
||||||
attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1))
|
attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1))
|
||||||
return self.projection_drop(self.projector(
|
return self.projection_drop(self.projector(
|
||||||
|
|
@ -76,7 +82,7 @@ class VissionTransformer(nn.Module):
|
||||||
NORM_LAYER = nn.LayerNorm
|
NORM_LAYER = nn.LayerNorm
|
||||||
|
|
||||||
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
|
def __init__(self, image_shape: tuple[int, int, int], class_count: int, depth: int,
|
||||||
path_size: int = 16, embed_dim: int = 768,
|
patch_size: int = 16, embed_dim: int = 768,
|
||||||
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
|
head_count: int = 8, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_scale: float = None,
|
||||||
representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
|
representation_size=None, distilled: bool = False, drop_rate: float = 0.0,
|
||||||
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
|
attention_drop_rate: float = 0.0, drop_path_rate: float = 0.0, embed_layer=PatchEmbed,
|
||||||
|
|
@ -92,7 +98,7 @@ class VissionTransformer(nn.Module):
|
||||||
self.distilled = distilled
|
self.distilled = distilled
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
self.patch_embed = embed_layer(image_shape[1:], patch_size=path_size,
|
self.patch_embed = embed_layer(image_shape[1:], patch_size=patch_size,
|
||||||
in_channels=image_shape[0], embed_dim=embed_dim)
|
in_channels=image_shape[0], embed_dim=embed_dim)
|
||||||
patch_count = self.patch_embed.patch_count
|
patch_count = self.patch_embed.patch_count
|
||||||
token_count = 2 if distilled else 1
|
token_count = 2 if distilled else 1
|
||||||
|
|
@ -128,6 +134,7 @@ class VissionTransformer(nn.Module):
|
||||||
if self.distilled:
|
if self.distilled:
|
||||||
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
nn.init.trunc_normal_(self.distillation_token, std=0.02)
|
||||||
|
|
||||||
|
# Applying weights initialization made no difference so far
|
||||||
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
self.apply(partial(self._init_weights, head_bias=-math.log(self.class_count)))
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,10 +178,11 @@ class VissionTransformer(nn.Module):
|
||||||
elif name.startswith('pre_logits'):
|
elif name.startswith('pre_logits'):
|
||||||
nn.init.xavier_normal_(module.weight)
|
nn.init.xavier_normal_(module.weight)
|
||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
elif isinstance(module, nn.Conv2d):
|
# pytorch init for conv is fine
|
||||||
nn.init.xavier_normal_(module.weight)
|
# elif isinstance(module, nn.Conv2d):
|
||||||
if module.bias is not None:
|
# nn.init.xavier_normal_(module.weight)
|
||||||
nn.init.zeros_(module.bias)
|
# if module.bias is not None:
|
||||||
|
# nn.init.zeros_(module.bias)
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
nn.init.ones_(module.weight)
|
nn.init.ones_(module.weight)
|
||||||
nn.init.zeros_(module.bias)
|
nn.init.zeros_(module.bias)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue