Fix SequenceGenerator for single sequence data
This commit is contained in:
parent
58310ba077
commit
95f2e52ff3
1 changed files with 35 additions and 18 deletions
|
|
@ -104,16 +104,6 @@ class SequenceGenerator:
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
|
||||||
self.index_list = []
|
|
||||||
for sequence_index in range(len(data)):
|
|
||||||
start_indices = np.expand_dims(
|
|
||||||
np.arange(len(data[sequence_index]) - sequence_size, dtype=np.uint8),
|
|
||||||
axis=-1)
|
|
||||||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
|
||||||
self.index_list.append(start_indices)
|
|
||||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
|
||||||
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
self.label_processor = label_processor
|
self.label_processor = label_processor
|
||||||
|
|
@ -122,25 +112,52 @@ class SequenceGenerator:
|
||||||
else:
|
else:
|
||||||
self.data_processor = None
|
self.data_processor = None
|
||||||
self.label_processor = None
|
self.label_processor = None
|
||||||
|
save_path = save
|
||||||
|
if save is not None:
|
||||||
|
if '.' not in os.path.basename(save_path):
|
||||||
|
save_path += '.hdf5'
|
||||||
|
if not os.path.exists(os.path.dirname(save_path)):
|
||||||
|
os.makedirs(os.path.dirname(save_path))
|
||||||
|
|
||||||
if save and os.path.exists(save + '_data.npy'):
|
if save and os.path.exists(save_path):
|
||||||
self.data = np.load(save + '_data.npy', allow_pickle=True)
|
with h5py.File(save_path, 'r') as h5_file:
|
||||||
self.label = np.load(save + '_label.npy', allow_pickle=True)
|
data_len = np.asarray(h5_file['data_len'])
|
||||||
|
self.data = []
|
||||||
|
self.label = []
|
||||||
|
for sequence_index in range(data_len):
|
||||||
|
self.data.append(np.asarray(h5_file[f'data_{sequence_index}']))
|
||||||
|
self.label.append(np.asarray(h5_file[f'label_{sequence_index}']))
|
||||||
|
self.data = np.asarray(self.data)
|
||||||
|
self.label = np.asarray(self.label)
|
||||||
else:
|
else:
|
||||||
if data_processor:
|
if data_processor:
|
||||||
self.data = np.asarray(
|
self.data = np.asarray(
|
||||||
[np.asarray([data_processor(entry) for entry in serie]) for serie in data])
|
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
|
||||||
|
dtype=np.object if len(data) > 1 else None)
|
||||||
else:
|
else:
|
||||||
self.data = data
|
self.data = data
|
||||||
if label_processor:
|
if label_processor:
|
||||||
self.label = np.asarray(
|
self.label = np.asarray(
|
||||||
[np.asarray([label_processor(entry) for entry in serie]) for serie in label])
|
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
|
||||||
|
dtype=np.object if len(label) > 1 else None)
|
||||||
else:
|
else:
|
||||||
self.label = label
|
self.label = label
|
||||||
|
if save:
|
||||||
|
with h5py.File(save_path, 'w') as h5_file:
|
||||||
|
h5_file.create_dataset(f'data_len', data=len(self.data))
|
||||||
|
for sequence_index in range(len(self.data)):
|
||||||
|
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
|
||||||
|
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
|
||||||
|
|
||||||
if save:
|
self.index_list = []
|
||||||
np.save(save + '_data.npy', self.data, allow_pickle=True)
|
for sequence_index in range(len(self.data)):
|
||||||
np.save(save + '_label.npy', self.label, allow_pickle=True)
|
start_indices = np.expand_dims(
|
||||||
|
np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32),
|
||||||
|
axis=-1)
|
||||||
|
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||||
|
self.index_list.append(start_indices)
|
||||||
|
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||||
|
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||||
|
|
||||||
self.epoch = 0
|
self.epoch = 0
|
||||||
self.global_step = -1
|
self.global_step = -1
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue