Layers, batch generator, memory
This commit is contained in:
parent
9ab6adce7a
commit
268429fa1a
5 changed files with 277 additions and 0 deletions
189
utils/batch_generator.py
Normal file
189
utils/batch_generator.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
import math
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BatchGenerator:
|
||||
|
||||
def __init__(self, data, label, batch_size, data_processor=None, label_processor=None,
|
||||
shuffle=True, preload=False, save=None, left_right_flip=False):
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.left_right_flip = left_right_flip
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
self.data = data
|
||||
self.label = label
|
||||
else:
|
||||
self.data_processor = None
|
||||
self.label_processor = None
|
||||
|
||||
if save and os.path.exists(save + '_data.npy'):
|
||||
self.data = np.load(save + '_data.npy', allow_pickle=True)
|
||||
self.label = np.load(save + '_label.npy', allow_pickle=True)
|
||||
else:
|
||||
if data_processor:
|
||||
self.data = np.asarray([data_processor(entry) for entry in data])
|
||||
else:
|
||||
self.data = data
|
||||
if label_processor:
|
||||
self.label = np.asarray([label_processor(entry) for entry in label])
|
||||
else:
|
||||
self.label = label
|
||||
if save:
|
||||
np.save(save + '_data.npy', self.data, allow_pickle=True)
|
||||
np.save(save + '_label.npy', self.label, allow_pickle=True)
|
||||
|
||||
self.step_per_epoch = math.ceil(len(self.data) / batch_size)
|
||||
|
||||
self.epoch = 0
|
||||
self.global_step = -1
|
||||
self.step = -1
|
||||
|
||||
self.batch_data: Optional[np.ndarray] = None
|
||||
self.batch_label: Optional[np.ndarray] = None
|
||||
|
||||
self.index_list = np.arange(len(self.data))
|
||||
if shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if self.step >= self.step_per_epoch - 1: # step start at 0
|
||||
self.step = 0
|
||||
self.epoch += 1
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
else:
|
||||
self.step += 1
|
||||
|
||||
self.global_step += 1
|
||||
# Loading data
|
||||
if self.data_processor is not None:
|
||||
self.batch_data = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_data.append(self.data_processor(self.data[entry]))
|
||||
self.batch_data = np.asarray(self.batch_data)
|
||||
else:
|
||||
self.batch_data = self.data[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
self.batch_label = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_label.append(self.label_processor(self.label[entry]))
|
||||
self.batch_label = np.asarray(self.batch_label)
|
||||
else:
|
||||
self.batch_label = self.label[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
# print('next_batch : epoch {}, step {}/{}, data {}, label {}'.format(
|
||||
# self.epoch, self.step, self.step_per_epoch - 1, self.batch_data.shape, self.batch_label.shape))
|
||||
if self.left_right_flip and np.random.uniform() > 0.5:
|
||||
self.batch_data = self.batch_data[:, :, ::-1]
|
||||
self.batch_label = self.batch_label[:, :, ::-1]
|
||||
return self.batch_data, self.batch_label
|
||||
|
||||
|
||||
class SequenceGenerator:
|
||||
|
||||
def __init__(self, data, label, sequence_size, batch_size, data_processor=None, label_processor=None,
|
||||
preload=False, shuffle=True, save=None):
|
||||
self.sequence_size = sequence_size
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
|
||||
self.index_list = []
|
||||
for sequence_index in range(len(data)):
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(len(data[sequence_index]) - sequence_size, dtype=np.uint8),
|
||||
axis=-1)
|
||||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||
self.index_list.append(start_indices)
|
||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||
self.step_per_epoch = math.ceil(len(self.index_list) / batch_size)
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
self.data = data
|
||||
self.label = label
|
||||
else:
|
||||
self.data_processor = None
|
||||
self.label_processor = None
|
||||
|
||||
if save and os.path.exists(save + '_data.npy'):
|
||||
self.data = np.load(save + '_data.npy', allow_pickle=True)
|
||||
self.label = np.load(save + '_label.npy', allow_pickle=True)
|
||||
else:
|
||||
if data_processor:
|
||||
self.data = np.asarray(
|
||||
[np.asarray([data_processor(entry) for entry in serie]) for serie in data])
|
||||
else:
|
||||
self.data = data
|
||||
if label_processor:
|
||||
self.label = np.asarray(
|
||||
[np.asarray([label_processor(entry) for entry in serie]) for serie in label])
|
||||
else:
|
||||
self.label = label
|
||||
|
||||
if save:
|
||||
np.save(save + '_data.npy', self.data, allow_pickle=True)
|
||||
np.save(save + '_label.npy', self.label, allow_pickle=True)
|
||||
|
||||
self.epoch = 0
|
||||
self.global_step = -1
|
||||
self.step = -1
|
||||
|
||||
self.batch_data: Optional[np.ndarray] = None
|
||||
self.batch_label: Optional[np.ndarray] = None
|
||||
|
||||
if shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if self.step >= self.step_per_epoch - 1: # step start at 0
|
||||
self.step = 0
|
||||
self.epoch += 1
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
else:
|
||||
self.step += 1
|
||||
|
||||
self.global_step += 1
|
||||
# Loading data
|
||||
if self.data_processor is not None:
|
||||
self.batch_data = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_data.append(
|
||||
[self.data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
self.batch_data = np.asarray(self.batch_data)
|
||||
else:
|
||||
self.batch_data = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_data.append(
|
||||
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||
self.batch_data = np.asarray(self.batch_data)
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
self.batch_label = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_label.append(
|
||||
[self.label_processor(input_data)
|
||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
self.batch_label = np.asarray(self.batch_label)
|
||||
else:
|
||||
self.batch_label = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
self.batch_label.append(
|
||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||
self.batch_label = np.asarray(self.batch_label)
|
||||
|
||||
return self.batch_data, self.batch_label
|
||||
8
utils/memory.py
Normal file
8
utils/memory.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
def human_size(byte_count: int) -> str:
|
||||
"""Output byte amount in human readable format"""
|
||||
amount = float(byte_count)
|
||||
for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi', 'Yi']:
|
||||
if amount < 1024.0:
|
||||
break
|
||||
amount /= 1024.0
|
||||
return f'{amount:.2f}{unit}B'
|
||||
Loading…
Add table
Add a link
Reference in a new issue