import os import struct from typing import Tuple import numpy as np def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) -> np.ndarray: """Load MNIST image file""" images = [] with open(path, 'rb') as data_file: header_data = data_file.read(16) # 4 * int32 = 16 bytes data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data) if data_magic != magic_number: raise RuntimeError( f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') image_data_size = rows * cols image_chunk = 1000 # loading by chunk for faster IO for _ in range(data_count // image_chunk): data = data_file.read(image_data_size * image_chunk) data = struct.unpack('B' * image_data_size * image_chunk, data) images += data images = np.array(images, dtype=np.uint8) if flatten: images = np.reshape(images, (data_count, rows * cols)) else: images = np.reshape(images, (data_count, rows, cols)) return images def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray: """Load MNIST label file""" labels = [] with open(path, 'rb') as data_file: header_data = data_file.read(8) # 2 * int32 = 8 bytes data_magic, data_count = struct.unpack('>ii', header_data) if data_magic != magic_number: raise RuntimeError( f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}') label_chunk = 1000 # loading by chunk for faster IO for _ in range(data_count // label_chunk): data = data_file.read(label_chunk) data = struct.unpack('B' * label_chunk, data) labels += data labels = np.array(labels, dtype=np.uint8) labels = np.asarray(labels) return labels def load_data(data_path: str, flatten: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Load MNIST data from raw ubyte files (download at http://yann.lecun.com/exdb/mnist/)""" train_images_filename = 'train-images-idx3-ubyte' train_labels_filename = 'train-labels-idx1-ubyte' test_images_filename = 't10k-images-idx3-ubyte' test_labels_filename = 't10k-labels-idx1-ubyte' for filename in [train_images_filename, train_labels_filename, test_images_filename, test_labels_filename]: if not os.path.exists(os.path.join(data_path, filename)): raise RuntimeError(f'MNIST data load : Couldn\'t find {filename}') train_images = load_image_file(os.path.join(data_path, train_images_filename), flatten=flatten) train_labels = load_label_file(os.path.join(data_path, train_labels_filename)) test_images = load_image_file(os.path.join(data_path, test_images_filename), flatten=flatten) test_labels = load_label_file(os.path.join(data_path, test_labels_filename)) return train_images, train_labels, test_images, test_labels