Data flip implementation
This commit is contained in:
parent
8a0d9b51a3
commit
144ff4a004
2 changed files with 40 additions and 17 deletions
|
|
@ -12,12 +12,12 @@ class BatchGenerator:
|
|||
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
left_right_flip=False, save: Optional[str] = None):
|
||||
flip_data=False, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.prefetch = prefetch and not preload
|
||||
self.num_workers = num_workers
|
||||
self.left_right_flip = left_right_flip
|
||||
self.flip_data = flip_data
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -129,7 +129,6 @@ class BatchGenerator:
|
|||
self.release()
|
||||
|
||||
def release(self):
|
||||
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
|
||||
if self.num_workers > 1:
|
||||
self.worker_stop.buf[0] = 1
|
||||
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
|
||||
|
|
@ -141,7 +140,6 @@ class BatchGenerator:
|
|||
self.num_workers = 0 # Avoids double release
|
||||
|
||||
if self.prefetch:
|
||||
print(f'{self.process_id} cleaning prefetch')
|
||||
self.prefetch_stop.buf[0] = 1
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.prefetch_process.join()
|
||||
|
|
@ -153,14 +151,12 @@ class BatchGenerator:
|
|||
self.prefetch = False # Avoids double release
|
||||
|
||||
if self.process_id == 'main' and self.cache_memory_indices is not None:
|
||||
print(f'{self.process_id} cleaning memory')
|
||||
self.cache_memory_indices.close()
|
||||
self.cache_memory_indices.unlink()
|
||||
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
||||
shared_mem.close()
|
||||
shared_mem.unlink()
|
||||
self.cache_memory_indices = None # Avoids double release
|
||||
print(f'{self.process_id} released')
|
||||
|
||||
def _prefetch_worker(self):
|
||||
self.prefetch = False
|
||||
|
|
@ -191,7 +187,6 @@ class BatchGenerator:
|
|||
|
||||
if self.num_workers > 1:
|
||||
self.release()
|
||||
print(f'{self.process_id} exiting')
|
||||
|
||||
def _worker(self, worker_index: int):
|
||||
self.process_id = f'worker_{worker_index}'
|
||||
|
|
@ -225,7 +220,6 @@ class BatchGenerator:
|
|||
break
|
||||
except ValueError:
|
||||
break
|
||||
print(f'{self.process_id} exiting')
|
||||
|
||||
def skip_epoch(self):
|
||||
if self.prefetch:
|
||||
|
|
@ -273,7 +267,18 @@ class BatchGenerator:
|
|||
else:
|
||||
data = self.data[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
label = []
|
||||
|
|
@ -314,10 +319,6 @@ class BatchGenerator:
|
|||
self.batch_data = self.cache_data[self.current_cache]
|
||||
self.batch_label = self.cache_label[self.current_cache]
|
||||
|
||||
if self.left_right_flip and np.random.uniform() > 0.5:
|
||||
self.batch_data = self.batch_data[:, :, ::-1]
|
||||
self.batch_label = self.batch_label[:, :, ::-1]
|
||||
|
||||
return self.batch_data, self.batch_label
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue