From 95bd1850b5c1af933c21e2c7ff310bd0e73b4f35 Mon Sep 17 00:00:00 2001 From: Corentin Risselin Date: Tue, 28 Apr 2020 16:54:44 +0900 Subject: [PATCH] Change BatchGenerator save format from pickle to h5df --- utils/batch_generator.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 2b72e93..008e5c6 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -2,6 +2,7 @@ import math import os from typing import Optional, Tuple +import h5py import numpy as np @@ -21,10 +22,17 @@ class BatchGenerator: else: self.data_processor = None self.label_processor = None + save_path = save + if save is not None: + if '.' not in os.path.basename(save_path): + save_path += '.hdf5' + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) - if save and os.path.exists(save + '_data.npy'): - self.data = np.load(save + '_data.npy', allow_pickle=True) - self.label = np.load(save + '_label.npy', allow_pickle=True) + if save and os.path.exists(save_path): + with h5py.File(save_path, 'r') as h5_file: + self.data = np.asarray(h5_file['data']) + self.label = np.asarray(h5_file['label']) else: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) @@ -34,9 +42,10 @@ class BatchGenerator: self.label = np.asarray([label_processor(entry) for entry in label]) else: self.label = label - if save: - np.save(save + '_data.npy', self.data, allow_pickle=True) - np.save(save + '_label.npy', self.label, allow_pickle=True) + if save: + with h5py.File(save_path, 'w') as h5_file: + h5_file.create_dataset('data', data=self.data) + h5_file.create_dataset('label', data=self.label) self.step_per_epoch = math.ceil(len(self.data) / batch_size)