Add pipeline to SequenceBatchGenerator

* Remove flip_data parameter (can be done in the pipeline)
This commit is contained in:
Corentin 2021-03-23 02:13:40 +09:00
commit 73457750a8
2 changed files with 31 additions and 36 deletions

View file

@ -68,10 +68,15 @@ class BatchGenerator:
self.global_step = 0 self.global_step = 0
self.step = 0 self.step = 0
first_data = np.array([data_processor(entry) if data_processor else entry first_data = [data_processor(entry) if data_processor else entry
for entry in self.data[self.index_list[:batch_size]]]) for entry in self.data[self.index_list[:batch_size]]]
first_label = np.array([label_processor(entry) if label_processor else entry first_label = [label_processor(entry) if label_processor else entry
for entry in self.label[self.index_list[:batch_size]]]) for entry in self.label[self.index_list[:batch_size]]]
if self.pipeline is not None:
for data_index, sample_data in enumerate(first_data):
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label

View file

@ -16,14 +16,15 @@ class SequenceGenerator(BatchGenerator):
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, pipeline: Optional[Callable] = None, index_list=None,
sequence_stride=1, flip_data=False, save: Optional[str] = None): prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
sequence_stride=1, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
self.sequence_size = sequence_size self.sequence_size = sequence_size
self.shuffle = shuffle self.shuffle = shuffle
self.prefetch = prefetch and not preload self.prefetch = prefetch and not preload
self.num_workers = num_workers self.num_workers = num_workers
self.flip_data = flip_data self.pipeline = pipeline
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
@ -107,26 +108,29 @@ class SequenceGenerator(BatchGenerator):
first_data.append( first_data.append(
[data_processor(input_data) [data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
first_data = np.asarray(first_data)
else: else:
first_data = [] first_data = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_data.append( first_data.append(
self.data[sequence_index][start_index: start_index + self.sequence_size]) self.data[sequence_index][start_index: start_index + self.sequence_size])
first_data = np.asarray(first_data)
if label_processor: if label_processor:
first_label = [] first_label = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_label.append( first_label.append(
[label_processor(input_label) [label_processor(input_label)
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]]) for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
first_label = np.asarray(first_label)
else: else:
first_label = [] first_label = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_label.append( first_label.append(
self.label[sequence_index][start_index: start_index + self.sequence_size]) self.label[sequence_index][start_index: start_index + self.sequence_size])
first_label = np.asarray(first_label) if self.pipeline is not None:
for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)):
first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline(
data_sample, label_sample)
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label
@ -179,35 +183,11 @@ class SequenceGenerator(BatchGenerator):
data.append( data.append(
[self.data_processor(input_data) [self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else: else:
data = [] data = []
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
# Loading label # Loading label
if self.label_processor is not None: if self.label_processor is not None:
@ -217,12 +197,22 @@ class SequenceGenerator(BatchGenerator):
label.append( label.append(
[self.label_processor(input_data) [self.label_processor(input_data)
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
else: else:
label = [] label = []
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size]) label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
# Process through pipeline
if self.pipeline is not None:
for sequence_index in range(len(data)):
for data_index in range(len(data[sequence_index])):
piped_data, piped_label = self.pipeline(
data[sequence_index][data_index], label[sequence_index][data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
self.cache_label[self.current_cache][:len(label)] = np.asarray(label) self.cache_label[self.current_cache][:len(label)] = np.asarray(label)