Layers, batch generator, memory

This commit is contained in:
Corentin Risselin 2020-03-31 13:46:01 +09:00
commit 268429fa1a
5 changed files with 277 additions and 0 deletions

189
utils/batch_generator.py Normal file
View file

@ -0,0 +1,189 @@
import math
import os
from typing import Optional, Tuple
import numpy as np
class BatchGenerator:
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None,
shuffle=True, preload=False, save=None, left_right_flip=False):
self.batch_size = batch_size
self.shuffle = shuffle
self.left_right_flip = left_right_flip
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = data
self.label = label
else:
self.data_processor = None
self.label_processor = None
if save and os.path.exists(save + '_data.npy'):
self.data = np.load(save + '_data.npy', allow_pickle=True)
self.label = np.load(save + '_label.npy', allow_pickle=True)
else:
if data_processor:
self.data = np.asarray([data_processor(entry) for entry in data])
else:
self.data = data
if label_processor:
self.label = np.asarray([label_processor(entry) for entry in label])
else:
self.label = label
if save:
np.save(save + '_data.npy', self.data, allow_pickle=True)
np.save(save + '_label.npy', self.label, allow_pickle=True)
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
self.epoch = 0
self.global_step = -1
self.step = -1
self.batch_data: Optional[np.ndarray] = None
self.batch_label: Optional[np.ndarray] = None
self.index_list = np.arange(len(self.data))
if shuffle:
np.random.shuffle(self.index_list)
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
if self.step >= self.step_per_epoch - 1: # step start at 0
self.step = 0
self.epoch += 1
if self.shuffle:
np.random.shuffle(self.index_list)
else:
self.step += 1
self.global_step += 1
# Loading data
if self.data_processor is not None:
self.batch_data = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_data.append(self.data_processor(self.data[entry]))
self.batch_data = np.asarray(self.batch_data)
else:
self.batch_data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# Loading label
if self.label_processor is not None:
self.batch_label = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_label.append(self.label_processor(self.label[entry]))
self.batch_label = np.asarray(self.batch_label)
else:
self.batch_label = self.label[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format(
# self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape))
if self.left_right_flip and np.random.uniform() > 0.5:
self.batch_data = self.batch_data[:, :, ::-1]
self.batch_label = self.batch_label[:, :, ::-1]
return self.batch_data, self.batch_label
class SequenceGenerator:
def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None,
preload=False, shuffle=True, save=None):
self.sequence_size = sequence_size
self.batch_size = batch_size
self.shuffle = shuffle
self.index_list = []
for sequence_index in range(len(data)):
start_indices = np.expand_dims(
np.arange(len(data[sequence_index]) - sequence_size, dtype=np.uint8),
axis=-1)
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
self.index_list.append(start_indices)
self.index_list = np.concatenate(self.index_list, axis=0)
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
if not preload:
self.data_processor = data_processor
self.label_processor = label_processor
self.data = data
self.label = label
else:
self.data_processor = None
self.label_processor = None
if save and os.path.exists(save + '_data.npy'):
self.data = np.load(save + '_data.npy', allow_pickle=True)
self.label = np.load(save + '_label.npy', allow_pickle=True)
else:
if data_processor:
self.data = np.asarray(
[np.asarray([data_processor(entry) for entry in serie]) for serie in data])
else:
self.data = data
if label_processor:
self.label = np.asarray(
[np.asarray([label_processor(entry) for entry in serie]) for serie in label])
else:
self.label = label
if save:
np.save(save + '_data.npy', self.data, allow_pickle=True)
np.save(save + '_label.npy', self.label, allow_pickle=True)
self.epoch = 0
self.global_step = -1
self.step = -1
self.batch_data: Optional[np.ndarray] = None
self.batch_label: Optional[np.ndarray] = None
if shuffle:
np.random.shuffle(self.index_list)
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
if self.step >= self.step_per_epoch - 1: # step start at 0
self.step = 0
self.epoch += 1
if self.shuffle:
np.random.shuffle(self.index_list)
else:
self.step += 1
self.global_step += 1
# Loading data
if self.data_processor is not None:
self.batch_data = []
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_data.append(
[self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
self.batch_data = np.asarray(self.batch_data)
else:
self.batch_data = []
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_data.append(
self.data[sequence_index][start_index: start_index + self.sequence_size])
self.batch_data = np.asarray(self.batch_data)
# Loading label
if self.label_processor is not None:
self.batch_label = []
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_label.append(
[self.label_processor(input_data)
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
self.batch_label = np.asarray(self.batch_label)
else:
self.batch_label = []
for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
self.batch_label.append(
self.label[sequence_index][start_index: start_index + self.sequence_size])
self.batch_label = np.asarray(self.batch_label)
return self.batch_data, self.batch_label