Improve code

This commit is contained in:
Corentin 2020-12-18 17:40:39 +09:00
commit 8a0d9b51a3
2 changed files with 38 additions and 32 deletions

View file

@ -71,7 +71,7 @@ class BatchGenerator:
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label
self.main_process = False self.process_id = 'NA'
if self.prefetch or self.num_workers > 1: if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray( self.cache_indices = np.ndarray(
@ -95,17 +95,17 @@ class BatchGenerator:
self.cache_label = [first_label] self.cache_label = [first_label]
if self.prefetch: if self.prefetch:
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0 self.prefetch_stop.buf[0] = 0
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_skip.buf[0] = 0 self.prefetch_skip.buf[0] = 0
self.cache_process = mp.Process(target=self._cache_worker) self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.cache_process.start() self.prefetch_process.start()
self.num_workers = 0 self.num_workers = 0
self._init_workers() self._init_workers()
self.current_cache = 0 self.current_cache = 0
self.main_process = True self.process_id = 'main'
def _init_workers(self): def _init_workers(self):
if self.num_workers > 1: if self.num_workers > 1:
@ -129,6 +129,7 @@ class BatchGenerator:
self.release() self.release()
def release(self): def release(self):
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
if self.num_workers > 1: if self.num_workers > 1:
self.worker_stop.buf[0] = 1 self.worker_stop.buf[0] = 1
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
@ -140,9 +141,10 @@ class BatchGenerator:
self.num_workers = 0 # Avoids double release self.num_workers = 0 # Avoids double release
if self.prefetch: if self.prefetch:
print(f'{self.process_id} cleaning prefetch')
self.prefetch_stop.buf[0] = 1 self.prefetch_stop.buf[0] = 1
self.cache_pipe_parent.send(True) self.prefetch_pipe_parent.send(True)
self.cache_process.join() self.prefetch_process.join()
self.prefetch_stop.close() self.prefetch_stop.close()
self.prefetch_stop.unlink() self.prefetch_stop.unlink()
@ -150,23 +152,21 @@ class BatchGenerator:
self.prefetch_skip.unlink() self.prefetch_skip.unlink()
self.prefetch = False # Avoids double release self.prefetch = False # Avoids double release
if self.main_process and self.cache_memory_indices is not None: if self.process_id == 'main' and self.cache_memory_indices is not None:
print(f'{self.process_id} cleaning memory')
self.cache_memory_indices.close() self.cache_memory_indices.close()
self.cache_memory_indices.unlink() self.cache_memory_indices.unlink()
self.cache_memory_data[0].close() for shared_mem in self.cache_memory_data + self.cache_memory_label:
self.cache_memory_data[0].unlink() shared_mem.close()
self.cache_memory_data[1].close() shared_mem.unlink()
self.cache_memory_data[1].unlink()
self.cache_memory_label[0].close()
self.cache_memory_label[0].unlink()
self.cache_memory_label[1].close()
self.cache_memory_label[1].unlink()
self.cache_memory_indices = None # Avoids double release self.cache_memory_indices = None # Avoids double release
print(f'{self.process_id} released')
def _cache_worker(self): def _prefetch_worker(self):
self.prefetch = False self.prefetch = False
self.current_cache = 1 self.current_cache = 1
self._init_workers() self._init_workers()
self.process_id = 'prefetch'
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0: while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
try: try:
@ -184,15 +184,17 @@ class BatchGenerator:
self.global_step += 1 self.global_step += 1
self._next_batch() self._next_batch()
self.cache_pipe_child.recv() self.prefetch_pipe_child.recv()
self.cache_pipe_child.send(self.current_cache) self.prefetch_pipe_child.send(self.current_cache)
except KeyboardInterrupt: except KeyboardInterrupt:
break break
if self.num_workers > 1: if self.num_workers > 1:
self.release() self.release()
print(f'{self.process_id} exiting')
def _worker(self, worker_index: int): def _worker(self, worker_index: int):
self.process_id = f'worker_{worker_index}'
self.num_workers = 0 self.num_workers = 0
self.current_cache = 0 self.current_cache = 0
parent_cache_data = self.cache_data parent_cache_data = self.cache_data
@ -221,6 +223,9 @@ class BatchGenerator:
pipe.send(True) pipe.send(True)
except KeyboardInterrupt: except KeyboardInterrupt:
break break
except ValueError:
break
print(f'{self.process_id} exiting')
def skip_epoch(self): def skip_epoch(self):
if self.prefetch: if self.prefetch:
@ -293,8 +298,8 @@ class BatchGenerator:
if self.step == self.step_per_epoch - 1 and self.shuffle: if self.step == self.step_per_epoch - 1 and self.shuffle:
np.random.shuffle(self.index_list) np.random.shuffle(self.index_list)
self.cache_indices[:] = self.index_list self.cache_indices[:] = self.index_list
self.cache_pipe_parent.send(True) self.prefetch_pipe_parent.send(True)
self.current_cache = self.cache_pipe_parent.recv() self.current_cache = self.prefetch_pipe_parent.recv()
else: else:
if self.step == 0 and self.shuffle: if self.step == 0 and self.shuffle:
np.random.shuffle(self.index_list) np.random.shuffle(self.index_list)
@ -324,14 +329,15 @@ if __name__ == '__main__':
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8) [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
for data_processor in [None, lambda x:x]: for data_processor in [None, lambda x:x]:
for prefetch in [False, True]: for prefetch in [True, False]:
for num_workers in [1, 3]: for num_workers in [3, 1]:
print(f'{data_processor=} {prefetch=} {num_workers=}') print(f'{data_processor=} {prefetch=} {num_workers=}')
with BatchGenerator(data, label, 5, data_processor=data_processor, with BatchGenerator(data, label, 5, data_processor=data_processor,
prefetch=prefetch, num_workers=num_workers) as batch_generator: prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(19): for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch() batch_generator.next_batch()
raise KeyboardInterrupt
print() print()
test() test()

View file

@ -28,8 +28,8 @@ class SequenceGenerator(BatchGenerator):
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
self.label_processor = label_processor self.label_processor = label_processor
self.data = data self.data = np.asarray(data)
self.label = label self.label = np.asarray(label)
else: else:
self.data_processor = None self.data_processor = None
self.label_processor = None self.label_processor = None
@ -56,13 +56,13 @@ class SequenceGenerator(BatchGenerator):
[np.asarray([data_processor(entry) for entry in serie]) for serie in data], [np.asarray([data_processor(entry) for entry in serie]) for serie in data],
dtype=np.object if len(data) > 1 else None) dtype=np.object if len(data) > 1 else None)
else: else:
self.data = data self.data = np.asarray(data)
if label_processor: if label_processor:
self.label = np.asarray( self.label = np.asarray(
[np.asarray([label_processor(entry) for entry in serie]) for serie in label], [np.asarray([label_processor(entry) for entry in serie]) for serie in label],
dtype=np.object if len(label) > 1 else None) dtype=np.object if len(label) > 1 else None)
else: else:
self.label = label self.label = np.asarray(label)
if save: if save:
with h5py.File(save_path, 'w') as h5_file: with h5py.File(save_path, 'w') as h5_file:
h5_file.create_dataset('data_len', data=len(self.data)) h5_file.create_dataset('data_len', data=len(self.data))
@ -118,7 +118,7 @@ class SequenceGenerator(BatchGenerator):
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label
self.main_process = False self.process_id = 'NA'
if self.prefetch or self.num_workers > 1: if self.prefetch or self.num_workers > 1:
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
self.cache_indices = np.ndarray( self.cache_indices = np.ndarray(
@ -142,17 +142,17 @@ class SequenceGenerator(BatchGenerator):
self.cache_label = [first_label] self.cache_label = [first_label]
if self.prefetch: if self.prefetch:
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0 self.prefetch_stop.buf[0] = 0
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_skip.buf[0] = 0 self.prefetch_skip.buf[0] = 0
self.cache_process = mp.Process(target=self._cache_worker) self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.cache_process.start() self.prefetch_process.start()
self.num_workers = 0 self.num_workers = 0
self._init_workers() self._init_workers()
self.current_cache = 0 self.current_cache = 0
self.main_process = True self.process_id = 'main'
def _next_batch(self): def _next_batch(self):
if self.num_workers > 1: if self.num_workers > 1: