diff --git a/utils/batch_generator.py b/utils/batch_generator.py index 8f60332..2edaed6 100644 --- a/utils/batch_generator.py +++ b/utils/batch_generator.py @@ -2,7 +2,7 @@ import math import multiprocessing as mp from multiprocessing import shared_memory import os -from typing import Optional, Tuple +from typing import Callable, Iterable, Optional, Tuple import h5py import numpy as np @@ -10,12 +10,15 @@ import numpy as np class BatchGenerator: - def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True, - shuffle=True, preload=False, save=None, initial_shuffle=False, left_right_flip=False): + def __init__(self, data: Iterable, label: Iterable, batch_size: int, + data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, + prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, + left_right_flip=False, save: Optional[str] = None): self.batch_size = batch_size self.shuffle = shuffle + self.prefetch = prefetch and not preload + self.num_workers = num_workers self.left_right_flip = left_right_flip - self.precache = precache and not preload if not preload: self.data_processor = data_processor @@ -50,54 +53,105 @@ class BatchGenerator: h5_file.create_dataset('data', data=self.data) h5_file.create_dataset('label', data=self.label) - self.step_per_epoch = math.ceil(len(self.data) / batch_size) - - self.epoch = 0 - self.global_step = -1 - self.step = -1 - - self.batch_data: Optional[np.ndarray] = None - self.batch_label: Optional[np.ndarray] = None - self.index_list = np.arange(len(self.data)) if shuffle or initial_shuffle: np.random.shuffle(self.index_list) + self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) + self.last_batch_size = len(self.index_list) % self.batch_size - if self.precache: - data_sample = np.array([data_processor(entry) if data_processor else entry - for entry in self.data[:batch_size]]) - label_sample = np.array([label_processor(entry) if label_processor else entry - for entry in self.label[:batch_size]]) + self.epoch = 0 + self.global_step = 0 + self.step = 0 + + first_data = np.array([data_processor(entry) if data_processor else entry + for entry in self.data[self.index_list[:batch_size]]]) + first_label = np.array([label_processor(entry) if label_processor else entry + for entry in self.label[self.index_list[:batch_size]]]) + self.batch_data = first_data + self.batch_label = first_label + + self.main_process = False + if self.prefetch or self.num_workers > 1: + self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) + self.cache_indices = np.ndarray( + self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf) + self.cache_indices[:] = self.index_list self.cache_memory_data = [ - shared_memory.SharedMemory(create=True, size=data_sample.nbytes), - shared_memory.SharedMemory(create=True, size=data_sample.nbytes)] + shared_memory.SharedMemory(create=True, size=first_data.nbytes), + shared_memory.SharedMemory(create=True, size=first_data.nbytes)] self.cache_data = [ - np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf), - np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)] + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf), + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)] self.cache_memory_label = [ - shared_memory.SharedMemory(create=True, size=label_sample.nbytes), - shared_memory.SharedMemory(create=True, size=label_sample.nbytes)] + shared_memory.SharedMemory(create=True, size=first_label.nbytes), + shared_memory.SharedMemory(create=True, size=first_label.nbytes)] self.cache_label = [ - np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf), - np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)] + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf), + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)] + else: + self.cache_memory_indices = None + self.cache_data = [first_data] + self.cache_label = [first_label] + + if self.prefetch: self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() - self.cache_stop = shared_memory.SharedMemory(create=True, size=1) - self.cache_stop.buf[0] = 0 - self.cache_skip = shared_memory.SharedMemory(create=True, size=1) - self.cache_skip.buf[0] = 0 - self.cache_process = mp.Process(target=self.cache_worker) + self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) + self.prefetch_stop.buf[0] = 0 + self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) + self.prefetch_skip.buf[0] = 0 + self.cache_process = mp.Process(target=self._cache_worker) self.cache_process.start() + self.num_workers = 0 + self._init_workers() + self.current_cache = 0 + self.main_process = True + + def _init_workers(self): + if self.num_workers > 1: + self.worker_stop = shared_memory.SharedMemory(create=True, size=1) + self.worker_stop.buf[0] = 0 + self.worker_pipes = [] + self.worker_processes = [] + for _ in range(self.num_workers): + self.worker_pipes.append(mp.Pipe()) + for worker_index in range(self.num_workers): + self.worker_processes.append(mp.Process(target=self._worker, args=(worker_index,))) + self.worker_processes[-1].start() def __del__(self): - if self.precache: - self.cache_stop.buf[0] = 1 + self.release() + + def __enter__(self): + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self.release() + + def release(self): + if self.num_workers > 1: + self.worker_stop.buf[0] = 1 + for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes): + pipe.send([-1, 0, 0, 0]) + worker.join() + + self.worker_stop.close() + self.worker_stop.unlink() + self.num_workers = 0 # Avoids double release + + if self.prefetch: + self.prefetch_stop.buf[0] = 1 self.cache_pipe_parent.send(True) self.cache_process.join() - self.cache_stop.close() - self.cache_stop.unlink() - self.cache_skip.close() - self.cache_skip.unlink() + self.prefetch_stop.close() + self.prefetch_stop.unlink() + self.prefetch_skip.close() + self.prefetch_skip.unlink() + self.prefetch = False # Avoids double release + + if self.main_process and self.cache_memory_indices is not None: + self.cache_memory_indices.close() + self.cache_memory_indices.unlink() self.cache_memory_data[0].close() self.cache_memory_data[0].unlink() self.cache_memory_data[1].close() @@ -106,71 +160,177 @@ class BatchGenerator: self.cache_memory_label[0].unlink() self.cache_memory_label[1].close() self.cache_memory_label[1].unlink() + self.cache_memory_indices = None # Avoids double release - def cache_worker(self): - self.precache = False - self.next_batch() - self.cache_data[0][:] = self.batch_data[:] - self.cache_label[0][:] = self.batch_label[:] - current_cache = 0 + def _cache_worker(self): + self.prefetch = False + self.current_cache = 1 + self._init_workers() - while not self.cache_stop.buf[0]: + while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0: try: - self.cache_pipe_child.recv() - self.cache_pipe_child.send(current_cache) - if self.cache_skip.buf[0]: - self.cache_skip.buf[0] = 0 + self.current_cache = 1 - self.current_cache + if self.prefetch_skip.buf[0]: + self.prefetch_skip.buf[0] = 0 self.step = self.step_per_epoch - self.next_batch() - current_cache = 1 - current_cache - self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:] - self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:] + self.index_list[:] = self.cache_indices + + if self.step >= self.step_per_epoch - 1: # step start at 0 + self.step = 0 + self.epoch += 1 + else: + self.step += 1 + self.global_step += 1 + + self._next_batch() + self.cache_pipe_child.recv() + self.cache_pipe_child.send(self.current_cache) + except KeyboardInterrupt: + break + + if self.num_workers > 1: + self.release() + + def _worker(self, worker_index: int): + self.num_workers = 0 + self.current_cache = 0 + parent_cache_data = self.cache_data + parent_cache_label = self.cache_label + cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype) + cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype) + self.cache_data = [cache_data] + self.cache_label = [cache_label] + self.index_list[:] = self.cache_indices + pipe = self.worker_pipes[worker_index][1] + + while self.worker_stop.buf is not None and self.worker_stop.buf[0] == 0: + try: + current_cache, batch_index, start_index, self.batch_size = pipe.recv() + if self.batch_size == 0: + continue + self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy() + + self.cache_data[0] = cache_data[:self.batch_size] + self.cache_label[0] = cache_label[:self.batch_size] + self._next_batch() + parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[ + self.current_cache][:self.batch_size] + parent_cache_label[current_cache][batch_index:batch_index + self.batch_size] = self.cache_label[ + self.current_cache][:self.batch_size] + pipe.send(True) except KeyboardInterrupt: break def skip_epoch(self): - if self.precache: - self.cache_skip.buf[0] = 1 + if self.prefetch: + self.prefetch_skip.buf[0] = 1 self.step = self.step_per_epoch self.next_batch() + def _worker_next_batch(self): + index_list = self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size] + batch_len = len(index_list) + indices_per_worker = batch_len // self.num_workers + if indices_per_worker == 0: + indices_per_worker = 1 + + worker_params = [] + batch_index = 0 + start_index = self.step * self.batch_size + for _worker_index in range(self.num_workers - 1): + worker_params.append([self.current_cache, batch_index, start_index, indices_per_worker]) + batch_index += indices_per_worker + start_index += indices_per_worker + worker_params.append([ + self.current_cache, batch_index, start_index, + batch_len - ((self.num_workers - 1) * indices_per_worker)]) + + if indices_per_worker == 1 and batch_len < self.num_workers: + worker_params = worker_params[:batch_len] + + for params, (pipe, _) in zip(worker_params, self.worker_pipes): + pipe.send(params) + for _, (pipe, _) in zip(worker_params, self.worker_pipes): + pipe.recv() + + def _next_batch(self): + if self.num_workers > 1: + self._worker_next_batch() + return + + # Loading data + if self.data_processor is not None: + data = [] + for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: + data.append(self.data_processor(self.data[entry])) + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + else: + data = self.data[ + self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] + self.cache_data[self.current_cache][:len(data)] = data + # Loading label + if self.label_processor is not None: + label = [] + for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: + label.append(self.label_processor(self.label[entry])) + self.cache_label[self.current_cache][:len(label)] = np.asarray(label) + else: + label = self.label[ + self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] + self.cache_label[self.current_cache][:len(label)] = label + def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: if self.step >= self.step_per_epoch - 1: # step start at 0 self.step = 0 self.epoch += 1 - if self.shuffle: - np.random.shuffle(self.index_list) else: self.step += 1 - self.global_step += 1 - # Loading data - if self.precache: + + if self.prefetch: + # Shuffles once step ahead in pretech mode because next batch is already prepared + if self.step == self.step_per_epoch - 1 and self.shuffle: + np.random.shuffle(self.index_list) + self.cache_indices[:] = self.index_list self.cache_pipe_parent.send(True) - current_cache = self.cache_pipe_parent.recv() - self.batch_data = self.cache_data[current_cache].copy() - elif self.data_processor is not None: - self.batch_data = [] - for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_data.append(self.data_processor(self.data[entry])) - self.batch_data = np.asarray(self.batch_data) + self.current_cache = self.cache_pipe_parent.recv() else: - self.batch_data = self.data[ - self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] - # Loading label - if self.precache: - self.batch_label = self.cache_label[current_cache].copy() - elif self.label_processor is not None: - self.batch_label = [] - for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_label.append(self.label_processor(self.label[entry])) - self.batch_label = np.asarray(self.batch_label) + if self.step == 0 and self.shuffle: + np.random.shuffle(self.index_list) + if self.num_workers > 1: + self.cache_indices[:] = self.index_list + self._next_batch() + + if self.step == self.step_per_epoch - 1: + self.batch_data = self.cache_data[self.current_cache][:self.last_batch_size] + self.batch_label = self.cache_label[self.current_cache][:self.last_batch_size] else: - self.batch_label = self.label[ - self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] - # print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format( - # self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape)) + self.batch_data = self.cache_data[self.current_cache] + self.batch_label = self.cache_label[self.current_cache] + if self.left_right_flip and np.random.uniform() > 0.5: self.batch_data = self.batch_data[:, :, ::-1] self.batch_label = self.batch_label[:, :, ::-1] + return self.batch_data, self.batch_label + + +if __name__ == '__main__': + def test(): + data = np.array( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8) + label = np.array( + [.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8) + + for data_processor in [None, lambda x:x]: + for prefetch in [False, True]: + for num_workers in [1, 3]: + print(f'{data_processor=} {prefetch=} {num_workers=}') + with BatchGenerator(data, label, 5, data_processor=data_processor, + prefetch=prefetch, num_workers=num_workers) as batch_generator: + for _ in range(9): + print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) + batch_generator.next_batch() + print() + + test() diff --git a/utils/sequence_batch_generator.py b/utils/sequence_batch_generator.py index 011a075..abae375 100644 --- a/utils/sequence_batch_generator.py +++ b/utils/sequence_batch_generator.py @@ -2,20 +2,29 @@ import math import multiprocessing as mp from multiprocessing import shared_memory import os -from typing import Optional, Tuple +from typing import Callable, Iterable, Optional import h5py import numpy as np +try: + from .batch_generator import BatchGenerator +except ImportError: # If running this script directly + from batch_generator import BatchGenerator -class SequenceGenerator: - def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None, - precache=True, preload=False, shuffle=True, initial_shuffle=False, save=None): +class SequenceGenerator(BatchGenerator): + + def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, + data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, + prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, + save: Optional[str] = None): self.batch_size = batch_size self.sequence_size = sequence_size self.shuffle = shuffle - self.precache = precache and not preload + self.prefetch = prefetch and not preload + self.num_workers = num_workers + self.left_right_flip = False if not preload: self.data_processor = data_processor @@ -65,164 +74,139 @@ class SequenceGenerator: self.index_list = [] for sequence_index in range(len(self.data)): start_indices = np.expand_dims( - np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32), + np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32), axis=-1) start_indices = np.insert(start_indices, 0, sequence_index, axis=1) self.index_list.append(start_indices) self.index_list = np.concatenate(self.index_list, axis=0) - self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) - - self.epoch = 0 - self.global_step = -1 - self.step = -1 - - self.batch_data: Optional[np.ndarray] = None - self.batch_label: Optional[np.ndarray] = None - if shuffle or initial_shuffle: np.random.shuffle(self.index_list) + self.step_per_epoch = math.ceil(len(self.index_list) / batch_size) + self.last_batch_size = len(self.index_list) % self.batch_size - if self.precache: - if data_processor: - data_sample = [] - for sequence_index, start_index in self.index_list[:batch_size]: - data_sample.append( - [data_processor(input_data) - for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) - data_sample = np.asarray(data_sample) - else: - data_sample = [] - for sequence_index, start_index in self.index_list[:batch_size]: - data_sample.append( - self.data[sequence_index][start_index: start_index + self.sequence_size]) - data_sample = np.asarray(data_sample) - if label_processor: - label_sample = [] - for sequence_index, start_index in self.index_list[:batch_size]: - label_sample.append( - [label_processor(input_label) - for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]]) - label_sample = np.asarray(label_sample) - else: - label_sample = [] - for sequence_index, start_index in self.index_list[:batch_size]: - label_sample.append( - self.label[sequence_index][start_index: start_index + self.sequence_size]) - label_sample = np.asarray(label_sample) + self.epoch = 0 + self.global_step = 0 + self.step = 0 - self.cache_memory_data = [ - shared_memory.SharedMemory(create=True, size=data_sample.nbytes), - shared_memory.SharedMemory(create=True, size=data_sample.nbytes)] - self.cache_data = [ - np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf), - np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)] - self.cache_memory_label = [ - shared_memory.SharedMemory(create=True, size=label_sample.nbytes), - shared_memory.SharedMemory(create=True, size=label_sample.nbytes)] - self.cache_label = [ - np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf), - np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)] - self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() - self.cache_stop = shared_memory.SharedMemory(create=True, size=1) - self.cache_stop.buf[0] = 0 - self.cache_skip = shared_memory.SharedMemory(create=True, size=1) - self.cache_skip.buf[0] = 0 - self.cache_process = mp.Process(target=self.cache_worker) - self.cache_process.start() - - def __del__(self): - if self.precache: - self.cache_stop.buf[0] = 1 - self.cache_pipe_parent.send(True) - self.cache_process.join() - - self.cache_stop.close() - self.cache_stop.unlink() - self.cache_skip.close() - self.cache_skip.unlink() - self.cache_memory_data[0].close() - self.cache_memory_data[0].unlink() - self.cache_memory_data[1].close() - self.cache_memory_data[1].unlink() - self.cache_memory_label[0].close() - self.cache_memory_label[0].unlink() - self.cache_memory_label[1].close() - self.cache_memory_label[1].unlink() - - def cache_worker(self): - self.precache = False - self.next_batch() - self.cache_data[0][:] = self.batch_data[:] - self.cache_label[0][:] = self.batch_label[:] - current_cache = 0 - - while not self.cache_stop.buf[0]: - try: - self.cache_pipe_child.recv() - self.cache_pipe_child.send(current_cache) - if self.cache_skip.buf[0]: - self.cache_skip.buf[0] = 0 - self.step = self.step_per_epoch - self.next_batch() - current_cache = 1 - current_cache - self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:] - self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:] - except KeyboardInterrupt: - break - - def skip_epoch(self): - if self.precache: - self.cache_skip.buf[0] = 1 - self.step = self.step_per_epoch - self.next_batch() - - def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: - if self.step >= self.step_per_epoch - 1: # step start at 0 - self.step = 0 - self.epoch += 1 - if self.shuffle: - np.random.shuffle(self.index_list) + if data_processor: + first_data = [] + for sequence_index, start_index in self.index_list[:batch_size]: + first_data.append( + [data_processor(input_data) + for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) + first_data = np.asarray(first_data) else: - self.step += 1 - - self.global_step += 1 - # Loading data - if self.precache: - self.cache_pipe_parent.send(True) - current_cache = self.cache_pipe_parent.recv() - self.batch_data = self.cache_data[current_cache].copy() - elif self.data_processor is not None: - self.batch_data = [] - for sequence_index, start_index in self.index_list[ - self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_data.append( - [self.data_processor(input_data) - for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) - self.batch_data = np.asarray(self.batch_data) - else: - self.batch_data = [] - for sequence_index, start_index in self.index_list[ - self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_data.append( + first_data = [] + for sequence_index, start_index in self.index_list[:batch_size]: + first_data.append( self.data[sequence_index][start_index: start_index + self.sequence_size]) - self.batch_data = np.asarray(self.batch_data) - # Loading label - if self.precache: - self.batch_label = self.cache_label[current_cache].copy() - elif self.label_processor is not None: - self.batch_label = [] + first_data = np.asarray(first_data) + if label_processor: + first_label = [] + for sequence_index, start_index in self.index_list[:batch_size]: + first_label.append( + [label_processor(input_label) + for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]]) + first_label = np.asarray(first_label) + else: + first_label = [] + for sequence_index, start_index in self.index_list[:batch_size]: + first_label.append( + self.label[sequence_index][start_index: start_index + self.sequence_size]) + first_label = np.asarray(first_label) + self.batch_data = first_data + self.batch_label = first_label + + self.main_process = False + if self.prefetch or self.num_workers > 1: + self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes) + self.cache_indices = np.ndarray( + self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf) + self.cache_indices[:] = self.index_list + self.cache_memory_data = [ + shared_memory.SharedMemory(create=True, size=first_data.nbytes), + shared_memory.SharedMemory(create=True, size=first_data.nbytes)] + self.cache_data = [ + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf), + np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)] + self.cache_memory_label = [ + shared_memory.SharedMemory(create=True, size=first_label.nbytes), + shared_memory.SharedMemory(create=True, size=first_label.nbytes)] + self.cache_label = [ + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf), + np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)] + else: + self.cache_memory_indices = None + self.cache_data = [first_data] + self.cache_label = [first_label] + + if self.prefetch: + self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe() + self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1) + self.prefetch_stop.buf[0] = 0 + self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1) + self.prefetch_skip.buf[0] = 0 + self.cache_process = mp.Process(target=self._cache_worker) + self.cache_process.start() + self.num_workers = 0 + self._init_workers() + self.current_cache = 0 + self.main_process = True + + def _next_batch(self): + if self.num_workers > 1: + self._worker_next_batch() + return + + # Loading data + if self.data_processor is not None: + data = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_label.append( + data.append( + [self.data_processor(input_data) + for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + else: + data = [] + for sequence_index, start_index in self.index_list[ + self.step * self.batch_size:(self.step + 1) * self.batch_size]: + data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) + self.cache_data[self.current_cache][:len(data)] = np.asarray(data) + + # Loading label + if self.label_processor is not None: + label = [] + for sequence_index, start_index in self.index_list[ + self.step * self.batch_size:(self.step + 1) * self.batch_size]: + label.append( [self.label_processor(input_data) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) - self.batch_label = np.asarray(self.batch_label) + self.cache_label[self.current_cache][:len(label)] = np.asarray(label) else: - self.batch_label = [] + label = [] for sequence_index, start_index in self.index_list[ self.step * self.batch_size:(self.step + 1) * self.batch_size]: - self.batch_label.append( - self.label[sequence_index][start_index: start_index + self.sequence_size]) - self.batch_label = np.asarray(self.batch_label) + label.append(self.label[sequence_index][start_index: start_index + self.sequence_size]) + self.cache_label[self.current_cache][:len(label)] = np.asarray(label) - return self.batch_data, self.batch_label + +if __name__ == '__main__': + def test(): + data = np.array( + [[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8) + label = np.array( + [[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8) + + for data_processor in [None, lambda x:x]: + for prefetch in [False, True]: + for num_workers in [1, 2]: + print(f'{data_processor=} {prefetch=} {num_workers=}') + with SequenceGenerator(data, label, 5, 3, data_processor=data_processor, + prefetch=prefetch, num_workers=num_workers) as batch_generator: + for _ in range(9): + print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) + batch_generator.next_batch() + print() + + test()