Implement num_workers for batch generators
* Clean and factor code in batch generators
This commit is contained in:
parent
4b786943f5
commit
40ebd0010d
2 changed files with 372 additions and 228 deletions
|
|
@ -2,7 +2,7 @@ import math
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Callable, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -10,12 +10,15 @@ import numpy as np
|
||||||
|
|
||||||
class BatchGenerator:
|
class BatchGenerator:
|
||||||
|
|
||||||
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None, precache=True,
|
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
||||||
shuffle=True, preload=False, save=None, initial_shuffle=False, left_right_flip=False):
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
|
left_right_flip=False, save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
|
self.prefetch = prefetch and not preload
|
||||||
|
self.num_workers = num_workers
|
||||||
self.left_right_flip = left_right_flip
|
self.left_right_flip = left_right_flip
|
||||||
self.precache = precache and not preload
|
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -50,54 +53,105 @@ class BatchGenerator:
|
||||||
h5_file.create_dataset('data', data=self.data)
|
h5_file.create_dataset('data', data=self.data)
|
||||||
h5_file.create_dataset('label', data=self.label)
|
h5_file.create_dataset('label', data=self.label)
|
||||||
|
|
||||||
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
|
|
||||||
|
|
||||||
self.epoch = 0
|
|
||||||
self.global_step = -1
|
|
||||||
self.step = -1
|
|
||||||
|
|
||||||
self.batch_data: Optional[np.ndarray] = None
|
|
||||||
self.batch_label: Optional[np.ndarray] = None
|
|
||||||
|
|
||||||
self.index_list = np.arange(len(self.data))
|
self.index_list = np.arange(len(self.data))
|
||||||
if shuffle or initial_shuffle:
|
if shuffle or initial_shuffle:
|
||||||
np.random.shuffle(self.index_list)
|
np.random.shuffle(self.index_list)
|
||||||
|
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||||
|
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||||
|
|
||||||
if self.precache:
|
self.epoch = 0
|
||||||
data_sample = np.array([data_processor(entry) if data_processor else entry
|
self.global_step = 0
|
||||||
for entry in self.data[:batch_size]])
|
self.step = 0
|
||||||
label_sample = np.array([label_processor(entry) if label_processor else entry
|
|
||||||
for entry in self.label[:batch_size]])
|
first_data = np.array([data_processor(entry) if data_processor else entry
|
||||||
|
for entry in self.data[self.index_list[:batch_size]]])
|
||||||
|
first_label = np.array([label_processor(entry) if label_processor else entry
|
||||||
|
for entry in self.label[self.index_list[:batch_size]]])
|
||||||
|
self.batch_data = first_data
|
||||||
|
self.batch_label = first_label
|
||||||
|
|
||||||
|
self.main_process = False
|
||||||
|
if self.prefetch or self.num_workers > 1:
|
||||||
|
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||||
|
self.cache_indices = np.ndarray(
|
||||||
|
self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf)
|
||||||
|
self.cache_indices[:] = self.index_list
|
||||||
self.cache_memory_data = [
|
self.cache_memory_data = [
|
||||||
shared_memory.SharedMemory(create=True, size=data_sample.nbytes),
|
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
|
||||||
shared_memory.SharedMemory(create=True, size=data_sample.nbytes)]
|
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
|
||||||
self.cache_data = [
|
self.cache_data = [
|
||||||
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf),
|
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
|
||||||
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)]
|
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
|
||||||
self.cache_memory_label = [
|
self.cache_memory_label = [
|
||||||
shared_memory.SharedMemory(create=True, size=label_sample.nbytes),
|
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
|
||||||
shared_memory.SharedMemory(create=True, size=label_sample.nbytes)]
|
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
|
||||||
self.cache_label = [
|
self.cache_label = [
|
||||||
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf),
|
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
|
||||||
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)]
|
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
|
||||||
|
else:
|
||||||
|
self.cache_memory_indices = None
|
||||||
|
self.cache_data = [first_data]
|
||||||
|
self.cache_label = [first_label]
|
||||||
|
|
||||||
|
if self.prefetch:
|
||||||
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
||||||
self.cache_stop = shared_memory.SharedMemory(create=True, size=1)
|
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||||
self.cache_stop.buf[0] = 0
|
self.prefetch_stop.buf[0] = 0
|
||||||
self.cache_skip = shared_memory.SharedMemory(create=True, size=1)
|
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
|
||||||
self.cache_skip.buf[0] = 0
|
self.prefetch_skip.buf[0] = 0
|
||||||
self.cache_process = mp.Process(target=self.cache_worker)
|
self.cache_process = mp.Process(target=self._cache_worker)
|
||||||
self.cache_process.start()
|
self.cache_process.start()
|
||||||
|
self.num_workers = 0
|
||||||
|
self._init_workers()
|
||||||
|
self.current_cache = 0
|
||||||
|
self.main_process = True
|
||||||
|
|
||||||
|
def _init_workers(self):
|
||||||
|
if self.num_workers > 1:
|
||||||
|
self.worker_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||||
|
self.worker_stop.buf[0] = 0
|
||||||
|
self.worker_pipes = []
|
||||||
|
self.worker_processes = []
|
||||||
|
for _ in range(self.num_workers):
|
||||||
|
self.worker_pipes.append(mp.Pipe())
|
||||||
|
for worker_index in range(self.num_workers):
|
||||||
|
self.worker_processes.append(mp.Process(target=self._worker, args=(worker_index,)))
|
||||||
|
self.worker_processes[-1].start()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.precache:
|
self.release()
|
||||||
self.cache_stop.buf[0] = 1
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
if self.num_workers > 1:
|
||||||
|
self.worker_stop.buf[0] = 1
|
||||||
|
for worker, (pipe, _) in zip(self.worker_processes, self.worker_pipes):
|
||||||
|
pipe.send([-1, 0, 0, 0])
|
||||||
|
worker.join()
|
||||||
|
|
||||||
|
self.worker_stop.close()
|
||||||
|
self.worker_stop.unlink()
|
||||||
|
self.num_workers = 0 # Avoids double release
|
||||||
|
|
||||||
|
if self.prefetch:
|
||||||
|
self.prefetch_stop.buf[0] = 1
|
||||||
self.cache_pipe_parent.send(True)
|
self.cache_pipe_parent.send(True)
|
||||||
self.cache_process.join()
|
self.cache_process.join()
|
||||||
|
|
||||||
self.cache_stop.close()
|
self.prefetch_stop.close()
|
||||||
self.cache_stop.unlink()
|
self.prefetch_stop.unlink()
|
||||||
self.cache_skip.close()
|
self.prefetch_skip.close()
|
||||||
self.cache_skip.unlink()
|
self.prefetch_skip.unlink()
|
||||||
|
self.prefetch = False # Avoids double release
|
||||||
|
|
||||||
|
if self.main_process and self.cache_memory_indices is not None:
|
||||||
|
self.cache_memory_indices.close()
|
||||||
|
self.cache_memory_indices.unlink()
|
||||||
self.cache_memory_data[0].close()
|
self.cache_memory_data[0].close()
|
||||||
self.cache_memory_data[0].unlink()
|
self.cache_memory_data[0].unlink()
|
||||||
self.cache_memory_data[1].close()
|
self.cache_memory_data[1].close()
|
||||||
|
|
@ -106,71 +160,177 @@ class BatchGenerator:
|
||||||
self.cache_memory_label[0].unlink()
|
self.cache_memory_label[0].unlink()
|
||||||
self.cache_memory_label[1].close()
|
self.cache_memory_label[1].close()
|
||||||
self.cache_memory_label[1].unlink()
|
self.cache_memory_label[1].unlink()
|
||||||
|
self.cache_memory_indices = None # Avoids double release
|
||||||
|
|
||||||
def cache_worker(self):
|
def _cache_worker(self):
|
||||||
self.precache = False
|
self.prefetch = False
|
||||||
self.next_batch()
|
self.current_cache = 1
|
||||||
self.cache_data[0][:] = self.batch_data[:]
|
self._init_workers()
|
||||||
self.cache_label[0][:] = self.batch_label[:]
|
|
||||||
current_cache = 0
|
|
||||||
|
|
||||||
while not self.cache_stop.buf[0]:
|
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
|
||||||
try:
|
try:
|
||||||
self.cache_pipe_child.recv()
|
self.current_cache = 1 - self.current_cache
|
||||||
self.cache_pipe_child.send(current_cache)
|
if self.prefetch_skip.buf[0]:
|
||||||
if self.cache_skip.buf[0]:
|
self.prefetch_skip.buf[0] = 0
|
||||||
self.cache_skip.buf[0] = 0
|
|
||||||
self.step = self.step_per_epoch
|
self.step = self.step_per_epoch
|
||||||
self.next_batch()
|
self.index_list[:] = self.cache_indices
|
||||||
current_cache = 1 - current_cache
|
|
||||||
self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:]
|
if self.step >= self.step_per_epoch - 1: # step start at 0
|
||||||
self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:]
|
self.step = 0
|
||||||
|
self.epoch += 1
|
||||||
|
else:
|
||||||
|
self.step += 1
|
||||||
|
self.global_step += 1
|
||||||
|
|
||||||
|
self._next_batch()
|
||||||
|
self.cache_pipe_child.recv()
|
||||||
|
self.cache_pipe_child.send(self.current_cache)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.num_workers > 1:
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
def _worker(self, worker_index: int):
|
||||||
|
self.num_workers = 0
|
||||||
|
self.current_cache = 0
|
||||||
|
parent_cache_data = self.cache_data
|
||||||
|
parent_cache_label = self.cache_label
|
||||||
|
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)
|
||||||
|
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)
|
||||||
|
self.cache_data = [cache_data]
|
||||||
|
self.cache_label = [cache_label]
|
||||||
|
self.index_list[:] = self.cache_indices
|
||||||
|
pipe = self.worker_pipes[worker_index][1]
|
||||||
|
|
||||||
|
while self.worker_stop.buf is not None and self.worker_stop.buf[0] == 0:
|
||||||
|
try:
|
||||||
|
current_cache, batch_index, start_index, self.batch_size = pipe.recv()
|
||||||
|
if self.batch_size == 0:
|
||||||
|
continue
|
||||||
|
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
|
||||||
|
|
||||||
|
self.cache_data[0] = cache_data[:self.batch_size]
|
||||||
|
self.cache_label[0] = cache_label[:self.batch_size]
|
||||||
|
self._next_batch()
|
||||||
|
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
|
||||||
|
self.current_cache][:self.batch_size]
|
||||||
|
parent_cache_label[current_cache][batch_index:batch_index + self.batch_size] = self.cache_label[
|
||||||
|
self.current_cache][:self.batch_size]
|
||||||
|
pipe.send(True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
break
|
break
|
||||||
|
|
||||||
def skip_epoch(self):
|
def skip_epoch(self):
|
||||||
if self.precache:
|
if self.prefetch:
|
||||||
self.cache_skip.buf[0] = 1
|
self.prefetch_skip.buf[0] = 1
|
||||||
self.step = self.step_per_epoch
|
self.step = self.step_per_epoch
|
||||||
self.next_batch()
|
self.next_batch()
|
||||||
|
|
||||||
|
def _worker_next_batch(self):
|
||||||
|
index_list = self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]
|
||||||
|
batch_len = len(index_list)
|
||||||
|
indices_per_worker = batch_len // self.num_workers
|
||||||
|
if indices_per_worker == 0:
|
||||||
|
indices_per_worker = 1
|
||||||
|
|
||||||
|
worker_params = []
|
||||||
|
batch_index = 0
|
||||||
|
start_index = self.step * self.batch_size
|
||||||
|
for _worker_index in range(self.num_workers - 1):
|
||||||
|
worker_params.append([self.current_cache, batch_index, start_index, indices_per_worker])
|
||||||
|
batch_index += indices_per_worker
|
||||||
|
start_index += indices_per_worker
|
||||||
|
worker_params.append([
|
||||||
|
self.current_cache, batch_index, start_index,
|
||||||
|
batch_len - ((self.num_workers - 1) * indices_per_worker)])
|
||||||
|
|
||||||
|
if indices_per_worker == 1 and batch_len < self.num_workers:
|
||||||
|
worker_params = worker_params[:batch_len]
|
||||||
|
|
||||||
|
for params, (pipe, _) in zip(worker_params, self.worker_pipes):
|
||||||
|
pipe.send(params)
|
||||||
|
for _, (pipe, _) in zip(worker_params, self.worker_pipes):
|
||||||
|
pipe.recv()
|
||||||
|
|
||||||
|
def _next_batch(self):
|
||||||
|
if self.num_workers > 1:
|
||||||
|
self._worker_next_batch()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Loading data
|
||||||
|
if self.data_processor is not None:
|
||||||
|
data = []
|
||||||
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
|
data.append(self.data_processor(self.data[entry]))
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
|
else:
|
||||||
|
data = self.data[
|
||||||
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||||
|
self.cache_data[self.current_cache][:len(data)] = data
|
||||||
|
# Loading label
|
||||||
|
if self.label_processor is not None:
|
||||||
|
label = []
|
||||||
|
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
|
label.append(self.label_processor(self.label[entry]))
|
||||||
|
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||||
|
else:
|
||||||
|
label = self.label[
|
||||||
|
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||||
|
self.cache_label[self.current_cache][:len(label)] = label
|
||||||
|
|
||||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
if self.step >= self.step_per_epoch - 1: # step start at 0
|
if self.step >= self.step_per_epoch - 1: # step start at 0
|
||||||
self.step = 0
|
self.step = 0
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
if self.shuffle:
|
|
||||||
np.random.shuffle(self.index_list)
|
|
||||||
else:
|
else:
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
# Loading data
|
|
||||||
if self.precache:
|
if self.prefetch:
|
||||||
|
# Shuffles once step ahead in pretech mode because next batch is already prepared
|
||||||
|
if self.step == self.step_per_epoch - 1 and self.shuffle:
|
||||||
|
np.random.shuffle(self.index_list)
|
||||||
|
self.cache_indices[:] = self.index_list
|
||||||
self.cache_pipe_parent.send(True)
|
self.cache_pipe_parent.send(True)
|
||||||
current_cache = self.cache_pipe_parent.recv()
|
self.current_cache = self.cache_pipe_parent.recv()
|
||||||
self.batch_data = self.cache_data[current_cache].copy()
|
|
||||||
elif self.data_processor is not None:
|
|
||||||
self.batch_data = []
|
|
||||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
|
||||||
self.batch_data.append(self.data_processor(self.data[entry]))
|
|
||||||
self.batch_data = np.asarray(self.batch_data)
|
|
||||||
else:
|
else:
|
||||||
self.batch_data = self.data[
|
if self.step == 0 and self.shuffle:
|
||||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
np.random.shuffle(self.index_list)
|
||||||
# Loading label
|
if self.num_workers > 1:
|
||||||
if self.precache:
|
self.cache_indices[:] = self.index_list
|
||||||
self.batch_label = self.cache_label[current_cache].copy()
|
self._next_batch()
|
||||||
elif self.label_processor is not None:
|
|
||||||
self.batch_label = []
|
if self.step == self.step_per_epoch - 1:
|
||||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.batch_data = self.cache_data[self.current_cache][:self.last_batch_size]
|
||||||
self.batch_label.append(self.label_processor(self.label[entry]))
|
self.batch_label = self.cache_label[self.current_cache][:self.last_batch_size]
|
||||||
self.batch_label = np.asarray(self.batch_label)
|
|
||||||
else:
|
else:
|
||||||
self.batch_label = self.label[
|
self.batch_data = self.cache_data[self.current_cache]
|
||||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
self.batch_label = self.cache_label[self.current_cache]
|
||||||
# print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format(
|
|
||||||
# self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape))
|
|
||||||
if self.left_right_flip and np.random.uniform() > 0.5:
|
if self.left_right_flip and np.random.uniform() > 0.5:
|
||||||
self.batch_data = self.batch_data[:, :, ::-1]
|
self.batch_data = self.batch_data[:, :, ::-1]
|
||||||
self.batch_label = self.batch_label[:, :, ::-1]
|
self.batch_label = self.batch_label[:, :, ::-1]
|
||||||
|
|
||||||
return self.batch_data, self.batch_label
|
return self.batch_data, self.batch_label
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
def test():
|
||||||
|
data = np.array(
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.uint8)
|
||||||
|
label = np.array(
|
||||||
|
[.1, .2, .3, .4, .5, .6, .7, .8, .9, .10, .11, .12, .13, .14, .15, .16, .17, .18, .19], dtype=np.uint8)
|
||||||
|
|
||||||
|
for data_processor in [None, lambda x:x]:
|
||||||
|
for prefetch in [False, True]:
|
||||||
|
for num_workers in [1, 3]:
|
||||||
|
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||||
|
with BatchGenerator(data, label, 5, data_processor=data_processor,
|
||||||
|
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||||
|
for _ in range(9):
|
||||||
|
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||||
|
batch_generator.next_batch()
|
||||||
|
print()
|
||||||
|
|
||||||
|
test()
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,29 @@ import math
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Callable, Iterable, Optional
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .batch_generator import BatchGenerator
|
||||||
|
except ImportError: # If running this script directly
|
||||||
|
from batch_generator import BatchGenerator
|
||||||
|
|
||||||
class SequenceGenerator:
|
|
||||||
|
|
||||||
def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None,
|
class SequenceGenerator(BatchGenerator):
|
||||||
precache=True, preload=False, shuffle=True, initial_shuffle=False, save=None):
|
|
||||||
|
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||||
|
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||||
|
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||||
|
save: Optional[str] = None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.sequence_size = sequence_size
|
self.sequence_size = sequence_size
|
||||||
self.shuffle = shuffle
|
self.shuffle = shuffle
|
||||||
self.precache = precache and not preload
|
self.prefetch = prefetch and not preload
|
||||||
|
self.num_workers = num_workers
|
||||||
|
self.left_right_flip = False
|
||||||
|
|
||||||
if not preload:
|
if not preload:
|
||||||
self.data_processor = data_processor
|
self.data_processor = data_processor
|
||||||
|
|
@ -65,164 +74,139 @@ class SequenceGenerator:
|
||||||
self.index_list = []
|
self.index_list = []
|
||||||
for sequence_index in range(len(self.data)):
|
for sequence_index in range(len(self.data)):
|
||||||
start_indices = np.expand_dims(
|
start_indices = np.expand_dims(
|
||||||
np.arange(len(self.data[sequence_index]) - sequence_size, dtype=np.uint32),
|
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
||||||
axis=-1)
|
axis=-1)
|
||||||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||||
self.index_list.append(start_indices)
|
self.index_list.append(start_indices)
|
||||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||||
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
|
||||||
|
|
||||||
self.epoch = 0
|
|
||||||
self.global_step = -1
|
|
||||||
self.step = -1
|
|
||||||
|
|
||||||
self.batch_data: Optional[np.ndarray] = None
|
|
||||||
self.batch_label: Optional[np.ndarray] = None
|
|
||||||
|
|
||||||
if shuffle or initial_shuffle:
|
if shuffle or initial_shuffle:
|
||||||
np.random.shuffle(self.index_list)
|
np.random.shuffle(self.index_list)
|
||||||
|
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||||
|
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||||
|
|
||||||
|
self.epoch = 0
|
||||||
|
self.global_step = 0
|
||||||
|
self.step = 0
|
||||||
|
|
||||||
if self.precache:
|
|
||||||
if data_processor:
|
if data_processor:
|
||||||
data_sample = []
|
first_data = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
data_sample.append(
|
first_data.append(
|
||||||
[data_processor(input_data)
|
[data_processor(input_data)
|
||||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
data_sample = np.asarray(data_sample)
|
first_data = np.asarray(first_data)
|
||||||
else:
|
else:
|
||||||
data_sample = []
|
first_data = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
data_sample.append(
|
first_data.append(
|
||||||
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
data_sample = np.asarray(data_sample)
|
first_data = np.asarray(first_data)
|
||||||
if label_processor:
|
if label_processor:
|
||||||
label_sample = []
|
first_label = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
label_sample.append(
|
first_label.append(
|
||||||
[label_processor(input_label)
|
[label_processor(input_label)
|
||||||
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
label_sample = np.asarray(label_sample)
|
first_label = np.asarray(first_label)
|
||||||
else:
|
else:
|
||||||
label_sample = []
|
first_label = []
|
||||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||||
label_sample.append(
|
first_label.append(
|
||||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
label_sample = np.asarray(label_sample)
|
first_label = np.asarray(first_label)
|
||||||
|
self.batch_data = first_data
|
||||||
|
self.batch_label = first_label
|
||||||
|
|
||||||
|
self.main_process = False
|
||||||
|
if self.prefetch or self.num_workers > 1:
|
||||||
|
self.cache_memory_indices = shared_memory.SharedMemory(create=True, size=self.index_list.nbytes)
|
||||||
|
self.cache_indices = np.ndarray(
|
||||||
|
self.index_list.shape, dtype=self.index_list.dtype, buffer=self.cache_memory_indices.buf)
|
||||||
|
self.cache_indices[:] = self.index_list
|
||||||
self.cache_memory_data = [
|
self.cache_memory_data = [
|
||||||
shared_memory.SharedMemory(create=True, size=data_sample.nbytes),
|
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
|
||||||
shared_memory.SharedMemory(create=True, size=data_sample.nbytes)]
|
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
|
||||||
self.cache_data = [
|
self.cache_data = [
|
||||||
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[0].buf),
|
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
|
||||||
np.ndarray(data_sample.shape, dtype=data_sample.dtype, buffer=self.cache_memory_data[1].buf)]
|
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
|
||||||
self.cache_memory_label = [
|
self.cache_memory_label = [
|
||||||
shared_memory.SharedMemory(create=True, size=label_sample.nbytes),
|
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
|
||||||
shared_memory.SharedMemory(create=True, size=label_sample.nbytes)]
|
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
|
||||||
self.cache_label = [
|
self.cache_label = [
|
||||||
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[0].buf),
|
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
|
||||||
np.ndarray(label_sample.shape, dtype=label_sample.dtype, buffer=self.cache_memory_label[1].buf)]
|
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
|
||||||
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
|
||||||
self.cache_stop = shared_memory.SharedMemory(create=True, size=1)
|
|
||||||
self.cache_stop.buf[0] = 0
|
|
||||||
self.cache_skip = shared_memory.SharedMemory(create=True, size=1)
|
|
||||||
self.cache_skip.buf[0] = 0
|
|
||||||
self.cache_process = mp.Process(target=self.cache_worker)
|
|
||||||
self.cache_process.start()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if self.precache:
|
|
||||||
self.cache_stop.buf[0] = 1
|
|
||||||
self.cache_pipe_parent.send(True)
|
|
||||||
self.cache_process.join()
|
|
||||||
|
|
||||||
self.cache_stop.close()
|
|
||||||
self.cache_stop.unlink()
|
|
||||||
self.cache_skip.close()
|
|
||||||
self.cache_skip.unlink()
|
|
||||||
self.cache_memory_data[0].close()
|
|
||||||
self.cache_memory_data[0].unlink()
|
|
||||||
self.cache_memory_data[1].close()
|
|
||||||
self.cache_memory_data[1].unlink()
|
|
||||||
self.cache_memory_label[0].close()
|
|
||||||
self.cache_memory_label[0].unlink()
|
|
||||||
self.cache_memory_label[1].close()
|
|
||||||
self.cache_memory_label[1].unlink()
|
|
||||||
|
|
||||||
def cache_worker(self):
|
|
||||||
self.precache = False
|
|
||||||
self.next_batch()
|
|
||||||
self.cache_data[0][:] = self.batch_data[:]
|
|
||||||
self.cache_label[0][:] = self.batch_label[:]
|
|
||||||
current_cache = 0
|
|
||||||
|
|
||||||
while not self.cache_stop.buf[0]:
|
|
||||||
try:
|
|
||||||
self.cache_pipe_child.recv()
|
|
||||||
self.cache_pipe_child.send(current_cache)
|
|
||||||
if self.cache_skip.buf[0]:
|
|
||||||
self.cache_skip.buf[0] = 0
|
|
||||||
self.step = self.step_per_epoch
|
|
||||||
self.next_batch()
|
|
||||||
current_cache = 1 - current_cache
|
|
||||||
self.cache_data[current_cache][:len(self.batch_data)] = self.batch_data[:]
|
|
||||||
self.cache_label[current_cache][:len(self.batch_label)] = self.batch_label[:]
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
break
|
|
||||||
|
|
||||||
def skip_epoch(self):
|
|
||||||
if self.precache:
|
|
||||||
self.cache_skip.buf[0] = 1
|
|
||||||
self.step = self.step_per_epoch
|
|
||||||
self.next_batch()
|
|
||||||
|
|
||||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
if self.step >= self.step_per_epoch - 1: # step start at 0
|
|
||||||
self.step = 0
|
|
||||||
self.epoch += 1
|
|
||||||
if self.shuffle:
|
|
||||||
np.random.shuffle(self.index_list)
|
|
||||||
else:
|
else:
|
||||||
self.step += 1
|
self.cache_memory_indices = None
|
||||||
|
self.cache_data = [first_data]
|
||||||
|
self.cache_label = [first_label]
|
||||||
|
|
||||||
|
if self.prefetch:
|
||||||
|
self.cache_pipe_parent, self.cache_pipe_child = mp.Pipe()
|
||||||
|
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||||
|
self.prefetch_stop.buf[0] = 0
|
||||||
|
self.prefetch_skip = shared_memory.SharedMemory(create=True, size=1)
|
||||||
|
self.prefetch_skip.buf[0] = 0
|
||||||
|
self.cache_process = mp.Process(target=self._cache_worker)
|
||||||
|
self.cache_process.start()
|
||||||
|
self.num_workers = 0
|
||||||
|
self._init_workers()
|
||||||
|
self.current_cache = 0
|
||||||
|
self.main_process = True
|
||||||
|
|
||||||
|
def _next_batch(self):
|
||||||
|
if self.num_workers > 1:
|
||||||
|
self._worker_next_batch()
|
||||||
|
return
|
||||||
|
|
||||||
self.global_step += 1
|
|
||||||
# Loading data
|
# Loading data
|
||||||
if self.precache:
|
if self.data_processor is not None:
|
||||||
self.cache_pipe_parent.send(True)
|
data = []
|
||||||
current_cache = self.cache_pipe_parent.recv()
|
|
||||||
self.batch_data = self.cache_data[current_cache].copy()
|
|
||||||
elif self.data_processor is not None:
|
|
||||||
self.batch_data = []
|
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_data.append(
|
data.append(
|
||||||
[self.data_processor(input_data)
|
[self.data_processor(input_data)
|
||||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
self.batch_data = np.asarray(self.batch_data)
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
else:
|
else:
|
||||||
self.batch_data = []
|
data = []
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_data.append(
|
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||||
self.batch_data = np.asarray(self.batch_data)
|
|
||||||
# Loading label
|
# Loading label
|
||||||
if self.precache:
|
if self.label_processor is not None:
|
||||||
self.batch_label = self.cache_label[current_cache].copy()
|
label = []
|
||||||
elif self.label_processor is not None:
|
|
||||||
self.batch_label = []
|
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_label.append(
|
label.append(
|
||||||
[self.label_processor(input_data)
|
[self.label_processor(input_data)
|
||||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||||
self.batch_label = np.asarray(self.batch_label)
|
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||||
else:
|
else:
|
||||||
self.batch_label = []
|
label = []
|
||||||
for sequence_index, start_index in self.index_list[
|
for sequence_index, start_index in self.index_list[
|
||||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||||
self.batch_label.append(
|
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||||
self.batch_label = np.asarray(self.batch_label)
|
|
||||||
|
|
||||||
return self.batch_data, self.batch_label
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
def test():
|
||||||
|
data = np.array(
|
||||||
|
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8)
|
||||||
|
label = np.array(
|
||||||
|
[[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8)
|
||||||
|
|
||||||
|
for data_processor in [None, lambda x:x]:
|
||||||
|
for prefetch in [False, True]:
|
||||||
|
for num_workers in [1, 2]:
|
||||||
|
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||||
|
with SequenceGenerator(data, label, 5, 3, data_processor=data_processor,
|
||||||
|
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||||
|
for _ in range(9):
|
||||||
|
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||||
|
batch_generator.next_batch()
|
||||||
|
print()
|
||||||
|
|
||||||
|
test()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue