Merge branch 'master' into 'BatchNormModifications'

# Conflicts:
#   layers.py
This commit is contained in:
Corentin 2021-05-21 06:53:31 +00:00
commit fe11f3e6d5
11 changed files with 753 additions and 159 deletions

View file

@ -22,8 +22,13 @@ class Layer(nn.Module):
def __init__(self, activation, use_batch_norm): def __init__(self, activation, use_batch_norm):
super().__init__() super().__init__()
# Preload default # Preload default
if activation == 0:
activation = Layer.ACTIVATION
if isinstance(activation, type):
self.activation = activation()
else:
self.activation = activation
self.batch_norm: torch.nn._BatchNorm = None self.batch_norm: torch.nn._BatchNorm = None
self.activation = Layer.ACTIVATION if activation == 0 else activation
self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
@ -40,7 +45,7 @@ class Linear(Layer):
def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs): def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs):
super().__init__(activation, use_batch_norm) super().__init__(activation, use_batch_norm)
self.fc = nn.Linear(in_channels, out_channels, **kwargs) self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm1d( self.batch_norm = nn.BatchNorm1d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
@ -76,7 +81,7 @@ class Conv2d(Layer):
self.batch_norm = nn.BatchNorm2d( self.batch_norm = nn.BatchNorm2d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.conv(input_data)) return super().forward(self.conv(input_data))
@ -109,7 +114,7 @@ class Deconv2d(Layer):
self.batch_norm = nn.BatchNorm2d( self.batch_norm = nn.BatchNorm2d(
out_channels, out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM, momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.deconv(input_data)) return super().forward(self.deconv(input_data))

View file

@ -3,65 +3,51 @@ from typing import Union, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from .layers import LayerInfo, Layer from .layers import Conv2d, LayerInfo, Layer
class ResBlock(Layer): class ResBlock(Layer):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, def __init__(self, in_channels: int, out_channels: int = -1, kernel_size: int = 3, padding: int = 1,
activation=None, **kwargs): stride: Union[int, Tuple[int, int]] = 1, activation=None, batch_norm=None, **kwargs):
super().__init__(activation if activation is not None else 0, False) super().__init__(activation if activation is not None else 0, False)
self.batch_norm = None
if out_channels == -1:
out_channels = in_channels
self.seq = nn.Sequential( self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs), Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs),
nn.BatchNorm2d( Conv2d(in_channels, out_channels, kernel_size=3, padding=1,
out_channels, activation=None, batch_norm=batch_norm))
momentum=Layer.BATCH_NORM_MOMENTUM, self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
track_running_stats=not Layer.BATCH_NORM_TRAINING), out_channels != in_channels or stride != 1) else None
torch.nn.LeakyReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1),
nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING))
self.batch_norm = nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
if self.residual is not None:
return super().forward(self.residual(input_data) + self.seq(input_data))
return super().forward(input_data + self.seq(input_data)) return super().forward(input_data + self.seq(input_data))
class ResBottleneck(Layer): class ResBottleneck(Layer):
def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3, def __init__(self, in_channels: int, out_channels: int = -1, bottleneck_channels: int = -1, kernel_size: int = 3,
stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs): stride: Union[int, Tuple[int, int]] = 1, padding=1,
activation=None, batch_norm=None, **kwargs):
super().__init__(activation if activation is not None else 0, False) super().__init__(activation if activation is not None else 0, False)
self.batch_norm = None self.batch_norm = None
if out_channels == -1:
out_channels = in_channels
if bottleneck_channels == -1:
bottleneck_channels = in_channels // 4
self.seq = nn.Sequential( self.seq = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), Conv2d(in_channels, bottleneck_channels, kernel_size=1),
nn.BatchNorm2d( Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=kernel_size,
out_channels, stride=stride, padding=padding, **kwargs),
momentum=Layer.BATCH_NORM_MOMENTUM, Conv2d(bottleneck_channels, out_channels, kernel_size=1,
track_running_stats=not Layer.BATCH_NORM_TRAINING), activation=None, batch_norm=batch_norm))
torch.nn.LeakyReLU(), self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs), out_channels != in_channels or stride != 1) else None
nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING),
torch.nn.LeakyReLU(),
nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING))
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1),
nn.BatchNorm2d(
planes * out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,
track_running_stats=not Layer.BATCH_NORM_TRAINING))
def forward(self, input_data: torch.Tensor) -> torch.Tensor: def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return super().forward(self.downsample(input_data) + self.seq(input_data)) if self.residual is not None:
return super().forward(self.residual(input_data) + self.seq(input_data))
return super().forward(input_data + self.seq(input_data))

86
ssd/box.py Normal file
View file

@ -0,0 +1,86 @@
import numpy as np
def create_box(y_pos: float, x_pos: float, height: float, width: float) -> tuple[float, float, float, float]:
y_min, x_min, y_max, x_max = check_rectangle(
y_pos - (height / 2), x_pos - (width / 2), y_pos + (height / 2), x_pos + (width / 2))
return (y_min + y_max) / 2, (x_min + x_max) / 2, y_max - y_min, x_max - x_min
def check_rectangle(y_min: float, x_min: float, y_max: float, x_max: float) -> tuple[float, float, float, float]:
if y_min < 0:
y_min = 0
if x_min < 0:
x_min = 0
if y_min > 1:
y_min = 1
if x_min > 1:
x_min = 1
if y_max < 0:
y_max = 0
if x_max < 0:
x_max = 0
if y_max >= 1:
y_max = 1
if x_max >= 1:
x_max = 1
return y_min, x_min, y_max, x_max
def get_boxes(predictions: np.ndarray, anchors: np.ndarray, class_index: int) -> np.ndarray:
boxes = np.zeros(anchors.shape)
boxes[:, 0] = (predictions[:, 0] * anchors[:, 2]) + anchors[:, 0]
boxes[:, 1] = (predictions[:, 1] * anchors[:, 3]) + anchors[:, 1]
boxes[:, 2] = np.exp(predictions[:, 2]) * anchors[:, 2]
boxes[:, 3] = np.exp(predictions[:, 3]) * anchors[:, 3]
boxes = np.asarray([create_box(*box) for box in boxes])
# return np.insert(boxes, 4, predictions[:, class_index], axis=-1)
return np.concatenate([boxes, predictions[:, class_index:class_index + 1]], axis=1)
def fast_nms(boxes: np.ndarray, min_iou: float) -> np.ndarray:
# if there are no boxes, return an empty list
if len(boxes) == 0:
return []
# initialize the list of picked indexes
pick = []
# grab the coordinates of the bounding boxes
y_min = boxes[:, 0] - (boxes[:, 2] / 2)
y_max = boxes[:, 0] + (boxes[:, 2] / 2)
x_min = boxes[:, 1] - (boxes[:, 3] / 2)
x_max = boxes[:, 1] + (boxes[:, 3] / 2)
scores = boxes[:, 4]
# compute the area of the bounding boxes and sort the bounding boxes by the scores
areas = (x_max - x_min) * (y_max - y_min)
idxs = np.argsort(scores)
# keep looping while some indexes still remain in the indexes
# list
while len(idxs) > 0:
# grab the last index in the indexes list and add the
# index value to the list of picked indexes
last = len(idxs) - 1
i = idxs[last]
pick.append(i)
inter_tops = np.maximum(y_min[i], y_min[idxs[:last]])
inter_bottoms = np.minimum(y_max[i], y_max[idxs[:last]])
inter_lefts = np.maximum(x_min[i], x_min[idxs[:last]])
inter_rights = np.minimum(x_max[i], x_max[idxs[:last]])
inter_areas = (inter_rights - inter_lefts) * (inter_bottoms - inter_tops)
# compute the ratio of overlap
union_area = (areas[idxs[:last]] + areas[i]) - inter_areas
overlap = inter_areas / union_area
# delete all indexes from the index list that have less overlap than min_iou
idxs = np.delete(
idxs, np.concatenate(([last], np.where(overlap > min_iou)[0])))
# return only the bounding boxes that were picked using the
# integer data type
return boxes[pick]

112
ssd/loss.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
class JacardOverlap(nn.Module):
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
Jaccard overlap : A B / A B = A B / (area(A) + area(B) - A B)
Return:
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
"""
anchors_count = anchors.size(0)
labels_count = labels.size(0)
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
anchor_coords = torch.cat([
anchors[:, :2] - (anchors[:, 2:] / 2),
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
label_coords = torch.cat([
labels[:, :2] - (labels[:, 2:] / 2),
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
inter_coords = torch.clamp(maxes - mins, min=0)
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
union_area = anchor_areas + label_areas - inter_area
return inter_area / union_area
class SSDLoss(nn.Module):
def __init__(self, anchors: torch.Tensor, label_per_image: int,
negative_mining_ratio: int, matching_iou: float,
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
super().__init__()
self.anchors = anchors
self.anchor_count = anchors.size(0)
self.label_per_image = label_per_image
self.location_dimmension = location_dimmension
self.negative_mining_ratio = negative_mining_ratio
self.matching_iou = matching_iou
self.localization_loss_weight = localization_loss_weight
self.overlap = JacardOverlap()
self.matches = []
# self.negative_matches = []
self.positive_class_loss = torch.Tensor()
self.negative_class_loss = torch.Tensor()
self.localization_loss = torch.Tensor()
self.class_loss = torch.Tensor()
self.final_loss = torch.Tensor()
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
batch_size = input_data.size(0)
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
batch_size, self.anchor_count, self.label_per_image, 4)
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
expanded_anchors[:, :, :, 2:])
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
positive_objectives = []
positive_predictions = []
positive_class_loss = []
negative_class_loss = []
self.matches = []
# self.negative_matches = []
for batch_index in range(batch_size):
predictions = input_data[batch_index]
labels = input_labels[batch_index]
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
mask = (overlaps >= self.matching_iou).long()
match_indices = torch.nonzero(mask, as_tuple=False)
self.matches.append(match_indices.detach().cpu())
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
# self.negative_matches.append(non_match_indices.detach().cpu())
for anchor_index, label_index in match_indices:
positive_predictions.append(predictions[anchor_index])
positive_objectives.append(
torch.cat((
objective_pos[batch_index, anchor_index, label_index],
objective_size[batch_index, anchor_index, label_index]), dim=-1))
positive_class_loss.append(torch.log(
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
for anchor_index in non_match_indices:
negative_class_loss.append(
torch.log(predictions[anchor_index, self.location_dimmension]))
if not positive_predictions:
return None
positive_predictions = torch.stack(positive_predictions)
positive_objectives = torch.stack(positive_objectives)
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
self.localization_loss = nn.functional.smooth_l1_loss(
positive_predictions[:, self.location_dimmension],
positive_objectives)
self.class_loss = self.positive_class_loss + self.negative_class_loss
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
return self.final_loss

165
ssd/ssd.py Normal file
View file

@ -0,0 +1,165 @@
import colorsys
import math
import numpy as np
import torch
import torch.nn as nn
from .box import check_rectangle
from ..layers import Conv2d
class SSD(nn.Module):
class Detector(nn.Module):
def __init__(self, input_features: int, output_features: int):
super().__init__()
self.conv = Conv2d(input_features, output_features, kernel_size=3, padding=1,
batch_norm=False, activation=None)
self.output = None
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
self.output = self.conv(input_data).permute(0, 2, 3, 1)
return self.output
class DetectorMerge(nn.Module):
def __init__(self, location_dimmension: int):
super().__init__()
self.location_dim = location_dimmension
def forward(self, detector_outputs: torch.Tensor) -> torch.Tensor:
return torch.cat(
[detector_outputs[:, :, :self.location_dim],
torch.softmax(detector_outputs[:, :, self.location_dim:], dim=2)], dim=2)
class AnchorInfo:
def __init__(self, center: tuple[float, float], size: tuple[float],
index: int, layer_index: int, map_index: tuple[int, int], color_index: int,
ratio: float, size_factor: float):
self.index = index
self.layer_index = layer_index
self.map_index = map_index
self.color_index = color_index
self.ratio = ratio
self.size_factor = size_factor
self.center = center
self.size = size
self.box = check_rectangle(
center[0] - (size[0] / 2), center[1] - (size[1] / 2),
center[0] + (size[0] / 2), center[1] + (size[1] / 2))
def __repr__(self):
return (f'{self.__class__.__name__}'
f'(index:{self.index}, layer:{self.layer_index}, coord:{self.map_index}'
f', center:({self.center[0]:.03f}, {self.center[1]:.03f})'
f', size:({self.size[0]:.03f}, {self.size[1]:.03f})'
f', ratio:{self.ratio:.03f}, size_factor:{self.size_factor:.03f})'
f', y:[{self.box[0]:.03f}:{self.box[2]:.03f}]'
f', x:[{self.box[1]:.03f}:{self.box[3]:.03f}])')
def __array__(self):
return np.array([*self.center, *self.size])
def __init__(self, base_network: nn.Module, input_sample: torch.Tensor, classes: list[str],
location_dimmension: int, layer_channels: list[int], layer_box_ratios: list[float], layer_args: dict,
box_size_factors: list[float]):
super().__init__()
self.location_dim = location_dimmension
self.classes = ['none'] + classes
self.class_count = len(self.classes)
self.base_input_shape = input_sample.numpy().shape[1:]
self.base_network = base_network
sample_output = base_network(input_sample)
self.base_output_shape = list(sample_output.detach().numpy().shape)[-3:]
layer_convs: list[nn.Module] = []
layer_detectors: list[SSD.Detector] = []
last_feature_count = self.base_output_shape[0]
for layer_index, (output_features, kwargs) in enumerate(zip(layer_channels, layer_args)):
if 'disable' not in kwargs:
layer_convs.append(Conv2d(last_feature_count, output_features, **kwargs))
layer_detectors.append(SSD.Detector(
last_feature_count, (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index])))
# layers.append(SSD.Layer(
# last_feature_count, output_features,
# (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index]),
# **kwargs))
last_feature_count = output_features
self.layer_convs = nn.ModuleList(layer_convs)
self.layer_detectors = nn.ModuleList(layer_detectors)
self.merge = self.DetectorMerge(location_dimmension)
self.anchors_numpy, self.anchor_info, self.box_colors = self._create_anchors(
sample_output, self.layer_convs, self.layer_detectors, layer_box_ratios, box_size_factors,
input_sample.shape[3] / input_sample.shape[2])
self.anchors = torch.from_numpy(self.anchors_numpy)
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
head = self.base_network(input_data)
detector_outputs = []
for layer_index, detector in enumerate(self.layer_detectors):
detector_out = detector(head)
detector_outputs.append(detector_out.reshape(
detector_out.size(0), -1, self.class_count + self.location_dim))
if layer_index < len(self.layer_convs):
head = self.layer_convs[layer_index](head)
detector_outputs = torch.cat(detector_outputs, 1)
return self.merge(detector_outputs)
# base_output = self.base_network(input_data)
# head = base_output
# outputs = []
# for layer in self.layers:
# head, detector_output = layer(head)
# outputs.append(detector_output.reshape(base_output.size(0), -1, self.class_count + self.location_dim))
# outputs = torch.cat(outputs, 1)
# return torch.cat(
# [outputs[:, :, :self.location_dim], torch.softmax(outputs[:, :, self.location_dim:], dim=2)], dim=2)
def _apply(self, fn):
super()._apply(fn)
self.anchors = fn(self.anchors)
return self
@staticmethod
def _create_anchors(
base_output: torch.Tensor, layers: nn.ModuleList, detectors: nn.ModuleList, layer_box_ratios: list[float],
box_size_factors: list[float], image_ratio: float) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]:
anchors = []
anchor_info: list[SSD.AnchorInfo] = []
box_colors: list[np.ndarray] = []
head = base_output
for layer_index, detector in enumerate(detectors):
detector_output = detector(head) # detector output shape : NCRSHW (Ratio, Size)
if layer_index < len(layers):
head = layers[layer_index](head)
detector_rows = detector_output.size()[1]
detector_cols = detector_output.size()[2]
color_index = 0
layer_ratios = layer_box_ratios[layer_index]
for index_y in range(detector_rows):
center_y = (index_y + 0.5) / detector_rows
for index_x in range(detector_cols):
center_x = (index_x + 0.5) / detector_cols
for ratio, size_factor in zip(layer_ratios, box_size_factors):
box_colors.append((np.asarray(colorsys.hsv_to_rgb(
color_index / len(layer_ratios), 1.0, 1.0)) * 255).astype(np.uint8))
color_index += 1
unit_box_size = size_factor / max(detector_rows, detector_cols)
anchor_width = unit_box_size * math.sqrt(ratio / image_ratio)
anchor_height = unit_box_size / math.sqrt(ratio / image_ratio)
anchor_info.append(SSD.AnchorInfo(
(center_y, center_x),
(anchor_height, anchor_width),
len(anchors),
layer_index,
(index_y, index_x),
len(box_colors) - 1,
ratio,
size_factor
))
anchors.append([center_y, center_x, anchor_height, anchor_width])
return np.asarray(anchors, dtype=np.float32), anchor_info, box_colors

View file

@ -28,11 +28,26 @@ def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], s
def resource_usage() -> Tuple[int, str]: def resource_usage() -> Tuple[int, str]:
memory_peak = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) memory_peak = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
return memory_peak, gpu_used_memory()
def gpu_used_memory() -> str:
gpu_memory = subprocess.check_output( gpu_memory = subprocess.check_output(
'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode() 'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode().strip()
if 'CUDA_VISIBLE_DEVICES' in os.environ: if 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])] gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
else: else:
gpu_memory = ' '.join(gpu_memory.split('\n')) gpu_memory = ','.join(gpu_memory.split('\n'))
return memory_peak, gpu_memory return gpu_memory
def gpu_total_memory() -> str:
gpu_memory = subprocess.check_output(
'nvidia-smi --query-gpu=memory.total --format=csv,noheader', shell=True).decode().strip()
if 'CUDA_VISIBLE_DEVICES' in os.environ:
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
else:
gpu_memory = ','.join(gpu_memory.split('\n'))
return gpu_memory

View file

@ -23,12 +23,13 @@ class Trainer:
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int, epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
data_dtype=None, label_dtype=None, data_dtype=None, label_dtype=None,
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False, train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
logger=DummyLogger()): logger=DummyLogger(), verbose=True, save_src=True):
super().__init__() super().__init__()
self.device = device self.device = device
self.output_dir = output_dir self.output_dir = output_dir
self.data_is_label = data_is_label self.data_is_label = data_is_label
self.logger = logger self.logger = logger
self.verbose = verbose
self.should_stop = False self.should_stop = False
self.batch_generator_train = batch_generator_train self.batch_generator_train = batch_generator_train
@ -42,8 +43,9 @@ class Trainer:
self.network = network self.network = network
self.optimizer = optimizer self.optimizer = optimizer
self.criterion = criterion self.criterion = criterion
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30) self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30) self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
# Save network graph # Save network graph
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device) batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
@ -74,10 +76,11 @@ class Trainer:
if summary_per_epoch % image_per_epoch == 0: if summary_per_epoch % image_per_epoch == 0:
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch) self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
torch.save(network, os.path.join(output_dir, 'model_init.pt')) torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
# Save source files # Save source files
if save_src:
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob( for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
os.path.join('src', '**', '*.py'), recursive=True): os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry)) dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
@ -89,23 +92,31 @@ class Trainer:
self.processed_inputs = processed_inputs self.processed_inputs = processed_inputs
self.network_outputs = processed_inputs # Placeholder self.network_outputs = processed_inputs # Placeholder
self.train_loss = 0.0 self.train_loss = 0.0
self.train_accuracy = 0.0
self.running_loss = 0.0 self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0 self.running_count = 0
self.benchmark_step = 0 self.benchmark_step = 0
self.benchmark_time = time.time() self.benchmark_time = time.time()
def end_epoch_callback(self):
pass
def train_step_callback( def train_step_callback(
self, self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float): loss: float, accuracy: float):
pass pass
def val_step_callback( def val_step_callback(
self, self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor, batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
batch_labels: torch.Tensor, network_outputs: torch.Tensor, batch_labels: torch.Tensor, network_outputs: torch.Tensor,
loss: float): loss: float, accuracy: float):
pass
def pre_summary_callback(self):
pass pass
def summary_callback( def summary_callback(
@ -129,6 +140,7 @@ class Trainer:
try: try:
while not self.should_stop and self.batch_generator_train.epoch < epochs: while not self.should_stop and self.batch_generator_train.epoch < epochs:
epoch = self.batch_generator_train.epoch epoch = self.batch_generator_train.epoch
if self.verbose:
print() print()
print(' ' * os.get_terminal_size()[0], end='\r') print(' ' * os.get_terminal_size()[0], end='\r')
print(f'Epoch {self.batch_generator_train.epoch}') print(f'Epoch {self.batch_generator_train.epoch}')
@ -138,7 +150,7 @@ class Trainer:
self.batch_labels = torch.as_tensor( self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device) self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
if self.benchmark_step > 1: if self.verbose and self.benchmark_step > 1:
speed = self.benchmark_step / (time.time() - self.benchmark_time) speed = self.benchmark_step / (time.time() - self.benchmark_time)
print( print(
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s' f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
@ -150,35 +162,60 @@ class Trainer:
self.processed_inputs = self.train_pre_process(self.batch_inputs) self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs) self.network_outputs = self.network(self.processed_inputs)
loss = self.criterion( labels = self.batch_labels if not self.data_is_label else self.processed_inputs
self.network_outputs, loss = self.criterion(self.network_outputs, labels)
self.batch_labels if not self.data_is_label else self.processed_inputs)
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
self.train_loss = loss.item() self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data) self.running_count += len(self.batch_generator_train.batch_data)
self.train_step_callback( self.train_step_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.batch_inputs, self.processed_inputs, labels,
self.network_outputs, self.train_loss) self.network_outputs, self.train_loss, self.train_accuracy)
self.benchmark_step += 1 self.benchmark_step += 1
self.save_summaries() self.save_summaries()
self.batch_generator_train.next_batch() self.batch_generator_train.next_batch()
self.end_epoch_callback()
except KeyboardInterrupt: except KeyboardInterrupt:
if self.verbose:
print() print()
# Small training loop for last metrics
for _ in range(20):
self.batch_inputs = torch.as_tensor(
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
self.batch_labels = torch.as_tensor(
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
self.processed_inputs = self.train_pre_process(self.batch_inputs)
self.network_outputs = self.network(self.processed_inputs)
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
loss = self.criterion(self.network_outputs, labels)
self.train_loss = loss.item()
self.train_accuracy = self.accuracy_fn(
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
self.running_loss += self.train_loss
self.running_accuracy += self.train_accuracy
self.running_count += len(self.batch_generator_train.batch_data)
self.benchmark_step += 1
self.batch_generator_train.next_batch()
self.save_summaries(force_summary=True)
train_stop_time = time.time() train_stop_time = time.time()
self.writer_train.close() self.writer_train.close()
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt')) self.writer_val.close()
memory_peak, gpu_memory = resource_usage() memory_peak, gpu_memory = resource_usage()
self.logger.info( self.logger.info(
f'Training time : {train_stop_time - train_start_time:.03f}s\n' f'Training time : {train_stop_time - train_start_time:.03f}s\n'
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}') f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
def save_summaries(self): def save_summaries(self, force_summary=False):
global_step = self.batch_generator_train.global_step global_step = self.batch_generator_train.global_step
if self.batch_generator_train.epoch < self.epoch_skip: if self.batch_generator_train.epoch < self.epoch_skip:
return return
@ -188,7 +225,9 @@ class Trainer:
if self.batch_generator_val.step != 0: if self.batch_generator_val.step != 0:
self.batch_generator_val.skip_epoch() self.batch_generator_val.skip_epoch()
self.pre_summary_callback()
val_loss = 0.0 val_loss = 0.0
val_accuracy = 0.0
val_count = 0 val_count = 0
self.network.train(False) self.network.train(False)
with torch.no_grad(): with torch.no_grad():
@ -201,23 +240,30 @@ class Trainer:
val_pre_process = self.pre_process(val_inputs) val_pre_process = self.pre_process(val_inputs)
val_outputs = self.network(val_pre_process) val_outputs = self.network(val_pre_process)
loss = self.criterion( val_labels = val_labels if not self.data_is_label else val_pre_process
val_outputs, loss = self.criterion(val_outputs, val_labels).item()
val_labels if not self.data_is_label else val_pre_process).item() accuracy = self.accuracy_fn(
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
val_loss += loss val_loss += loss
val_accuracy += accuracy
val_count += len(self.batch_generator_val.batch_data) val_count += len(self.batch_generator_val.batch_data)
self.val_step_callback( self.val_step_callback(
val_inputs, val_pre_process, val_labels, val_outputs, loss) val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
self.batch_generator_val.next_batch() self.batch_generator_val.next_batch()
self.network.train(True) self.network.train(True)
# Add summaries # Add summaries
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1): if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
self.writer_train.add_scalar( self.writer_train.add_scalar(
'loss', self.running_loss / self.running_count, global_step=global_step) 'loss', self.running_loss / self.running_count, global_step=global_step)
self.writer_val.add_scalar( self.writer_val.add_scalar(
'loss', val_loss / val_count, global_step=global_step) 'loss', val_loss / val_count, global_step=global_step)
if self.accuracy_fn is not None:
self.writer_train.add_scalar(
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
self.writer_val.add_scalar(
'error', 1 - (val_accuracy / val_count), global_step=global_step)
self.summary_callback( self.summary_callback(
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count, self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
val_inputs, val_pre_process, val_labels, val_outputs, val_count) val_inputs, val_pre_process, val_labels, val_outputs, val_count)
@ -228,12 +274,18 @@ class Trainer:
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
val_inputs, val_pre_process, val_labels, val_outputs) val_inputs, val_pre_process, val_labels, val_outputs)
if self.verbose:
speed = self.benchmark_step / (time.time() - self.benchmark_time) speed = self.benchmark_step / (time.time() - self.benchmark_time)
print(f'Step {global_step}, ' print(f'Step {global_step}, '
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, ' f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec') f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt')) if self.accuracy_fn is not None:
torch.save(self.network.state_dict(), os.path.join(
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
else:
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
self.benchmark_time = time.time() self.benchmark_time = time.time()
self.benchmark_step = 0 self.benchmark_step = 0
self.running_loss = 0.0 self.running_loss = 0.0
self.running_accuracy = 0.0
self.running_count = 0 self.running_count = 0

View file

@ -11,6 +11,7 @@ class BatchGenerator:
def __init__(self, data: Iterable, label: Iterable, batch_size: int, def __init__(self, data: Iterable, label: Iterable, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
pipeline: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
flip_data=False, save: Optional[str] = None): flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
@ -18,6 +19,7 @@ class BatchGenerator:
self.prefetch = prefetch and not preload self.prefetch = prefetch and not preload
self.num_workers = num_workers self.num_workers = num_workers
self.flip_data = flip_data self.flip_data = flip_data
self.pipeline = pipeline
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
@ -59,15 +61,22 @@ class BatchGenerator:
self.last_batch_size = len(self.index_list) % self.batch_size self.last_batch_size = len(self.index_list) % self.batch_size
if self.last_batch_size == 0: if self.last_batch_size == 0:
self.last_batch_size = self.batch_size self.last_batch_size = self.batch_size
else:
self.step_per_epoch += 1
self.epoch = 0 self.epoch = 0
self.global_step = 0 self.global_step = 0
self.step = 0 self.step = 0
first_data = np.array([data_processor(entry) if data_processor else entry first_data = [data_processor(entry) if data_processor else entry
for entry in self.data[self.index_list[:batch_size]]]) for entry in self.data[self.index_list[:batch_size]]]
first_label = np.array([label_processor(entry) if label_processor else entry first_label = [label_processor(entry) if label_processor else entry
for entry in self.label[self.index_list[:batch_size]]]) for entry in self.label[self.index_list[:batch_size]]]
if self.pipeline is not None:
for data_index, sample_data in enumerate(first_data):
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label
@ -194,10 +203,8 @@ class BatchGenerator:
self.current_cache = 0 self.current_cache = 0
parent_cache_data = self.cache_data parent_cache_data = self.cache_data
parent_cache_label = self.cache_label parent_cache_label = self.cache_label
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype) self.cache_data = [np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)]
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype) self.cache_label = [np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)]
self.cache_data = [cache_data]
self.cache_label = [cache_label]
self.index_list[:] = self.cache_indices self.index_list[:] = self.cache_indices
pipe = self.worker_pipes[worker_index][1] pipe = self.worker_pipes[worker_index][1]
@ -208,8 +215,6 @@ class BatchGenerator:
continue continue
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy() self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
self.cache_data[0] = cache_data[:self.batch_size]
self.cache_label[0] = cache_label[:self.batch_size]
self._next_batch() self._next_batch()
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[ parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
self.current_cache][:self.batch_size] self.current_cache][:self.batch_size]
@ -263,31 +268,37 @@ class BatchGenerator:
data = [] data = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data_processor(self.data[entry])) data.append(self.data_processor(self.data[entry]))
self.cache_data[self.current_cache][:len(data)] = np.asarray(data) data = np.asarray(data)
else: else:
data = self.data[ data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
if self.flip_data: if self.flip_data:
flip = np.random.uniform() flip = np.random.uniform()
if flip < 0.25: if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1] data = data[:, :, ::-1]
elif flip < 0.5: elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1] data = data[:, :, :, ::-1]
elif flip < 0.75: elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1] data = data[:, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = data
else:
self.cache_data[self.current_cache][:len(data)] = data
# Loading label # Loading label
if self.label_processor is not None: if self.label_processor is not None:
label = [] label = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]: for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(self.label_processor(self.label[entry])) label.append(self.label_processor(self.label[entry]))
self.cache_label[self.current_cache][:len(label)] = np.asarray(label) label = np.asarray(label)
else: else:
label = self.label[ label = self.label[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]] self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# Process through pipeline
if self.pipeline is not None:
for data_index, data_entry in enumerate(data):
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = data
self.cache_label[self.current_cache][:len(label)] = label self.cache_label[self.current_cache][:len(label)] = label
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
@ -338,7 +349,6 @@ if __name__ == '__main__':
for _ in range(19): for _ in range(19):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
batch_generator.next_batch() batch_generator.next_batch()
raise KeyboardInterrupt
print() print()
test() test()

138
utils/ipc_data_generator.py Normal file
View file

@ -0,0 +1,138 @@
import multiprocessing as mp
from multiprocessing import shared_memory
import signal
from typing import Callable, Optional, Tuple
import numpy as np
class IPCBatchGenerator:
def __init__(self, ipc_processor: Callable,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
pipeline: Optional[Callable] = None,
prefetch=True, flip_data=False):
self.flip_data = flip_data
self.pipeline = pipeline
self.prefetch = prefetch
self.ipc_processor = ipc_processor
self.data_processor = data_processor
self.label_processor = label_processor
self.global_step = 0
self.data, self.label = ipc_processor()
first_data = [data_processor(entry) for entry in self.data] if data_processor else self.data
first_label = [label_processor(entry) for entry in self.label] if label_processor else self.label
if self.pipeline is not None:
for data_index, sample_data in enumerate(first_data):
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
first_data = np.asarray(first_data)
first_label = np.asarray(first_label)
self.batch_data = first_data
self.batch_label = first_label
self.process_id = 'NA'
if self.prefetch:
self.cache_memory_data = [
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
self.cache_data = [
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
self.cache_memory_label = [
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
self.cache_label = [
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
self.prefetch_stop.buf[0] = 0
self.prefetch_process = mp.Process(target=self._prefetch_worker)
self.prefetch_process.start()
else:
self.cache_data = [first_data]
self.cache_label = [first_label]
self.current_cache = 0
self.process_id = 'main'
def __del__(self):
self.release()
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_value, _traceback):
self.release()
def release(self):
if self.prefetch:
self.prefetch_stop.buf[0] = 1
self.prefetch_pipe_parent.send(True)
self.prefetch_process.join()
for shared_mem in self.cache_memory_data + self.cache_memory_label:
shared_mem.close()
shared_mem.unlink()
self.prefetch_stop.close()
self.prefetch_stop.unlink()
self.prefetch = False # Avoids double release
def _prefetch_worker(self):
self.prefetch = False
self.current_cache = 1
self.process_id = 'prefetch'
signal.signal(signal.SIGINT, signal.SIG_IGN)
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
self.current_cache = 1 - self.current_cache
self.global_step += 1
self._next_batch()
self.prefetch_pipe_child.recv()
self.prefetch_pipe_child.send(self.current_cache)
def _next_batch(self):
# Loading data
self.data, self.label = self.ipc_processor()
data = np.asarray([self.data_processor(entry) for entry in self.data]) if self.data_processor else self.data
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
data = data[:, :, ::-1]
elif flip < 0.5:
data = data[:, :, :, ::-1]
elif flip < 0.75:
data = data[:, :, ::-1, ::-1]
# Loading label
label = np.asarray([
self.label_processor(entry) for entry in self.label]) if self.label_processor else self.label
# Process through pipeline
if self.pipeline is not None:
for data_index, data_entry in enumerate(data):
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = data
self.cache_label[self.current_cache][:len(label)] = label
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
self.global_step += 1
if self.prefetch:
self.prefetch_pipe_parent.send(True)
self.current_cache = self.prefetch_pipe_parent.recv()
else:
self._next_batch()
self.batch_data = self.cache_data[self.current_cache]
self.batch_label = self.cache_label[self.current_cache]
return self.batch_data, self.batch_label

View file

@ -6,3 +6,22 @@ def human_size(byte_count: int) -> str:
break break
amount /= 1024.0 amount /= 1024.0
return f'{amount:.2f}{unit}B' return f'{amount:.2f}{unit}B'
def human_to_bytes(text: str) -> float:
split_index = 0
while '0' <= text[split_index] <= '9':
split_index += 1
if split_index == len(text):
return float(text)
amount = float(text[:split_index])
unit = text[split_index:].strip()
if not unit:
return amount
if unit not in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
raise RuntimeError(f'Unrecognized unit : {unit}')
for final_unit in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
amount *= 1024.0
if unit == final_unit:
return amount

View file

@ -16,14 +16,15 @@ class SequenceGenerator(BatchGenerator):
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int, def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None, data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
pipeline: Optional[Callable] = None, index_list=None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False, prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
flip_data=False, save: Optional[str] = None): sequence_stride=1, save: Optional[str] = None):
self.batch_size = batch_size self.batch_size = batch_size
self.sequence_size = sequence_size self.sequence_size = sequence_size
self.shuffle = shuffle self.shuffle = shuffle
self.prefetch = prefetch and not preload self.prefetch = prefetch and not preload
self.num_workers = num_workers self.num_workers = num_workers
self.flip_data = flip_data self.pipeline = pipeline
if not preload: if not preload:
self.data_processor = data_processor self.data_processor = data_processor
@ -70,14 +71,26 @@ class SequenceGenerator(BatchGenerator):
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index]) h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index]) h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
if index_list is not None:
self.index_list = index_list
else:
self.index_list = [] self.index_list = []
for sequence_index in range(len(self.data)): for sequence_index in range(len(self.data)):
if sequence_stride > 1:
start_indices = np.expand_dims(
np.arange(0,
len(self.data[sequence_index]) - sequence_size + 1,
sequence_stride,
dtype=np.uint32),
axis=-1)
else:
start_indices = np.expand_dims( start_indices = np.expand_dims(
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32), np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
axis=-1) axis=-1)
start_indices = np.insert(start_indices, 0, sequence_index, axis=1) start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
self.index_list.append(start_indices) self.index_list.append(start_indices)
self.index_list = np.concatenate(self.index_list, axis=0) self.index_list = np.concatenate(self.index_list, axis=0)
if shuffle or initial_shuffle: if shuffle or initial_shuffle:
np.random.shuffle(self.index_list) np.random.shuffle(self.index_list)
self.step_per_epoch = len(self.index_list) // self.batch_size self.step_per_epoch = len(self.index_list) // self.batch_size
@ -95,25 +108,27 @@ class SequenceGenerator(BatchGenerator):
first_data.append( first_data.append(
[data_processor(input_data) [data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
first_data = np.asarray(first_data)
else: else:
first_data = [] first_data = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_data.append( first_data.append(
self.data[sequence_index][start_index: start_index + self.sequence_size]) self.data[sequence_index][start_index: start_index + self.sequence_size])
first_data = np.asarray(first_data)
if label_processor: if label_processor:
first_label = [] first_label = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_label.append( first_label.append(
[label_processor(input_label) [label_processor(input_label)
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]]) for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
first_label = np.asarray(first_label)
else: else:
first_label = [] first_label = []
for sequence_index, start_index in self.index_list[:batch_size]: for sequence_index, start_index in self.index_list[:batch_size]:
first_label.append( first_label.append(
self.label[sequence_index][start_index: start_index + self.sequence_size]) self.label[sequence_index][start_index: start_index + self.sequence_size])
if self.pipeline is not None:
for batch_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
first_data[batch_index], first_label[batch_index] = self.pipeline(
np.asarray(data_sequence), np.asarray(label_sequence))
first_data = np.asarray(first_data)
first_label = np.asarray(first_label) first_label = np.asarray(first_label)
self.batch_data = first_data self.batch_data = first_data
self.batch_label = first_label self.batch_label = first_label
@ -165,37 +180,14 @@ class SequenceGenerator(BatchGenerator):
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append( data.append(
np.asarray(
[self.data_processor(input_data) [self.data_processor(input_data)
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]))
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else: else:
data = [] data = []
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size]) data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
# Loading label # Loading label
if self.label_processor is not None: if self.label_processor is not None:
@ -203,14 +195,23 @@ class SequenceGenerator(BatchGenerator):
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append( label.append(
np.asarray(
[self.label_processor(input_data) [self.label_processor(input_data)
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]) for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]))
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
else: else:
label = [] label = []
for sequence_index, start_index in self.index_list[ for sequence_index, start_index in self.index_list[
self.step * self.batch_size:(self.step + 1) * self.batch_size]: self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size]) label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
# Process through pipeline
if self.pipeline is not None:
for batch_index in range(len(data)):
piped_data, piped_label = self.pipeline(data[batch_index], label[batch_index])
self.cache_data[self.current_cache][batch_index] = piped_data
self.cache_label[self.current_cache][batch_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
self.cache_label[self.current_cache][:len(label)] = np.asarray(label) self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
@ -219,16 +220,21 @@ if __name__ == '__main__':
data = np.array( data = np.array(
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8) [[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8)
label = np.array( label = np.array(
[[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8) [[10, 20, 30, 40, 50, 60, 70, 80, 90], [110, 120, 130, 140, 150, 160, 170, 180, 190]], dtype=np.uint8)
def pipeline(data, label):
return data, label
for pipeline in [None, pipeline]:
for data_processor in [None, lambda x:x]: for data_processor in [None, lambda x:x]:
for prefetch in [False, True]: for prefetch in [False, True]:
for num_workers in [1, 2]: for num_workers in [1, 2]:
print(f'{data_processor=} {prefetch=} {num_workers=}') print(f'{pipeline=} {data_processor=} {prefetch=} {num_workers=}')
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, pipeline=pipeline,
prefetch=prefetch, num_workers=num_workers) as batch_generator: prefetch=prefetch, num_workers=num_workers) as batch_generator:
for _ in range(9): for _ in range(9):
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step) print(batch_generator.batch_data.tolist(), batch_generator.batch_label.tolist(),
batch_generator.epoch, batch_generator.step)
batch_generator.next_batch() batch_generator.next_batch()
print() print()