From 8a0d9b51a3d1be68bc83f48bce9544f73d5b553d Mon Sep 17 00:00:00 2001 From: Corentin Date: Fri, 18 Dec 2020 17:40:39 +0900 Subject: [PATCH] Improve code --- utils/batch_generator.py | 52 +++++++++++++++++-------------- utils/sequence_batch_generator.py | 18 +++++------ 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/utils/batch_generator.py b/utils/batch_generator.py index e5b12fa..aa278e3 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -71,7 +71,7 @@ class BatchGenerator: self.batch_data = first_data self.batch_label = first_label - self.main_process = False + self.process_id = 'NA' if self.prefetch or self.num_workers > 1: self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_indices = np.ndarray( @@ -95,17 +95,17 @@ class BatchGenerator: self.cache_label = [first_label] if self.prefetch: - self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() + self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe() self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) self.prefetch_stop.buf[0] = 0 self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) self.prefetch_skip.buf[0] = 0 - self.cache_process = mp.Process(target=self._cache_worker) - self.cache_process.start() + self.prefetch_process = mp.Process(target=self._prefetch_worker) + self.prefetch_process.start() self.num_workers = 0 self._init_workers() self.current_cache = 0 - self.main_process = True + self.process_id = 'main' def _init_workers(self): if self.num_workers > 1: @@ -129,6 +129,7 @@ class BatchGenerator: self.release() def release(self): + print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}') if self.num_workers > 1: self.worker_stop.buf[0] = 1 for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): @@ -140,9 +141,10 @@ class BatchGenerator: self.num_workers = 0 # Avoids double release if self.prefetch: + print(f'{self.process_id} cleaning prefetch') self.prefetch_stop.buf[0] = 1 - self.cache_pipe_parent.send(True) - self.cache_process.join() + self.prefetch_pipe_parent.send(True) + self.prefetch_process.join() self.prefetch_stop.close() self.prefetch_stop.unlink() @@ -150,23 +152,21 @@ class BatchGenerator: self.prefetch_skip.unlink() self.prefetch = False # Avoids double release - if self.main_process and self.cache_memory_indices is not None: + if self.process_id == 'main' and self.cache_memory_indices is not None: + print(f'{self.process_id} cleaning memory') self.cache_memory_indices.close() self.cache_memory_indices.unlink() - self.cache_memory_data[0].close() - self.cache_memory_data[0].unlink() - self.cache_memory_data[1].close() - self.cache_memory_data[1].unlink() - self.cache_memory_label[0].close() - self.cache_memory_label[0].unlink() - self.cache_memory_label[1].close() - self.cache_memory_label[1].unlink() + for shared_mem in self.cache_memory_data + self.cache_memory_label: + shared_mem.close() + shared_mem.unlink() self.cache_memory_indices = None # Avoids double release + print(f'{self.process_id} released') - def _cache_worker(self): + def _prefetch_worker(self): self.prefetch = False self.current_cache = 1 self._init_workers() + self.process_id = 'prefetch' while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0: try: @@ -184,15 +184,17 @@ class BatchGenerator: self.global_step += 1 self._next_batch() - self.cache_pipe_child.recv() - self.cache_pipe_child.send(self.current_cache) + self.prefetch_pipe_child.recv() + self.prefetch_pipe_child.send(self.current_cache) except KeyboardInterrupt: break if self.num_workers > 1: self.release() + print(f'{self.process_id} exiting') def _worker(self, worker_index: int): + self.process_id = f'worker_{worker_index}' self.num_workers = 0 self.current_cache = 0 parent_cache_data = self.cache_data @@ -221,6 +223,9 @@ class BatchGenerator: pipe.send(True) except KeyboardInterrupt: break + except ValueError: + break + print(f'{self.process_id} exiting') def skip_epoch(self): if self.prefetch: @@ -293,8 +298,8 @@ class BatchGenerator: if self.step == self.step_per_epoch - 1 and self.shuffle: np.random.shuffle(self.index_list) self.cache_indices[:] = self.index_list - self.cache_pipe_parent.send(True) - self.current_cache = self.cache_pipe_parent.recv() + self.prefetch_pipe_parent.send(True) + self.current_cache = self.prefetch_pipe_parent.recv() else: if self.step == 0 and self.shuffle: np.random.shuffle(self.index_list) @@ -324,14 +329,15 @@ if __name__ == '__main__': [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8) for data_processor in [None, lambda x:x]: - for prefetch in [False, True]: - for num_workers in [1, 3]: + for prefetch in [True, False]: + for num_workers in [3, 1]: print(f'{data_processor=} {prefetch=} {num_workers=}') with BatchGenerator(data, label, 5, data_processor=data_processor, prefetch=prefetch, num_workers=num_workers) as batch_generator: for _ in range(19): print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) batch_generator.next_batch() + raise KeyboardInterrupt print() test() diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 2a30aab..3b0cbd2 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -28,8 +28,8 @@ class SequenceGenerator(BatchGenerator): if not preload: self.data_processor = data_processor self.label_processor = label_processor - self.data = data - self.label = label + self.data = np.asarray(data) + self.label = np.asarray(label) else: self.data_processor = None self.label_processor = None @@ -56,13 +56,13 @@ class SequenceGenerator(BatchGenerator): [np.asarray([data_processor(entry) for entry in serie]) for serie in data], dtype=np.object if len(data) > 1 else None) else: - self.data = data + self.data = np.asarray(data) if label_processor: self.label = np.asarray( [np.asarray([label_processor(entry) for entry in serie]) for serie in label], dtype=np.object if len(label) > 1 else None) else: - self.label = label + self.label = np.asarray(label) if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data_len', data=len(self.data)) @@ -118,7 +118,7 @@ class SequenceGenerator(BatchGenerator): self.batch_data = first_data self.batch_label = first_label - self.main_process = False + self.process_id = 'NA' if self.prefetch or self.num_workers > 1: self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_indices = np.ndarray( @@ -142,17 +142,17 @@ class SequenceGenerator(BatchGenerator): self.cache_label = [first_label] if self.prefetch: - self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() + self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe() self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) self.prefetch_stop.buf[0] = 0 self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) self.prefetch_skip.buf[0] = 0 - self.cache_process = mp.Process(target=self._cache_worker) - self.cache_process.start() + self.prefetch_process = mp.Process(target=self._prefetch_worker) + self.prefetch_process.start() self.num_workers = 0 self._init_workers() self.current_cache = 0 - self.main_process = True + self.process_id = 'main' def _next_batch(self): if self.num_workers > 1: