Fix batch generator when data/label isn't ndarray
This commit is contained in:
parent
379dd4814f
commit
128dfe511e
1 changed files with 4 additions and 4 deletions
|
|
@ -22,8 +22,8 @@ class BatchGenerator:
|
|||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.data = np.asarray(data)
|
||||
self.label = np.asarray(label)
|
||||
else:
|
||||
self.data_processor = None
|
||||
self.label_processor = None
|
||||
|
|
@ -42,11 +42,11 @@ class BatchGenerator:
|
|||
if data_processor:
|
||||
self.data = np.asarray([data_processor(entry) for entry in data])
|
||||
else:
|
||||
self.data = data
|
||||
self.data = np.asarray(data)
|
||||
if label_processor:
|
||||
self.label = np.asarray([label_processor(entry) for entry in label])
|
||||
else:
|
||||
self.label = label
|
||||
self.label = np.asarray(label)
|
||||
if save:
|
||||
with h5py.File(save_path, 'w') as h5_file:
|
||||
h5_file.create_dataset('data', data=self.data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue