Fix batch generator when data/label isn't ndarray

This commit is contained in:
Corentin 2020-12-18 16:24:21 +09:00
commit 128dfe511e

View file

@ -22,8 +22,8 @@ class BatchGenerator:
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
self.label_processor = label_processor self.label_processor = label_processor
self.data = data self.data = np.asarray(data)
self.label = label self.label = np.asarray(label)
else: else:
self.data_processor = None self.data_processor = None
self.label_processor = None self.label_processor = None
@ -42,11 +42,11 @@ class BatchGenerator:
if data_processor: if data_processor:
self.data = np.asarray([data_processor(entry) for entry in data]) self.data = np.asarray([data_processor(entry) for entry in data])
else: else:
self.data = data self.data = np.asarray(data)
if label_processor: if label_processor:
self.label = np.asarray([label_processor(entry) for entry in label]) self.label = np.asarray([label_processor(entry) for entry in label])
else: else:
self.label = label self.label = np.asarray(label)
if save: if save:
with h5py.File(save_path, 'w') as h5_file: with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('data', data=self.data)