From 7a6f5821bd6615eabb6bc91ba9169828eb00243a Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 21 Jan 2021 16:10:10 +0900 Subject: [PATCH 1/5] Introduced the use_batch_norm variable, removed old code --- layers.py | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/layers.py b/layers.py index 10df5f2..0d7ae78 100644 --- a/layers.py +++ b/layers.py @@ -7,13 +7,6 @@ import torch.nn.functional as F from .utils.logger import DummyLogger -class LayerInfo(): - def __init__(self): - self.memory = 0.0 - self.ops = 0.0 - self.output = 0.0 - - class Layer(nn.Module): # Default layer arguments ACTIVATION = F.leaky_relu @@ -27,14 +20,12 @@ class Layer(nn.Module): VERBOSE = 0 LOGGER = DummyLogger() - def __init__(self, activation, batch_norm): + def __init__(self, activation): super().__init__() self.name = 'Layer' - self.info = LayerInfo() # Preload default self.activation = Layer.ACTIVATION if activation == 0 else activation - self.batch_norm = Layer.BATCH_NORM if batch_norm is None else batch_norm def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = input_data @@ -46,14 +37,14 @@ class Layer(nn.Module): class Linear(Layer): - def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): - super().__init__(activation, batch_norm) + def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = False, **kwargs): + super().__init__(activation) self.fc = nn.Linear(in_channels, out_channels, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -61,15 +52,15 @@ class Linear(Layer): class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): - super().__init__(activation, batch_norm) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): + super().__init__(activation) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -77,15 +68,15 @@ class Conv1d(Layer): class Conv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): - super().__init__(activation, batch_norm) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): + super().__init__(activation) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -93,15 +84,15 @@ class Conv2d(Layer): class Conv3d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): - super().__init__(activation, batch_norm) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): + super().__init__(activation) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -109,8 +100,8 @@ class Conv3d(Layer): class Deconv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, batch_norm=None, **kwargs): - super().__init__(activation, batch_norm) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): + super().__init__(activation) self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, @@ -118,7 +109,7 @@ class Deconv2d(Layer): self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) From 54000b6c3405ebd54358f13fb1e26eb107698730 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Thu, 21 Jan 2021 20:36:22 +0900 Subject: [PATCH 2/5] Fixed default use_batch_norm value --- layers.py | 50 +++++++++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/layers.py b/layers.py index 0d7ae78..1f916d0 100644 --- a/layers.py +++ b/layers.py @@ -2,49 +2,49 @@ from typing import Union, Tuple import torch import torch.nn as nn -import torch.nn.functional as F from .utils.logger import DummyLogger class Layer(nn.Module): # Default layer arguments - ACTIVATION = F.leaky_relu + ACTIVATION = torch.nn.LeakyReLU + ACTIVATION_KWARGS = {"negative_slope": 0.1} - BATCH_NORM = True + USE_BATCH_NORM = True BATCH_NORM_TRAINING = True BATCH_NORM_MOMENTUM = 0.01 IS_TRAINING = False METRICS = False - VERBOSE = 0 LOGGER = DummyLogger() - def __init__(self, activation): + def __init__(self, activation, use_batch_norm): super().__init__() self.name = 'Layer' # Preload default self.activation = Layer.ACTIVATION if activation == 0 else activation + self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = input_data if self.activation is not None: output = self.activation(output) - if self.batch_norm is not None: + if self.use_batch_norm is not None: output = self.batch_norm(output) return output class Linear(Layer): - def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.fc = nn.Linear(in_channels, out_channels, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -52,15 +52,15 @@ class Linear(Layer): class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -68,15 +68,15 @@ class Conv1d(Layer): class Conv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -84,15 +84,15 @@ class Conv2d(Layer): class Conv3d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -100,16 +100,16 @@ class Conv3d(Layer): class Deconv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, - stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = False, **kwargs): - super().__init__(activation) + stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): + super().__init__(activation, use_batch_norm) self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, - bias=not self.batch_norm, **kwargs) + bias=not Layer.USE_BATCH_NORM, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=not Layer.BATCH_NORM_TRAINING) if use_batch_norm else None + track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) From a4280a1b78dffcea39ff918dc1731855aacc67e9 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Fri, 22 Jan 2021 12:38:07 +0900 Subject: [PATCH 3/5] Fixed issues: layers now use self.use_batch_norm instead of default value, fixed Layer's forward --- layers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/layers.py b/layers.py index 1f916d0..3bedb24 100644 --- a/layers.py +++ b/layers.py @@ -21,9 +21,8 @@ class Layer(nn.Module): def __init__(self, activation, use_batch_norm): super().__init__() - self.name = 'Layer' - # Preload default + self.batch_norm: torch.nn._BatchNorm = None self.activation = Layer.ACTIVATION if activation == 0 else activation self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm @@ -31,7 +30,8 @@ class Layer(nn.Module): output = input_data if self.activation is not None: output = self.activation(output) - if self.use_batch_norm is not None: + if self.use_batch_norm: + # It is assumed here that if using batch norm, then self.batch_norm has been instanciated. output = self.batch_norm(output) return output @@ -44,7 +44,7 @@ class Linear(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -60,7 +60,7 @@ class Conv1d(Layer): self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -92,7 +92,7 @@ class Conv3d(Layer): self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) From ce6314bf5eb824000611deaaeff628150ac404b7 Mon Sep 17 00:00:00 2001 From: Hoel Bagard Date: Fri, 22 Jan 2021 12:48:33 +0900 Subject: [PATCH 4/5] Fixed bias --- layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/layers.py b/layers.py index 3bedb24..aaf219f 100644 --- a/layers.py +++ b/layers.py @@ -56,7 +56,7 @@ class Conv1d(Layer): super().__init__(activation, use_batch_norm) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, - bias=not Layer.USE_BATCH_NORM, **kwargs) + bias=not self.use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -72,7 +72,7 @@ class Conv2d(Layer): super().__init__(activation, use_batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, - bias=not Layer.USE_BATCH_NORM, **kwargs) + bias=not self.use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -88,7 +88,7 @@ class Conv3d(Layer): super().__init__(activation, use_batch_norm) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, - bias=not Layer.USE_BATCH_NORM, **kwargs) + bias=not self.use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -105,7 +105,7 @@ class Deconv2d(Layer): self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, - bias=not Layer.USE_BATCH_NORM, **kwargs) + bias=not self.use_batch_norm, **kwargs) self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, From 770a9a4f8206d8553006936ceed770d22a155036 Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 21 May 2021 16:00:16 +0900 Subject: [PATCH 5/5] Avoid use_batch_norm as layers instance variable --- layers.py | 31 +++++++++++++----------- transformer/vision_transformer.py | 40 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 transformer/vision_transformer.py diff --git a/layers.py b/layers.py index 93e7749..3966c43 100644 --- a/layers.py +++ b/layers.py @@ -19,7 +19,7 @@ class Layer(nn.Module): METRICS = False LOGGER = DummyLogger() - def __init__(self, activation, use_batch_norm): + def __init__(self, activation): super().__init__() # Preload default if activation == 0: @@ -28,28 +28,27 @@ class Layer(nn.Module): self.activation = activation() else: self.activation = activation - self.batch_norm: torch.nn._BatchNorm = None - self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.batch_norm: torch.nn._BatchNorm def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = input_data - if self.activation is not None: + if self.activation: output = self.activation(output) - if self.use_batch_norm: - # It is assumed here that if using batch norm, then self.batch_norm has been instanciated. + if self.batch_norm: output = self.batch_norm(output) return output class Linear(Layer): def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -58,14 +57,15 @@ class Linear(Layer): class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -78,6 +78,7 @@ class Conv2d(Layer): self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -90,14 +91,15 @@ class Conv2d(Layer): class Conv3d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -106,15 +108,16 @@ class Conv3d(Layer): class Deconv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) diff --git a/transformer/vision_transformer.py b/transformer/vision_transformer.py new file mode 100644 index 0000000..2e5ef4b --- /dev/null +++ b/transformer/vision_transformer.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +class Attention(nn.Module): + def __init__(self, dim: int, head_count: int = None, qkv_bias: bool = False, qk_scale: float = None, + attention_drop: float = None, projection_drop: float = None): + super().__init__() + self.head_count = head_count + head_dim = dim // head_count + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attention_drop = nn.Dropout( + attention_drop if attention_drop is not None else VisionTransformer.ATTENTION_DROP) + self.projector = nn.Linear(dim, dim) + self.projection_drop = nn.Dropout( + projection_drop if projection_drop is not None else VisionTransformer.PROJECTION_DROP) + + def foward(self, input_data: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, channel_count = input_data.shape + qkv = self.qkv(input_data).reshape( + batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute( + 2, 0, 3, 1, 4) + # (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count) + query, key, value = qkv[0], qkv[1], qkv[2] + attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1)) + return self.projection_drop(self.projector( + (attention @ value).transpose(1, 2).reshape(batch_size, sequence_length, channel_count))) + + +class VisionTransformer(nn.Module): + HEAD_COUNT = 8 + MLP_RATIO = 4.0 + QKV_BIAS = False + ATTENTION_DROP = 0.0 + PROJECTION_DROP = 0.0 + + def __init__(self, dim: int, head_count: int, mlp_ratio: float = None, + qkv_bias: bool = None