From 770a9a4f8206d8553006936ceed770d22a155036 Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 21 May 2021 16:00:16 +0900 Subject: [PATCH] Avoid use_batch_norm as layers instance variable --- layers.py | 31 +++++++++++++----------- transformer/vision_transformer.py | 40 +++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 transformer/vision_transformer.py diff --git a/layers.py b/layers.py index 93e7749..3966c43 100644 --- a/layers.py +++ b/layers.py @@ -19,7 +19,7 @@ class Layer(nn.Module): METRICS = False LOGGER = DummyLogger() - def __init__(self, activation, use_batch_norm): + def __init__(self, activation): super().__init__() # Preload default if activation == 0: @@ -28,28 +28,27 @@ class Layer(nn.Module): self.activation = activation() else: self.activation = activation - self.batch_norm: torch.nn._BatchNorm = None - self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm + self.batch_norm: torch.nn._BatchNorm def forward(self, input_data: torch.Tensor) -> torch.Tensor: output = input_data - if self.activation is not None: + if self.activation: output = self.activation(output) - if self.use_batch_norm: - # It is assumed here that if using batch norm, then self.batch_norm has been instanciated. + if self.batch_norm: output = self.batch_norm(output) return output class Linear(Layer): def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.fc(input_data)) @@ -58,14 +57,15 @@ class Linear(Layer): class Conv1d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -78,6 +78,7 @@ class Conv2d(Layer): self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, @@ -90,14 +91,15 @@ class Conv2d(Layer): class Conv3d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm3d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) @@ -106,15 +108,16 @@ class Conv3d(Layer): class Deconv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: Union[int, Tuple[int, int]] = 1, activation=0, use_batch_norm: bool = None, **kwargs): - super().__init__(activation, use_batch_norm) + super().__init__(activation) self.deconv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, stride=stride, bias=not self.use_batch_norm, **kwargs) + use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.batch_norm = nn.BatchNorm2d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, - track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None + track_running_stats=Layer.BATCH_NORM_TRAINING) if use_batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.deconv(input_data)) diff --git a/transformer/vision_transformer.py b/transformer/vision_transformer.py new file mode 100644 index 0000000..2e5ef4b --- /dev/null +++ b/transformer/vision_transformer.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +class Attention(nn.Module): + def __init__(self, dim: int, head_count: int = None, qkv_bias: bool = False, qk_scale: float = None, + attention_drop: float = None, projection_drop: float = None): + super().__init__() + self.head_count = head_count + head_dim = dim // head_count + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attention_drop = nn.Dropout( + attention_drop if attention_drop is not None else VisionTransformer.ATTENTION_DROP) + self.projector = nn.Linear(dim, dim) + self.projection_drop = nn.Dropout( + projection_drop if projection_drop is not None else VisionTransformer.PROJECTION_DROP) + + def foward(self, input_data: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, channel_count = input_data.shape + qkv = self.qkv(input_data).reshape( + batch_size, sequence_length, 3, self.head_count, channel_count // self.head_count).permute( + 2, 0, 3, 1, 4) + # (output shape : 3, batch_size, head_ctoun, sequence_lenght, channel_count / head_count) + query, key, value = qkv[0], qkv[1], qkv[2] + attention = self.attention_drop(((query @ key.transpose(-2, -1)) * self.scale).softmax(dim=-1)) + return self.projection_drop(self.projector( + (attention @ value).transpose(1, 2).reshape(batch_size, sequence_length, channel_count))) + + +class VisionTransformer(nn.Module): + HEAD_COUNT = 8 + MLP_RATIO = 4.0 + QKV_BIAS = False + ATTENTION_DROP = 0.0 + PROJECTION_DROP = 0.0 + + def __init__(self, dim: int, head_count: int, mlp_ratio: float = None, + qkv_bias: bool = None