Add pipeline to SequenceBatchGenerator
* Remove flip_data parameter (can be done in the pipeline)
This commit is contained in:
parent
d0fd6b2642
commit
73457750a8
2 changed files with 31 additions and 36 deletions
|
|
@ -68,10 +68,15 @@ class BatchGenerator:
|
||||||
self.global_step = 0
|
self.global_step = 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
|
|
||||||
first_data = np.array([data_processor(entry) if data_processor else entry
|
first_data = [data_processor(entry) if data_processor else entry
|
||||||
for entry in self.data[self.index_list[:batch_size]]])
|
for entry in self.data[self.index_list[:batch_size]]]
|
||||||
first_label = np.array([label_processor(entry) if label_processor else entry
|
first_label = [label_processor(entry) if label_processor else entry
|
||||||
for entry in self.label[self.index_list[:batch_size]]])
|
for entry in self.label[self.index_list[:batch_size]]]
|
||||||
|
if self.pipeline is not None:
|
||||||
|
for data_index, sample_data in enumerate(first_data):
|
||||||
|
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
|
||||||
|
first_data = np.asarray(first_data)
|
||||||
|
first_label = np.asarray(first_label)
|
||||||
self.batch_data = first_data
|
self.batch_data = first_data
|
||||||
self.batch_label = first_label
|
self.batch_label = first_label
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,15 @@ class SequenceGenerator(BatchGenerator):
|
||||||
|
|
||||||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
index_list=None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
pipeline: Optional[Callable] = None, index_list=None,
|
||||||
sequence_stride=1, flip_data=False, save: Optional[str] = None):
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
|
sequence_stride=1, save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.sequence_size = sequence_size
|
self.sequence_size = sequence_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.prefetch = prefetch and not preload
|
self.prefetch = prefetch and not preload
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.flip_data = flip_data
|
self.pipeline = pipeline
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -107,25 +108,28 @@ class SequenceGenerator(BatchGenerator):
|
||||||
first_data.append(
|
first_data.append(
|
||||||
[data_processor(input_data)
|
[data_processor(input_data)
|
||||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
first_data = np.asarray(first_data)
|
|
||||||
else:
|
else:
|
||||||
first_data = []
|
first_data = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
first_data.append(
|
first_data.append(
|
||||||
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
first_data = np.asarray(first_data)
|
|
||||||
if label_processor:
|
if label_processor:
|
||||||
first_label = []
|
first_label = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
first_label.append(
|
first_label.append(
|
||||||
[label_processor(input_label)
|
[label_processor(input_label)
|
||||||
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
first_label = np.asarray(first_label)
|
|
||||||
else:
|
else:
|
||||||
first_label = []
|
first_label = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
first_label.append(
|
first_label.append(
|
||||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
|
if self.pipeline is not None:
|
||||||
|
for sequence_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
|
||||||
|
for data_index, (data_sample, label_sample) in enumerate(zip(data_sequence, label_sequence)):
|
||||||
|
first_data[sequence_index][data_index], first_label[sequence_index][data_index] = self.pipeline(
|
||||||
|
data_sample, label_sample)
|
||||||
|
first_data = np.asarray(first_data)
|
||||||
first_label = np.asarray(first_label)
|
first_label = np.asarray(first_label)
|
||||||
self.batch_data = first_data
|
self.batch_data = first_data
|
||||||
self.batch_label = first_label
|
self.batch_label = first_label
|
||||||
|
|
@ -179,35 +183,11 @@ class SequenceGenerator(BatchGenerator):
|
||||||
data.append(
|
data.append(
|
||||||
[self.data_processor(input_data)
|
[self.data_processor(input_data)
|
||||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
if self.flip_data:
|
|
||||||
flip = np.random.uniform()
|
|
||||||
if flip < 0.25:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
|
||||||
elif flip < 0.5:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
|
||||||
elif flip < 0.75:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
|
||||||
else:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
|
||||||
else:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
|
||||||
else:
|
else:
|
||||||
data = []
|
data = []
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
if self.flip_data:
|
|
||||||
flip = np.random.uniform()
|
|
||||||
if flip < 0.25:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
|
||||||
elif flip < 0.5:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
|
||||||
elif flip < 0.75:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
|
||||||
else:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
|
||||||
else:
|
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
|
||||||
|
|
||||||
# Loading label
|
# Loading label
|
||||||
if self.label_processor is not None:
|
if self.label_processor is not None:
|
||||||
|
|
@ -217,12 +197,22 @@ class SequenceGenerator(BatchGenerator):
|
||||||
label.append(
|
label.append(
|
||||||
[self.label_processor(input_data)
|
[self.label_processor(input_data)
|
||||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
|
||||||
else:
|
else:
|
||||||
label = []
|
label = []
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
|
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
|
|
||||||
|
# Process through pipeline
|
||||||
|
if self.pipeline is not None:
|
||||||
|
for sequence_index in range(len(data)):
|
||||||
|
for data_index in range(len(data[sequence_index])):
|
||||||
|
piped_data, piped_label = self.pipeline(
|
||||||
|
data[sequence_index][data_index], label[sequence_index][data_index])
|
||||||
|
self.cache_data[self.current_cache][data_index] = piped_data
|
||||||
|
self.cache_label[self.current_cache][data_index] = piped_label
|
||||||
|
else:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue