Fix batch generator when data/label isn't ndarray
This commit is contained in:
parent
379dd4814f
commit
128dfe511e
1 changed files with 4 additions and 4 deletions
|
|
@ -22,8 +22,8 @@ class BatchGenerator:
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
self.label_processor = label_processor
|
self.label_processor = label_processor
|
||||||
self.data = data
|
self.data = np.asarray(data)
|
||||||
self.label = label
|
self.label = np.asarray(label)
|
||||||
else:
|
else:
|
||||||
self.data_processor = None
|
self.data_processor = None
|
||||||
self.label_processor = None
|
self.label_processor = None
|
||||||
|
|
@ -42,11 +42,11 @@ class BatchGenerator:
|
||||||
if data_processor:
|
if data_processor:
|
||||||
self.data = np.asarray([data_processor(entry) for entry in data])
|
self.data = np.asarray([data_processor(entry) for entry in data])
|
||||||
else:
|
else:
|
||||||
self.data = data
|
self.data = np.asarray(data)
|
||||||
if label_processor:
|
if label_processor:
|
||||||
self.label = np.asarray([label_processor(entry) for entry in label])
|
self.label = np.asarray([label_processor(entry) for entry in label])
|
||||||
else:
|
else:
|
||||||
self.label = label
|
self.label = np.asarray(label)
|
||||||
if save:
|
if save:
|
||||||
with h5py.File(save_path, 'w') as h5_file:
|
with h5py.File(save_path, 'w') as h5_file:
|
||||||
h5_file.create_dataset('data', data=self.data)
|
h5_file.create_dataset('data', data=self.data)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue