import multiprocessing as mp from multiprocessing import shared_memory import os from typing import Callable, Iterable, Optional, Tuple import h5py import numpy as np class BatchGenerator: def __init__(self, data: Iterable, label: Iterable, batch_size: int, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, left_right_flip=False, save: Optional[str] = None): self.batch_size = batch_size self.shuffle = shuffle self.prefetch = prefetch and not preload self.num_workers = num_workers self.left_right_flip = left_right_flip if not preload: self.data_processor = data_processor self.label_processor = label_processor self.data = np.asarray(data) self.label = np.asarray(label) else: self.data_processor = None self.label_processor = None save_path = save if save is not None: if '.' not in os.path.basename(save_path): save_path += '.hdf5' if not os.path.exists(os.path.dirname(save_path)): os.makedirs(os.path.dirname(save_path)) if save and os.path.exists(save_path): with h5py.File(save_path, 'r') as h5_file: self.data = np.asarray(h5_file['data']) self.label = np.asarray(h5_file['label']) else: if data_processor: self.data = np.asarray([data_processor(entry) for entry in data]) else: self.data = np.asarray(data) if label_processor: self.label = np.asarray([label_processor(entry) for entry in label]) else: self.label = np.asarray(label) if save: with h5py.File(save_path, 'w') as h5_file: h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('label', data=self.label) self.index_list = np.arange(len(self.data)) if shuffle or initial_shuffle: np.random.shuffle(self.index_list) self.step_per_epoch = len(self.index_list) // self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size if self.last_batch_size == 0: self.last_batch_size = self.batch_size self.epoch = 0 self.global_step = 0 self.step = 0 first_data = np.array([data_processor(entry) if data_processor else entry for entry in self.data[self.index_list[:batch_size]]]) first_label = np.array([label_processor(entry) if label_processor else entry for entry in self.label[self.index_list[:batch_size]]]) self.batch_data = first_data self.batch_label = first_label self.process_id = 'NA' if self.prefetch or self.num_workers > 1: self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) self.cache_indices = np.ndarray( self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf) self.cache_indices[:] = self.index_list self.cache_memory_data = [ shared_memory.SharedMemory(create=True, size=first_data.nbytes), shared_memory.SharedMemory(create=True, size=first_data.nbytes)] self.cache_data = [ np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf), np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)] self.cache_memory_label = [ shared_memory.SharedMemory(create=True, size=first_label.nbytes), shared_memory.SharedMemory(create=True, size=first_label.nbytes)] self.cache_label = [ np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf), np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)] else: self.cache_memory_indices = None self.cache_data = [first_data] self.cache_label = [first_label] if self.prefetch: self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe() self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) self.prefetch_stop.buf[0] = 0 self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) self.prefetch_skip.buf[0] = 0 self.prefetch_process = mp.Process(target=self._prefetch_worker) self.prefetch_process.start() self.num_workers = 0 self._init_workers() self.current_cache = 0 self.process_id = 'main' def _init_workers(self): if self.num_workers > 1: self.worker_stop = shared_memory.SharedMemory(create=True, size=1) self.worker_stop.buf[0] = 0 self.worker_pipes = [] self.worker_processes = [] for _ in range(self.num_workers): self.worker_pipes.append(mp.Pipe()) for worker_index in range(self.num_workers): self.worker_processes.append(mp.Process(target=self._worker, args=(worker_index,))) self.worker_processes[-1].start() def __del__(self): self.release() def __enter__(self): return self def __exit__(self, _exc_type, _exc_value, _traceback): self.release() def release(self): print(f'Releasing {self.num_workers=} {self.prefetch=} {self.process_id=} {self.cache_memory_indices}') if self.num_workers > 1: self.worker_stop.buf[0] = 1 for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): pipe.send([-1, 0, 0, 0]) worker.join() self.worker_stop.close() self.worker_stop.unlink() self.num_workers = 0 # Avoids double release if self.prefetch: print(f'{self.process_id} cleaning prefetch') self.prefetch_stop.buf[0] = 1 self.prefetch_pipe_parent.send(True) self.prefetch_process.join() self.prefetch_stop.close() self.prefetch_stop.unlink() self.prefetch_skip.close() self.prefetch_skip.unlink() self.prefetch = False # Avoids double release if self.process_id == 'main' and self.cache_memory_indices is not None: print(f'{self.process_id} cleaning memory') self.cache_memory_indices.close() self.cache_memory_indices.unlink() for shared_mem in self.cache_memory_data + self.cache_memory_label: shared_mem.close() shared_mem.unlink() self.cache_memory_indices = None # Avoids double release print(f'{self.process_id} released') def _prefetch_worker(self): self.prefetch = False self.current_cache = 1 self._init_workers() self.process_id = 'prefetch' while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0: try: self.current_cache = 1 - self.current_cache if self.prefetch_skip.buf[0]: self.prefetch_skip.buf[0] = 0 self.step = self.step_per_epoch self.index_list[:] = self.cache_indices if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 else: self.step += 1 self.global_step += 1 self._next_batch() self.prefetch_pipe_child.recv() self.prefetch_pipe_child.send(self.current_cache) except KeyboardInterrupt: break if self.num_workers > 1: self.release() print(f'{self.process_id} exiting') def _worker(self, worker_index: int): self.process_id = f'worker_{worker_index}' self.num_workers = 0 self.current_cache = 0 parent_cache_data = self.cache_data parent_cache_label = self.cache_label cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype) cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype) self.cache_data = [cache_data] self.cache_label = [cache_label] self.index_list[:] = self.cache_indices pipe = self.worker_pipes[worker_index][1] while self.worker_stop.buf is not None and self.worker_stop.buf[0] == 0: try: current_cache, batch_index, start_index, self.batch_size = pipe.recv() if self.batch_size == 0: continue self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy() self.cache_data[0] = cache_data[:self.batch_size] self.cache_label[0] = cache_label[:self.batch_size] self._next_batch() parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[ self.current_cache][:self.batch_size] parent_cache_label[current_cache][batch_index:batch_index + self.batch_size] = self.cache_label[ self.current_cache][:self.batch_size] pipe.send(True) except KeyboardInterrupt: break except ValueError: break print(f'{self.process_id} exiting') def skip_epoch(self): if self.prefetch: self.prefetch_skip.buf[0] = 1 self.step = self.step_per_epoch self.next_batch() def _worker_next_batch(self): index_list = self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size] batch_len = len(index_list) indices_per_worker = batch_len // self.num_workers if indices_per_worker == 0: indices_per_worker = 1 worker_params = [] batch_index = 0 start_index = self.step * self.batch_size for _worker_index in range(self.num_workers - 1): worker_params.append([self.current_cache, batch_index, start_index, indices_per_worker]) batch_index += indices_per_worker start_index += indices_per_worker worker_params.append([ self.current_cache, batch_index, start_index, batch_len - ((self.num_workers - 1) * indices_per_worker)]) if indices_per_worker == 1 and batch_len < self.num_workers: worker_params = worker_params[:batch_len] for params, (pipe, _) in zip(worker_params, self.worker_pipes): pipe.send(params) for _, (pipe, _) in zip(worker_params, self.worker_pipes): pipe.recv() def _next_batch(self): if self.num_workers > 1: self._worker_next_batch() return # Loading data if self.data_processor is not None: data = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: data.append(self.data_processor(self.data[entry])) self.cache_data[self.current_cache][:len(data)] = np.asarray(data) else: data = self.data[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] self.cache_data[self.current_cache][:len(data)] = data # Loading label if self.label_processor is not None: label = [] for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: label.append(self.label_processor(self.label[entry])) self.cache_label[self.current_cache][:len(label)] = np.asarray(label) else: label = self.label[ self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] self.cache_label[self.current_cache][:len(label)] = label def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 else: self.step += 1 self.global_step += 1 if self.prefetch: # Shuffles once step ahead in pretech mode because next batch is already prepared if self.step == self.step_per_epoch - 1 and self.shuffle: np.random.shuffle(self.index_list) self.cache_indices[:] = self.index_list self.prefetch_pipe_parent.send(True) self.current_cache = self.prefetch_pipe_parent.recv() else: if self.step == 0 and self.shuffle: np.random.shuffle(self.index_list) if self.num_workers > 1: self.cache_indices[:] = self.index_list self._next_batch() if self.step == self.step_per_epoch - 1: self.batch_data = self.cache_data[self.current_cache][:self.last_batch_size] self.batch_label = self.cache_label[self.current_cache][:self.last_batch_size] else: self.batch_data = self.cache_data[self.current_cache] self.batch_label = self.cache_label[self.current_cache] if self.left_right_flip and np.random.uniform() > 0.5: self.batch_data = self.batch_data[:, :, ::-1] self.batch_label = self.batch_label[:, :, ::-1] return self.batch_data, self.batch_label if __name__ == '__main__': def test(): data = np.array( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], dtype=np.uint8) label = np.array( [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18], dtype=np.uint8) for data_processor in [None, lambda x:x]: for prefetch in [True, False]: for num_workers in [3, 1]: print(f'{data_processor=} {prefetch=} {num_workers=}') with BatchGenerator(data, label, 5, data_processor=data_processor, prefetch=prefetch, num_workers=num_workers) as batch_generator: for _ in range(19): print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) batch_generator.next_batch() raise KeyboardInterrupt print() test()