Add MNIST loader
This commit is contained in:
parent
846160e961
commit
9ab6adce7a
1 changed files with 66 additions and 0 deletions
66
dataset/mnist.py
Normal file
66
dataset/mnist.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import os
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_image_file(path: str, magic_number: int = 2051, flatten: bool = False) -> np.ndarray:
|
||||
"""Load MNIST image file"""
|
||||
images = []
|
||||
with open(path, 'rb') as data_file:
|
||||
header_data = data_file.read(16) ## 4 * int32 = 16 bytes
|
||||
data_magic, data_count, rows, cols = struct.unpack('>iiii', header_data)
|
||||
if data_magic != magic_number:
|
||||
raise RuntimeError(
|
||||
f'MNIST image file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
|
||||
image_data_size = rows * cols
|
||||
image_chunk = 1000 ## loading by chunk for faster IO
|
||||
for _ in range(data_count // image_chunk):
|
||||
data = data_file.read(image_data_size * image_chunk)
|
||||
data = struct.unpack('B' * image_data_size * image_chunk, data)
|
||||
images += data
|
||||
images = np.array(images, dtype=np.uint8)
|
||||
if flatten:
|
||||
images = np.reshape(images, (data_count, rows * cols))
|
||||
else:
|
||||
images = np.reshape(images, (data_count, rows, cols))
|
||||
return images
|
||||
|
||||
|
||||
def load_label_file(path: str, magic_number: int = 2049) -> np.ndarray:
|
||||
"""Load MNIST label file"""
|
||||
labels = []
|
||||
with open(path, 'rb') as data_file:
|
||||
header_data = data_file.read(8) ## 2 * int32 = 8 bytes
|
||||
data_magic, data_count = struct.unpack('>ii', header_data)
|
||||
if data_magic != magic_number:
|
||||
raise RuntimeError(
|
||||
f'MNIST label file doesn\'t have correct mmagic number: {data_magic} instead of {magic_number}')
|
||||
label_chunk = 1000 ## loading by chunk for faster IO
|
||||
for _ in range(data_count // label_chunk):
|
||||
data = data_file.read(label_chunk)
|
||||
data = struct.unpack('B' * label_chunk, data)
|
||||
labels += data
|
||||
labels = np.array(labels, dtype=np.uint8)
|
||||
labels = np.asarray(labels)
|
||||
return labels
|
||||
|
||||
|
||||
def load_data(data_path: str, flatten: bool = False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Load MNIST data from raw ubyte files (download at http://yann.lecun.com/exdb/mnist/)"""
|
||||
train_images_filename = 'train-images-idx3-ubyte'
|
||||
train_labels_filename = 'train-labels-idx1-ubyte'
|
||||
test_images_filename = 't10k-images-idx3-ubyte'
|
||||
test_labels_filename = 't10k-labels-idx1-ubyte'
|
||||
|
||||
for filename in [train_images_filename, train_labels_filename, test_images_filename, test_labels_filename]:
|
||||
if not os.path.exists(os.path.join(data_path, filename)):
|
||||
raise RuntimeError(f'MNIST data load : Couldn\'t find {filename}')
|
||||
|
||||
train_images = load_image_file(os.path.join(data_path, train_images_filename), flatten=flatten)
|
||||
train_labels = load_label_file(os.path.join(data_path, train_labels_filename))
|
||||
test_images = load_image_file(os.path.join(data_path, test_images_filename), flatten=flatten)
|
||||
test_labels = load_label_file(os.path.join(data_path, train_labels_filename))
|
||||
|
||||
return train_images, train_labels, test_images, test_labels
|
||||
Loading…
Add table
Add a link
Reference in a new issue