Change BatchGenerator save format from pickle to h5df

This commit is contained in:
Corentin Risselin 2020-04-28 16:54:44 +09:00
commit 95bd1850b5

View file

@ -2,6 +2,7 @@ import math
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
import h5py
import numpy as np import numpy as np
@ -21,10 +22,17 @@ class BatchGenerator:
else: else:
self.data_processor = None self.data_processor = None
self.label_processor = None self.label_processor = None
save_path = save
if save is not None:
if '.' not in os.path.basename(save_path):
save_path += '.hdf5'
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
if save and os.path.exists(save + '_data.npy'): if save and os.path.exists(save_path):
self.data = np.load(save + '_data.npy', allow_pickle=True) with h5py.File(save_path, 'r') as h5_file:
self.label = np.load(save + '_label.npy', allow_pickle=True) self.data = np.asarray(h5_file['data'])
self.label = np.asarray(h5_file['label'])
else: else:
if data_processor: if data_processor:
self.data = np.asarray([data_processor(entry) for entry in data]) self.data = np.asarray([data_processor(entry) for entry in data])
@ -34,9 +42,10 @@ class BatchGenerator:
self.label = np.asarray([label_processor(entry) for entry in label]) self.label = np.asarray([label_processor(entry) for entry in label])
else: else:
self.label = label self.label = label
if save: if save:
np.save(save + '_data.npy', self.data, allow_pickle=True) with h5py.File(save_path, 'w') as h5_file:
np.save(save + '_label.npy', self.label, allow_pickle=True) h5_file.create_dataset('data', data=self.data)
h5_file.create_dataset('label', data=self.label)
self.step_per_epoch = math.ceil(len(self.data) / batch_size) self.step_per_epoch = math.ceil(len(self.data) / batch_size)