Improve code

This commit is contained in:
Corentin 2020-12-18 17:40:39 +09:00
commit 8a0d9b51a3
2 changed files with 38 additions and 32 deletions

View file

@ -71,7 +71,7 @@ class BatchGenerator:
self.batch_data = first_data
self.batch_label = first_label
self.main_process = False
self.process_id = 'NA'
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(
@ -95,17 +95,17 @@ class BatchGenerator:
self.cache_label = [first_label]
if self.prefetch:
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_skip.buf[0] = 0
self.cache_process = mp.Process(target=self._cache_worker)
self.cache_process.start()
self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.prefetch_process.start()
self.num_workers = 0
self._init_workers()
self.current_cache = 0
self.main_process = True
self.process_id = 'main'
def _init_workers(self):
if self.num_workers > 1:
@ -129,6 +129,7 @@ class BatchGenerator:
self.release()
def release(self):
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
if self.num_workers > 1:
self.worker_stop.buf[0] = 1
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
@ -140,9 +141,10 @@ class BatchGenerator:
self.num_workers = 0 # Avoids double release
if self.prefetch:
print(f'{self.process_id} cleaning prefetch')
self.prefetch_stop.buf[0] = 1
self.cache_pipe_parent.send(True)
self.cache_process.join()
self.prefetch_pipe_parent.send(True)
self.prefetch_process.join()
self.prefetch_stop.close()
self.prefetch_stop.unlink()
@ -150,23 +152,21 @@ class BatchGenerator:
self.prefetch_skip.unlink()
self.prefetch = False # Avoids double release
if self.main_process and self.cache_memory_indices is not None:
if self.process_id == 'main' and self.cache_memory_indices is not None:
print(f'{self.process_id} cleaning memory')
self.cache_memory_indices.close()
self.cache_memory_indices.unlink()
self.cache_memory_data[0].close()
self.cache_memory_data[0].unlink()
self.cache_memory_data[1].close()
self.cache_memory_data[1].unlink()
self.cache_memory_label[0].close()
self.cache_memory_label[0].unlink()
self.cache_memory_label[1].close()
self.cache_memory_label[1].unlink()
for shared_mem in self.cache_memory_data + self.cache_memory_label:
shared_mem.close()
shared_mem.unlink()
self.cache_memory_indices = None # Avoids double release
print(f'{self.process_id} released')
def _cache_worker(self):
def _prefetch_worker(self):
self.prefetch = False
self.current_cache = 1
self._init_workers()
self.process_id = 'prefetch'
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
try:
@ -184,15 +184,17 @@ class BatchGenerator:
self.global_step += 1
self._next_batch()
self.cache_pipe_child.recv()
self.cache_pipe_child.send(self.current_cache)
self.prefetch_pipe_child.recv()
self.prefetch_pipe_child.send(self.current_cache)
except KeyboardInterrupt:
break
if self.num_workers > 1:
self.release()
print(f'{self.process_id} exiting')
def _worker(self, worker_index: int):
self.process_id = f'worker_{worker_index}'
self.num_workers = 0
self.current_cache = 0
parent_cache_data = self.cache_data
@ -221,6 +223,9 @@ class BatchGenerator:
pipe.send(True)
except KeyboardInterrupt:
break
except ValueError:
break
print(f'{self.process_id} exiting')
def skip_epoch(self):
if self.prefetch:
@ -293,8 +298,8 @@ class BatchGenerator:
if self.step == self.step_per_epoch - 1 and self.shuffle:
np.random.shuffle(self.index_list)
self.cache_indices[:] = self.index_list
self.cache_pipe_parent.send(True)
self.current_cache = self.cache_pipe_parent.recv()
self.prefetch_pipe_parent.send(True)
self.current_cache = self.prefetch_pipe_parent.recv()
else:
if self.step == 0 and self.shuffle:
np.random.shuffle(self.index_list)
@ -324,14 +329,15 @@ if __name__ == '__main__':
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
for data_processor in [None, lambda x:x]:
for prefetch in [False, True]:
for num_workers in [1, 3]:
for prefetch in [True, False]:
for num_workers in [3, 1]:
print(f'{data_processor=} {prefetch=} {num_workers=}')
with BatchGenerator(data, label, 5, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch()
raise KeyboardInterrupt
print()
test()

View file

@ -28,8 +28,8 @@ class SequenceGenerator(BatchGenerator):
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = data
self.label = label
self.data = np.asarray(data)
self.label = np.asarray(label)
else:
self.data_processor = None
self.label_processor = None
@ -56,13 +56,13 @@ class SequenceGenerator(BatchGenerator):
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
dtype=np.object if len(data) > 1 else None)
else:
self.data = data
self.data = np.asarray(data)
if label_processor:
self.label = np.asarray(
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
dtype=np.object if len(label) > 1 else None)
else:
self.label = label
self.label = np.asarray(label)
if save:
with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data_len', data=len(self.data))
@ -118,7 +118,7 @@ class SequenceGenerator(BatchGenerator):
self.batch_data = first_data
self.batch_label = first_label
self.main_process = False
self.process_id = 'NA'
if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray(
@ -142,17 +142,17 @@ class SequenceGenerator(BatchGenerator):
self.cache_label = [first_label]
if self.prefetch:
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_skip.buf[0] = 0
self.cache_process = mp.Process(target=self._cache_worker)
self.cache_process.start()
self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.prefetch_process.start()
self.num_workers = 0
self._init_workers()
self.current_cache = 0
self.main_process = True
self.process_id = 'main'
def _next_batch(self):
if self.num_workers > 1: