Improve code
This commit is contained in:
parent
128dfe511e
commit
8a0d9b51a3
2 changed files with 38 additions and 32 deletions
|
|
@ -71,7 +71,7 @@ class BatchGenerator:
|
|||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.main_process = False
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch or self.num_workers > 1:
|
||||
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||
self.cache_indices = np.ndarray(
|
||||
|
|
@ -95,17 +95,17 @@ class BatchGenerator:
|
|||
self.cache_label = [first_label]
|
||||
|
||||
if self.prefetch:
|
||||
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
||||
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
|
||||
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_stop.buf[0] = 0
|
||||
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_skip.buf[0] = 0
|
||||
self.cache_process = mp.Process(target=self._cache_worker)
|
||||
self.cache_process.start()
|
||||
self.prefetch_process = mp.Process(target=self._prefetch_worker)
|
||||
self.prefetch_process.start()
|
||||
self.num_workers = 0
|
||||
self._init_workers()
|
||||
self.current_cache = 0
|
||||
self.main_process = True
|
||||
self.process_id = 'main'
|
||||
|
||||
def _init_workers(self):
|
||||
if self.num_workers > 1:
|
||||
|
|
@ -129,6 +129,7 @@ class BatchGenerator:
|
|||
self.release()
|
||||
|
||||
def release(self):
|
||||
print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}')
|
||||
if self.num_workers > 1:
|
||||
self.worker_stop.buf[0] = 1
|
||||
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
|
||||
|
|
@ -140,9 +141,10 @@ class BatchGenerator:
|
|||
self.num_workers = 0 # Avoids double release
|
||||
|
||||
if self.prefetch:
|
||||
print(f'{self.process_id} cleaning prefetch')
|
||||
self.prefetch_stop.buf[0] = 1
|
||||
self.cache_pipe_parent.send(True)
|
||||
self.cache_process.join()
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.prefetch_process.join()
|
||||
|
||||
self.prefetch_stop.close()
|
||||
self.prefetch_stop.unlink()
|
||||
|
|
@ -150,23 +152,21 @@ class BatchGenerator:
|
|||
self.prefetch_skip.unlink()
|
||||
self.prefetch = False # Avoids double release
|
||||
|
||||
if self.main_process and self.cache_memory_indices is not None:
|
||||
if self.process_id == 'main' and self.cache_memory_indices is not None:
|
||||
print(f'{self.process_id} cleaning memory')
|
||||
self.cache_memory_indices.close()
|
||||
self.cache_memory_indices.unlink()
|
||||
self.cache_memory_data[0].close()
|
||||
self.cache_memory_data[0].unlink()
|
||||
self.cache_memory_data[1].close()
|
||||
self.cache_memory_data[1].unlink()
|
||||
self.cache_memory_label[0].close()
|
||||
self.cache_memory_label[0].unlink()
|
||||
self.cache_memory_label[1].close()
|
||||
self.cache_memory_label[1].unlink()
|
||||
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
||||
shared_mem.close()
|
||||
shared_mem.unlink()
|
||||
self.cache_memory_indices = None # Avoids double release
|
||||
print(f'{self.process_id} released')
|
||||
|
||||
def _cache_worker(self):
|
||||
def _prefetch_worker(self):
|
||||
self.prefetch = False
|
||||
self.current_cache = 1
|
||||
self._init_workers()
|
||||
self.process_id = 'prefetch'
|
||||
|
||||
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
|
||||
try:
|
||||
|
|
@ -184,15 +184,17 @@ class BatchGenerator:
|
|||
self.global_step += 1
|
||||
|
||||
self._next_batch()
|
||||
self.cache_pipe_child.recv()
|
||||
self.cache_pipe_child.send(self.current_cache)
|
||||
self.prefetch_pipe_child.recv()
|
||||
self.prefetch_pipe_child.send(self.current_cache)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
if self.num_workers > 1:
|
||||
self.release()
|
||||
print(f'{self.process_id} exiting')
|
||||
|
||||
def _worker(self, worker_index: int):
|
||||
self.process_id = f'worker_{worker_index}'
|
||||
self.num_workers = 0
|
||||
self.current_cache = 0
|
||||
parent_cache_data = self.cache_data
|
||||
|
|
@ -221,6 +223,9 @@ class BatchGenerator:
|
|||
pipe.send(True)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except ValueError:
|
||||
break
|
||||
print(f'{self.process_id} exiting')
|
||||
|
||||
def skip_epoch(self):
|
||||
if self.prefetch:
|
||||
|
|
@ -293,8 +298,8 @@ class BatchGenerator:
|
|||
if self.step == self.step_per_epoch - 1 and self.shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
self.cache_indices[:] = self.index_list
|
||||
self.cache_pipe_parent.send(True)
|
||||
self.current_cache = self.cache_pipe_parent.recv()
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.current_cache = self.prefetch_pipe_parent.recv()
|
||||
else:
|
||||
if self.step == 0 and self.shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
|
|
@ -324,14 +329,15 @@ if __name__ == '__main__':
|
|||
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8)
|
||||
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 3]:
|
||||
for prefetch in [True, False]:
|
||||
for num_workers in [3, 1]:
|
||||
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||
with BatchGenerator(data, label, 5, data_processor=data_processor,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(19):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
raise KeyboardInterrupt
|
||||
print()
|
||||
|
||||
test()
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@ class SequenceGenerator(BatchGenerator):
|
|||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.data = np.asarray(data)
|
||||
self.label = np.asarray(label)
|
||||
else:
|
||||
self.data_processor = None
|
||||
self.label_processor = None
|
||||
|
|
@ -56,13 +56,13 @@ class SequenceGenerator(BatchGenerator):
|
|||
[np.asarray([data_processor(entry) for entry in serie]) for serie in data],
|
||||
dtype=np.object if len(data) > 1 else None)
|
||||
else:
|
||||
self.data = data
|
||||
self.data = np.asarray(data)
|
||||
if label_processor:
|
||||
self.label = np.asarray(
|
||||
[np.asarray([label_processor(entry) for entry in serie]) for serie in label],
|
||||
dtype=np.object if len(label) > 1 else None)
|
||||
else:
|
||||
self.label = label
|
||||
self.label = np.asarray(label)
|
||||
if save:
|
||||
with h5py.File(save_path, 'w') as h5_file:
|
||||
h5_file.create_dataset('data_len', data=len(self.data))
|
||||
|
|
@ -118,7 +118,7 @@ class SequenceGenerator(BatchGenerator):
|
|||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.main_process = False
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch or self.num_workers > 1:
|
||||
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||
self.cache_indices = np.ndarray(
|
||||
|
|
@ -142,17 +142,17 @@ class SequenceGenerator(BatchGenerator):
|
|||
self.cache_label = [first_label]
|
||||
|
||||
if self.prefetch:
|
||||
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
||||
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
|
||||
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_stop.buf[0] = 0
|
||||
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_skip.buf[0] = 0
|
||||
self.cache_process = mp.Process(target=self._cache_worker)
|
||||
self.cache_process.start()
|
||||
self.prefetch_process = mp.Process(target=self._prefetch_worker)
|
||||
self.prefetch_process.start()
|
||||
self.num_workers = 0
|
||||
self._init_workers()
|
||||
self.current_cache = 0
|
||||
self.main_process = True
|
||||
self.process_id = 'main'
|
||||
|
||||
def _next_batch(self):
|
||||
if self.num_workers > 1:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue