Data flip implementation
This commit is contained in:
parent
8a0d9b51a3
commit
144ff4a004
2 changed files with 40 additions and 17 deletions
|
|
@ -12,12 +12,12 @@ class BatchGenerator:
|
||||||
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
||||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
left_right_flip=False, save: Optional[str] = None):
|
flip_data=False, save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.prefetch = prefetch and not preload
|
self.prefetch = prefetch and not preload
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.left_right_flip = left_right_flip
|
self.flip_data = flip_data
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -129,7 +129,6 @@ class BatchGenerator:
|
||||||
self.release()
|
self.release()
|
||||||
|
|
||||||
def release(self):
|
def release(self):
|
||||||
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
|
|
||||||
if self.num_workers > 1:
|
if self.num_workers > 1:
|
||||||
self.worker_stop.buf[0] = 1
|
self.worker_stop.buf[0] = 1
|
||||||
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
|
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
|
||||||
|
|
@ -141,7 +140,6 @@ class BatchGenerator:
|
||||||
self.num_workers = 0 # Avoids double release
|
self.num_workers = 0 # Avoids double release
|
||||||
|
|
||||||
if self.prefetch:
|
if self.prefetch:
|
||||||
print(f'{self.process_id} cleaning prefetch')
|
|
||||||
self.prefetch_stop.buf[0] = 1
|
self.prefetch_stop.buf[0] = 1
|
||||||
self.prefetch_pipe_parent.send(True)
|
self.prefetch_pipe_parent.send(True)
|
||||||
self.prefetch_process.join()
|
self.prefetch_process.join()
|
||||||
|
|
@ -153,14 +151,12 @@ class BatchGenerator:
|
||||||
self.prefetch = False # Avoids double release
|
self.prefetch = False # Avoids double release
|
||||||
|
|
||||||
if self.process_id == 'main' and self.cache_memory_indices is not None:
|
if self.process_id == 'main' and self.cache_memory_indices is not None:
|
||||||
print(f'{self.process_id} cleaning memory')
|
|
||||||
self.cache_memory_indices.close()
|
self.cache_memory_indices.close()
|
||||||
self.cache_memory_indices.unlink()
|
self.cache_memory_indices.unlink()
|
||||||
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
||||||
shared_mem.close()
|
shared_mem.close()
|
||||||
shared_mem.unlink()
|
shared_mem.unlink()
|
||||||
self.cache_memory_indices = None # Avoids double release
|
self.cache_memory_indices = None # Avoids double release
|
||||||
print(f'{self.process_id} released')
|
|
||||||
|
|
||||||
def _prefetch_worker(self):
|
def _prefetch_worker(self):
|
||||||
self.prefetch = False
|
self.prefetch = False
|
||||||
|
|
@ -191,7 +187,6 @@ class BatchGenerator:
|
||||||
|
|
||||||
if self.num_workers > 1:
|
if self.num_workers > 1:
|
||||||
self.release()
|
self.release()
|
||||||
print(f'{self.process_id} exiting')
|
|
||||||
|
|
||||||
def _worker(self, worker_index: int):
|
def _worker(self, worker_index: int):
|
||||||
self.process_id = f'worker_{worker_index}'
|
self.process_id = f'worker_{worker_index}'
|
||||||
|
|
@ -225,7 +220,6 @@ class BatchGenerator:
|
||||||
break
|
break
|
||||||
except ValueError:
|
except ValueError:
|
||||||
break
|
break
|
||||||
print(f'{self.process_id} exiting')
|
|
||||||
|
|
||||||
def skip_epoch(self):
|
def skip_epoch(self):
|
||||||
if self.prefetch:
|
if self.prefetch:
|
||||||
|
|
@ -273,6 +267,17 @@ class BatchGenerator:
|
||||||
else:
|
else:
|
||||||
data = self.data[
|
data = self.data[
|
||||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||||
|
if self.flip_data:
|
||||||
|
flip = np.random.uniform()
|
||||||
|
if flip < 0.25:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
|
||||||
|
elif flip < 0.5:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
|
||||||
|
elif flip < 0.75:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
|
||||||
|
else:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = data
|
||||||
|
else:
|
||||||
self.cache_data[self.current_cache][:len(data)] = data
|
self.cache_data[self.current_cache][:len(data)] = data
|
||||||
# Loading label
|
# Loading label
|
||||||
if self.label_processor is not None:
|
if self.label_processor is not None:
|
||||||
|
|
@ -314,10 +319,6 @@ class BatchGenerator:
|
||||||
self.batch_data = self.cache_data[self.current_cache]
|
self.batch_data = self.cache_data[self.current_cache]
|
||||||
self.batch_label = self.cache_label[self.current_cache]
|
self.batch_label = self.cache_label[self.current_cache]
|
||||||
|
|
||||||
if self.left_right_flip and np.random.uniform() > 0.5:
|
|
||||||
self.batch_data = self.batch_data[:, :, ::-1]
|
|
||||||
self.batch_label = self.batch_label[:, :, ::-1]
|
|
||||||
|
|
||||||
return self.batch_data, self.batch_label
|
return self.batch_data, self.batch_label
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,13 @@ class SequenceGenerator(BatchGenerator):
|
||||||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
save: Optional[str] = None):
|
flip_data=False, save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.sequence_size = sequence_size
|
self.sequence_size = sequence_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.prefetch = prefetch and not preload
|
self.prefetch = prefetch and not preload
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.left_right_flip = False
|
self.flip_data = flip_data
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -167,12 +167,34 @@ class SequenceGenerator(BatchGenerator):
|
||||||
data.append(
|
data.append(
|
||||||
[self.data_processor(input_data)
|
[self.data_processor(input_data)
|
||||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
|
if self.flip_data:
|
||||||
|
flip = np.random.uniform()
|
||||||
|
if flip < 0.25:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
||||||
|
elif flip < 0.5:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
||||||
|
elif flip < 0.75:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
||||||
|
else:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
|
else:
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
else:
|
else:
|
||||||
data = []
|
data = []
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
|
if self.flip_data:
|
||||||
|
flip = np.random.uniform()
|
||||||
|
if flip < 0.25:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
||||||
|
elif flip < 0.5:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
||||||
|
elif flip < 0.75:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
||||||
|
else:
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
|
else:
|
||||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
|
|
||||||
# Loading label
|
# Loading label
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue