torch_utils/dataset/mnist.py
2020-04-17 12:08:16 +09:00

66 lines
3 KiB
Python

import os
import struct
from typing import Tuple
import numpy as np
def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) -> np.ndarray:
"""Load MNIST image file"""
images = []
with open(path, 'rb') as data_file:
header_data = data_file.read(16) # 4 * int32 = 16 bytes
data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data)
if data_magic != magic_number:
raise RuntimeError(
f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
image_data_size = rows * cols
image_chunk = 1000 # loading by chunk for faster IO
for _ in range(data_count // image_chunk):
data = data_file.read(image_data_size * image_chunk)
data = struct.unpack('B' * image_data_size * image_chunk, data)
images += data
images = np.array(images, dtype=np.uint8)
if flatten:
images = np.reshape(images, (data_count, rows * cols))
else:
images = np.reshape(images, (data_count, rows, cols))
return images
def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray:
"""Load MNIST label file"""
labels = []
with open(path, 'rb') as data_file:
header_data = data_file.read(8) # 2 * int32 = 8 bytes
data_magic, data_count = struct.unpack('>ii', header_data)
if data_magic != magic_number:
raise RuntimeError(
f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
label_chunk = 1000 # loading by chunk for faster IO
for _ in range(data_count // label_chunk):
data = data_file.read(label_chunk)
data = struct.unpack('B' * label_chunk, data)
labels += data
labels = np.array(labels, dtype=np.uint8)
labels = np.asarray(labels)
return labels
def load_data(data_path: str, flatten: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Load MNIST data from raw ubyte files (download at http://yann.lecun.com/exdb/mnist/)"""
train_images_filename = 'train-images-idx3-ubyte'
train_labels_filename = 'train-labels-idx1-ubyte'
test_images_filename = 't10k-images-idx3-ubyte'
test_labels_filename = 't10k-labels-idx1-ubyte'
for filename in [train_images_filename, train_labels_filename, test_images_filename, test_labels_filename]:
if not os.path.exists(os.path.join(data_path, filename)):
raise RuntimeError(f'MNIST data load : Couldn\'t find {filename}')
train_images = load_image_file(os.path.join(data_path, train_images_filename), flatten=flatten)
train_labels = load_label_file(os.path.join(data_path, train_labels_filename))
test_images = load_image_file(os.path.join(data_path, test_images_filename), flatten=flatten)
test_labels = load_label_file(os.path.join(data_path, test_labels_filename))
return train_images, train_labels, test_images, test_labels