66 lines
3 KiB
Python
66 lines
3 KiB
Python
import os
|
|
import struct
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
|
|
|
|
def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) -> np.ndarray:
|
|
"""Load MNIST image file"""
|
|
images = []
|
|
with open(path, 'rb') as data_file:
|
|
header_data = data_file.read(16) # 4 * int32 = 16 bytes
|
|
data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data)
|
|
if data_magic != magic_number:
|
|
raise RuntimeError(
|
|
f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
|
|
image_data_size = rows * cols
|
|
image_chunk = 1000 # loading by chunk for faster IO
|
|
for _ in range(data_count // image_chunk):
|
|
data = data_file.read(image_data_size * image_chunk)
|
|
data = struct.unpack('B' * image_data_size * image_chunk, data)
|
|
images += data
|
|
images = np.array(images, dtype=np.uint8)
|
|
if flatten:
|
|
images = np.reshape(images, (data_count, rows * cols))
|
|
else:
|
|
images = np.reshape(images, (data_count, rows, cols))
|
|
return images
|
|
|
|
|
|
def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray:
|
|
"""Load MNIST label file"""
|
|
labels = []
|
|
with open(path, 'rb') as data_file:
|
|
header_data = data_file.read(8) # 2 * int32 = 8 bytes
|
|
data_magic, data_count = struct.unpack('>ii', header_data)
|
|
if data_magic != magic_number:
|
|
raise RuntimeError(
|
|
f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
|
|
label_chunk = 1000 # loading by chunk for faster IO
|
|
for _ in range(data_count // label_chunk):
|
|
data = data_file.read(label_chunk)
|
|
data = struct.unpack('B' * label_chunk, data)
|
|
labels += data
|
|
labels = np.array(labels, dtype=np.uint8)
|
|
labels = np.asarray(labels)
|
|
return labels
|
|
|
|
|
|
def load_data(data_path: str, flatten: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
|
"""Load MNIST data from raw ubyte files (download at http://yann.lecun.com/exdb/mnist/)"""
|
|
train_images_filename = 'train-images-idx3-ubyte'
|
|
train_labels_filename = 'train-labels-idx1-ubyte'
|
|
test_images_filename = 't10k-images-idx3-ubyte'
|
|
test_labels_filename = 't10k-labels-idx1-ubyte'
|
|
|
|
for filename in [train_images_filename, train_labels_filename, test_images_filename, test_labels_filename]:
|
|
if not os.path.exists(os.path.join(data_path, filename)):
|
|
raise RuntimeError(f'MNIST data load : Couldn\'t find {filename}')
|
|
|
|
train_images = load_image_file(os.path.join(data_path, train_images_filename), flatten=flatten)
|
|
train_labels = load_label_file(os.path.join(data_path, train_labels_filename))
|
|
test_images = load_image_file(os.path.join(data_path, test_images_filename), flatten=flatten)
|
|
test_labels = load_label_file(os.path.join(data_path, test_labels_filename))
|
|
|
|
return train_images, train_labels, test_images, test_labels
|