diff --git a/dataset/mnist.py b/dataset/mnist.py index 2604af9..4b8c464 100644 --- a/dataset/mnist.py +++ b/dataset/mnist.py @@ -9,13 +9,13 @@ def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) """Load MNIST image file""" images = [] with open(path, 'rb') as data_file: - header_data = data_file.read(16) ## 4 * int32 = 16 bytes + header_data = data_file.read(16) # 4 * int32 = 16 bytes data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data) if data_magic != magic_number: raise RuntimeError( f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') image_data_size = rows * cols - image_chunk = 1000 ## loading by chunk for faster IO + image_chunk = 1000 # loading by chunk for faster IO for _ in range(data_count // image_chunk): data = data_file.read(image_data_size * image_chunk) data = struct.unpack('B' * image_data_size * image_chunk, data) @@ -32,12 +32,12 @@ def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray: """Load MNIST label file""" labels = [] with open(path, 'rb') as data_file: - header_data = data_file.read(8) ## 2 * int32 = 8 bytes + header_data = data_file.read(8) # 2 * int32 = 8 bytes data_magic, data_count = struct.unpack('>ii', header_data) if data_magic != magic_number: raise RuntimeError( f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') - label_chunk = 1000 ## loading by chunk for faster IO + label_chunk = 1000 # loading by chunk for faster IO for _ in range(data_count // label_chunk): data = data_file.read(label_chunk) data = struct.unpack('B' * label_chunk, data) diff --git a/layers.py b/layers.py index fa473f0..2300e3c 100644 --- a/layers.py +++ b/layers.py @@ -18,11 +18,7 @@ class Layer(nn.Module): BATCH_NORM = True BATCH_NORM_TRAINING = False - BATCH_NORM_DECAY = 0.95 - - REGULARIZER = None - - PADDING = 'SAME' + BATCH_NORM_MOMENTUM = 0.01 IS_TRAINING = False METRICS = False @@ -50,10 +46,23 @@ class Layer(nn.Module): class Conv2d(Layer): def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation=0, batch_norm=None, **kwargs): - super(Conv2d, self).__init__(activation, batch_norm) + super().__init__(activation, batch_norm) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, **kwargs) - self.batch_norm = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01) if self.batch_norm else None + self.batch_norm = nn.BatchNorm2d( + out_channels, eps=0.001, momentum=Layer.BATCH_NORM_MOMENTUM) if self.batch_norm else None def forward(self, input_data: torch.Tensor) -> torch.Tensor: return super().forward(self.conv(input_data)) + + +class Linear(Layer): + def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): + super().__init__(activation, batch_norm) + + self.fc = nn.Linear(in_channels, out_channels, **kwargs) + self.batch_norm = nn.BatchNorm1d( + out_channels, eps=0.001, momentum=Layer.BATCH_NORM_MOMENTUM) if self.batch_norm else None + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + return super().forward(self.fc(input_data))