Data flip implementation

This commit is contained in:
corentin 2020-12-18 20:30:36 +09:00
commit 144ff4a004
2 changed files with 40 additions and 17 deletions

View file

@ -12,12 +12,12 @@ class BatchGenerator:
def __init__(self, data: Iterable, label: Iterable, batch_size: int, def __init__(self, data: Iterable, label: Iterable, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
left_right_flip=False, save: Optional[str] = None): flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
self.prefetch = prefetch and not preload self.prefetch = prefetch and not preload
self.num_workers = num_workers self.num_workers = num_workers
self.left_right_flip = left_right_flip self.flip_data = flip_data
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
@ -129,7 +129,6 @@ class BatchGenerator:
self.release() self.release()
def release(self): def release(self):
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
if self.num_workers > 1: if self.num_workers > 1:
self.worker_stop.buf[0] = 1 self.worker_stop.buf[0] = 1
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
@ -141,7 +140,6 @@ class BatchGenerator:
self.num_workers = 0 # Avoids double release self.num_workers = 0 # Avoids double release
if self.prefetch: if self.prefetch:
print(f'{self.process_id} cleaning prefetch')
self.prefetch_stop.buf[0] = 1 self.prefetch_stop.buf[0] = 1
self.prefetch_pipe_parent.send(True) self.prefetch_pipe_parent.send(True)
self.prefetch_process.join() self.prefetch_process.join()
@ -153,14 +151,12 @@ class BatchGenerator:
self.prefetch = False # Avoids double release self.prefetch = False # Avoids double release
if self.process_id == 'main' and self.cache_memory_indices is not None: if self.process_id == 'main' and self.cache_memory_indices is not None:
print(f'{self.process_id} cleaning memory')
self.cache_memory_indices.close() self.cache_memory_indices.close()
self.cache_memory_indices.unlink() self.cache_memory_indices.unlink()
for shared_mem in self.cache_memory_data + self.cache_memory_label: for shared_mem in self.cache_memory_data + self.cache_memory_label:
shared_mem.close() shared_mem.close()
shared_mem.unlink() shared_mem.unlink()
self.cache_memory_indices = None # Avoids double release self.cache_memory_indices = None # Avoids double release
print(f'{self.process_id} released')
def _prefetch_worker(self): def _prefetch_worker(self):
self.prefetch = False self.prefetch = False
@ -191,7 +187,6 @@ class BatchGenerator:
if self.num_workers > 1: if self.num_workers > 1:
self.release() self.release()
print(f'{self.process_id} exiting')
def _worker(self, worker_index: int): def _worker(self, worker_index: int):
self.process_id = f'worker_{worker_index}' self.process_id = f'worker_{worker_index}'
@ -225,7 +220,6 @@ class BatchGenerator:
break break
except ValueError: except ValueError:
break break
print(f'{self.process_id} exiting')
def skip_epoch(self): def skip_epoch(self):
if self.prefetch: if self.prefetch:
@ -273,7 +267,18 @@ class BatchGenerator:
else: else:
data = self.data[ data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
self.cache_data[self.current_cache][:len(data)] = data if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = data
else:
self.cache_data[self.current_cache][:len(data)] = data
# Loading label # Loading label
if self.label_processor is not None: if self.label_processor is not None:
label = [] label = []
@ -314,10 +319,6 @@ class BatchGenerator:
self.batch_data = self.cache_data[self.current_cache] self.batch_data = self.cache_data[self.current_cache]
self.batch_label = self.cache_label[self.current_cache] self.batch_label = self.cache_label[self.current_cache]
if self.left_right_flip and np.random.uniform() > 0.5:
self.batch_data = self.batch_data[:, :, ::-1]
self.batch_label = self.batch_label[:, :, ::-1]
return self.batch_data, self.batch_label return self.batch_data, self.batch_label

View file

@ -17,13 +17,13 @@ class SequenceGenerator(BatchGenerator):
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
save: Optional[str] = None): flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
self.sequence_size = sequence_size self.sequence_size = sequence_size
self.shuffle = shuffle self.shuffle = shuffle
self.prefetch = prefetch and not preload self.prefetch = prefetch and not preload
self.num_workers = num_workers self.num_workers = num_workers
self.left_right_flip = False self.flip_data = flip_data
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
@ -167,13 +167,35 @@ class SequenceGenerator(BatchGenerator):
data.append( data.append(
[self.data_processor(input_data) [self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
self.cache_data[self.current_cache][:len(data)] = np.asarray(data) if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else: else:
data = [] data = []
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
self.cache_data[self.current_cache][:len(data)] = np.asarray(data) if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
# Loading label # Loading label
if self.label_processor is not None: if self.label_processor is not None: