diff --git a/dataset/mnist.py b/dataset/mnist.py new file mode 100644 index 0000000..2604af9 --- /dev/null +++ b/dataset/mnist.py @@ -0,0 +1,66 @@ +import os +import struct +from typing import Tuple + +import numpy as np + + +def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) -> np.ndarray: + """Load MNIST image file""" + images = [] + with open(path, 'rb') as data_file: + header_data = data_file.read(16) ## 4 * int32 = 16 bytes + data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data) + if data_magic != magic_number: + raise RuntimeError( + f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') + image_data_size = rows * cols + image_chunk = 1000 ## loading by chunk for faster IO + for _ in range(data_count // image_chunk): + data = data_file.read(image_data_size * image_chunk) + data = struct.unpack('B' * image_data_size * image_chunk, data) + images += data + images = np.array(images, dtype=np.uint8) + if flatten: + images = np.reshape(images, (data_count, rows * cols)) + else: + images = np.reshape(images, (data_count, rows, cols)) + return images + + +def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray: + """Load MNIST label file""" + labels = [] + with open(path, 'rb') as data_file: + header_data = data_file.read(8) ## 2 * int32 = 8 bytes + data_magic, data_count = struct.unpack('>ii', header_data) + if data_magic != magic_number: + raise RuntimeError( + f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') + label_chunk = 1000 ## loading by chunk for faster IO + for _ in range(data_count // label_chunk): + data = data_file.read(label_chunk) + data = struct.unpack('B' * label_chunk, data) + labels += data + labels = np.array(labels, dtype=np.uint8) + labels = np.asarray(labels) + return labels + + +def load_data(data_path: str, flatten: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Load MNIST data from raw ubyte files (download at http://yann.lecun.com/exdb/mnist/)""" + train_images_filename = 'train-images-idx3-ubyte' + train_labels_filename = 'train-labels-idx1-ubyte' + test_images_filename = 't10k-images-idx3-ubyte' + test_labels_filename = 't10k-labels-idx1-ubyte' + + for filename in [train_images_filename, train_labels_filename, test_images_filename, test_labels_filename]: + if not os.path.exists(os.path.join(data_path, filename)): + raise RuntimeError(f'MNIST data load : Couldn\'t find {filename}') + + train_images = load_image_file(os.path.join(data_path, train_images_filename), flatten=flatten) + train_labels = load_label_file(os.path.join(data_path, train_labels_filename)) + test_images = load_image_file(os.path.join(data_path, test_images_filename), flatten=flatten) + test_labels = load_label_file(os.path.join(data_path, train_labels_filename)) + + return train_images, train_labels, test_images, test_labels