Data flip implementation

This commit is contained in:
corentin 2020-12-18 20:30:36 +09:00
commit 144ff4a004
2 changed files with 40 additions and 17 deletions

View file

@ -12,12 +12,12 @@ class BatchGenerator:
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
left_right_flip=False, save: Optional[str] = None):
flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size
self.shuffle = shuffle
self.prefetch = prefetch and not preload
self.num_workers = num_workers
self.left_right_flip = left_right_flip
self.flip_data = flip_data
if not preload:
self.data_processor = data_processor
@ -129,7 +129,6 @@ class BatchGenerator:
self.release()
def release(self):
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
if self.num_workers > 1:
self.worker_stop.buf[0] = 1
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
@ -141,7 +140,6 @@ class BatchGenerator:
self.num_workers = 0 # Avoids double release
if self.prefetch:
print(f'{self.process_id} cleaning prefetch')
self.prefetch_stop.buf[0] = 1
self.prefetch_pipe_parent.send(True)
self.prefetch_process.join()
@ -153,14 +151,12 @@ class BatchGenerator:
self.prefetch = False # Avoids double release
if self.process_id == 'main' and self.cache_memory_indices is not None:
print(f'{self.process_id} cleaning memory')
self.cache_memory_indices.close()
self.cache_memory_indices.unlink()
for shared_mem in self.cache_memory_data + self.cache_memory_label:
shared_mem.close()
shared_mem.unlink()
self.cache_memory_indices = None # Avoids double release
print(f'{self.process_id} released')
def _prefetch_worker(self):
self.prefetch = False
@ -191,7 +187,6 @@ class BatchGenerator:
if self.num_workers > 1:
self.release()
print(f'{self.process_id} exiting')
def _worker(self, worker_index: int):
self.process_id = f'worker_{worker_index}'
@ -225,7 +220,6 @@ class BatchGenerator:
break
except ValueError:
break
print(f'{self.process_id} exiting')
def skip_epoch(self):
if self.prefetch:
@ -273,7 +267,18 @@ class BatchGenerator:
else:
data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
self.cache_data[self.current_cache][:len(data)] = data
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = data
else:
self.cache_data[self.current_cache][:len(data)] = data
# Loading label
if self.label_processor is not None:
label = []
@ -314,10 +319,6 @@ class BatchGenerator:
self.batch_data = self.cache_data[self.current_cache]
self.batch_label = self.cache_label[self.current_cache]
if self.left_right_flip and np.random.uniform() > 0.5:
self.batch_data = self.batch_data[:, :, ::-1]
self.batch_label = self.batch_label[:, :, ::-1]
return self.batch_data, self.batch_label