diff --git a/utils/batch_generator.py b/utils/batch_generator.py index aeada53..e5b12fa 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -22,8 +22,8 @@ class BatchGenerator: if not preload: self.data_processor = data_processor self.label_processor = label_processor - self.data = data - self.label = label + self.data = np.asarray(data) + self.label = np.asarray(label) else: self.data_processor = None self.label_processor = None @@ -42,11 +42,11 @@ class BatchGenerator: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) else: - self.data = data + self.data = np.asarray(data) if label_processor: self.label = np.asarray([label_processor(entry) for entry in label]) else: - self.label = label + self.label = np.asarray(label) if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data', data=self.data)