Change BatchGenerator save format from pickle to h5df
This commit is contained in:
parent
7db99ffa51
commit
95bd1850b5
1 changed files with 15 additions and 6 deletions
|
|
@ -2,6 +2,7 @@ import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,10 +22,17 @@ class BatchGenerator:
|
||||||
else:
|
else:
|
||||||
self.data_processor = None
|
self.data_processor = None
|
||||||
self.label_processor = None
|
self.label_processor = None
|
||||||
|
save_path = save
|
||||||
|
if save is not None:
|
||||||
|
if '.' not in os.path.basename(save_path):
|
||||||
|
save_path += '.hdf5'
|
||||||
|
if not os.path.exists(os.path.dirname(save_path)):
|
||||||
|
os.makedirs(os.path.dirname(save_path))
|
||||||
|
|
||||||
if save and os.path.exists(save + '_data.npy'):
|
if save and os.path.exists(save_path):
|
||||||
self.data = np.load(save + '_data.npy', allow_pickle=True)
|
with h5py.File(save_path, 'r') as h5_file:
|
||||||
self.label = np.load(save + '_label.npy', allow_pickle=True)
|
self.data = np.asarray(h5_file['data'])
|
||||||
|
self.label = np.asarray(h5_file['label'])
|
||||||
else:
|
else:
|
||||||
if data_processor:
|
if data_processor:
|
||||||
self.data = np.asarray([data_processor(entry) for entry in data])
|
self.data = np.asarray([data_processor(entry) for entry in data])
|
||||||
|
|
@ -34,9 +42,10 @@ class BatchGenerator:
|
||||||
self.label = np.asarray([label_processor(entry) for entry in label])
|
self.label = np.asarray([label_processor(entry) for entry in label])
|
||||||
else:
|
else:
|
||||||
self.label = label
|
self.label = label
|
||||||
if save:
|
if save:
|
||||||
np.save(save + '_data.npy', self.data, allow_pickle=True)
|
with h5py.File(save_path, 'w') as h5_file:
|
||||||
np.save(save + '_label.npy', self.label, allow_pickle=True)
|
h5_file.create_dataset('data', data=self.data)
|
||||||
|
h5_file.create_dataset('label', data=self.label)
|
||||||
|
|
||||||
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
|
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue