Merge branch 'master' into 'BatchNormModifications'
# Conflicts: # layers.py
This commit is contained in:
commit
fe11f3e6d5
11 changed files with 753 additions and 159 deletions
13
layers.py
13
layers.py
|
|
@ -22,8 +22,13 @@ class Layer(nn.Module):
|
|||
def __init__(self, activation, use_batch_norm):
|
||||
super().__init__()
|
||||
# Preload default
|
||||
if activation == 0:
|
||||
activation = Layer.ACTIVATION
|
||||
if isinstance(activation, type):
|
||||
self.activation = activation()
|
||||
else:
|
||||
self.activation = activation
|
||||
self.batch_norm: torch.nn._BatchNorm = None
|
||||
self.activation = Layer.ACTIVATION if activation == 0 else activation
|
||||
self.use_batch_norm = Layer.USE_BATCH_NORM if use_batch_norm is None else use_batch_norm
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -40,7 +45,7 @@ class Linear(Layer):
|
|||
def __init__(self, in_channels: int, out_channels: int, activation=0, use_batch_norm: bool = None, **kwargs):
|
||||
super().__init__(activation, use_batch_norm)
|
||||
|
||||
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
|
||||
self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm1d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
|
|
@ -76,7 +81,7 @@ class Conv2d(Layer):
|
|||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None
|
||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.conv(input_data))
|
||||
|
|
@ -109,7 +114,7 @@ class Deconv2d(Layer):
|
|||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING if Layer.USE_BATCH_NORM else None
|
||||
track_running_stats=Layer.BATCH_NORM_TRAINING) if self.use_batch_norm else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.deconv(input_data))
|
||||
|
|
|
|||
74
residual.py
74
residual.py
|
|
@ -3,65 +3,51 @@ from typing import Union, Tuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .layers import LayerInfo, Layer
|
||||
from .layers import Conv2d, LayerInfo, Layer
|
||||
|
||||
|
||||
class ResBlock(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3,
|
||||
activation=None, **kwargs):
|
||||
def __init__(self, in_channels: int, out_channels: int = -1, kernel_size: int = 3, padding: int = 1,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=None, batch_norm=None, **kwargs):
|
||||
super().__init__(activation if activation is not None else 0, False)
|
||||
self.batch_norm = None
|
||||
if out_channels == -1:
|
||||
out_channels = in_channels
|
||||
|
||||
self.seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, bias=False, **kwargs),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, bias=False, padding=1),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
self.batch_norm = nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING) if self.batch_norm else None
|
||||
Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, **kwargs),
|
||||
Conv2d(in_channels, out_channels, kernel_size=3, padding=1,
|
||||
activation=None, batch_norm=batch_norm))
|
||||
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
|
||||
out_channels != in_channels or stride != 1) else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual is not None:
|
||||
return super().forward(self.residual(input_data) + self.seq(input_data))
|
||||
return super().forward(input_data + self.seq(input_data))
|
||||
|
||||
|
||||
class ResBottleneck(Layer):
|
||||
def __init__(self, in_channels: int, out_channels: int, planes: int = 1, kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, activation=None, **kwargs):
|
||||
def __init__(self, in_channels: int, out_channels: int = -1, bottleneck_channels: int = -1, kernel_size: int = 3,
|
||||
stride: Union[int, Tuple[int, int]] = 1, padding=1,
|
||||
activation=None, batch_norm=None, **kwargs):
|
||||
super().__init__(activation if activation is not None else 0, False)
|
||||
self.batch_norm = None
|
||||
if out_channels == -1:
|
||||
out_channels = in_channels
|
||||
if bottleneck_channels == -1:
|
||||
bottleneck_channels = in_channels // 4
|
||||
|
||||
self.seq = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False, **kwargs),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING),
|
||||
torch.nn.LeakyReLU(),
|
||||
nn.Conv2d(out_channels, planes * out_channels, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_channels, planes * out_channels, stride=stride, kernel_size=1),
|
||||
nn.BatchNorm2d(
|
||||
planes * out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
track_running_stats=not Layer.BATCH_NORM_TRAINING))
|
||||
Conv2d(in_channels, bottleneck_channels, kernel_size=1),
|
||||
Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=kernel_size,
|
||||
stride=stride, padding=padding, **kwargs),
|
||||
Conv2d(bottleneck_channels, out_channels, kernel_size=1,
|
||||
activation=None, batch_norm=batch_norm))
|
||||
self.residual = Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, activation=None) if (
|
||||
out_channels != in_channels or stride != 1) else None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(self.downsample(input_data) + self.seq(input_data))
|
||||
if self.residual is not None:
|
||||
return super().forward(self.residual(input_data) + self.seq(input_data))
|
||||
return super().forward(input_data + self.seq(input_data))
|
||||
|
|
|
|||
86
ssd/box.py
Normal file
86
ssd/box.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
def create_box(y_pos: float, x_pos: float, height: float, width: float) -> tuple[float, float, float, float]:
|
||||
y_min, x_min, y_max, x_max = check_rectangle(
|
||||
y_pos - (height / 2), x_pos - (width / 2), y_pos + (height / 2), x_pos + (width / 2))
|
||||
return (y_min + y_max) / 2, (x_min + x_max) / 2, y_max - y_min, x_max - x_min
|
||||
|
||||
|
||||
def check_rectangle(y_min: float, x_min: float, y_max: float, x_max: float) -> tuple[float, float, float, float]:
|
||||
if y_min < 0:
|
||||
y_min = 0
|
||||
if x_min < 0:
|
||||
x_min = 0
|
||||
if y_min > 1:
|
||||
y_min = 1
|
||||
if x_min > 1:
|
||||
x_min = 1
|
||||
if y_max < 0:
|
||||
y_max = 0
|
||||
if x_max < 0:
|
||||
x_max = 0
|
||||
if y_max >= 1:
|
||||
y_max = 1
|
||||
if x_max >= 1:
|
||||
x_max = 1
|
||||
return y_min, x_min, y_max, x_max
|
||||
|
||||
|
||||
def get_boxes(predictions: np.ndarray, anchors: np.ndarray, class_index: int) -> np.ndarray:
|
||||
boxes = np.zeros(anchors.shape)
|
||||
boxes[:, 0] = (predictions[:, 0] * anchors[:, 2]) + anchors[:, 0]
|
||||
boxes[:, 1] = (predictions[:, 1] * anchors[:, 3]) + anchors[:, 1]
|
||||
boxes[:, 2] = np.exp(predictions[:, 2]) * anchors[:, 2]
|
||||
boxes[:, 3] = np.exp(predictions[:, 3]) * anchors[:, 3]
|
||||
boxes = np.asarray([create_box(*box) for box in boxes])
|
||||
|
||||
# return np.insert(boxes, 4, predictions[:, class_index], axis=-1)
|
||||
return np.concatenate([boxes, predictions[:, class_index:class_index + 1]], axis=1)
|
||||
|
||||
|
||||
def fast_nms(boxes: np.ndarray, min_iou: float) -> np.ndarray:
|
||||
# if there are no boxes, return an empty list
|
||||
if len(boxes) == 0:
|
||||
return []
|
||||
|
||||
# initialize the list of picked indexes
|
||||
pick = []
|
||||
|
||||
# grab the coordinates of the bounding boxes
|
||||
y_min = boxes[:, 0] - (boxes[:, 2] / 2)
|
||||
y_max = boxes[:, 0] + (boxes[:, 2] / 2)
|
||||
x_min = boxes[:, 1] - (boxes[:, 3] / 2)
|
||||
x_max = boxes[:, 1] + (boxes[:, 3] / 2)
|
||||
scores = boxes[:, 4]
|
||||
|
||||
# compute the area of the bounding boxes and sort the bounding boxes by the scores
|
||||
areas = (x_max - x_min) * (y_max - y_min)
|
||||
idxs = np.argsort(scores)
|
||||
|
||||
# keep looping while some indexes still remain in the indexes
|
||||
# list
|
||||
while len(idxs) > 0:
|
||||
# grab the last index in the indexes list and add the
|
||||
# index value to the list of picked indexes
|
||||
last = len(idxs) - 1
|
||||
i = idxs[last]
|
||||
pick.append(i)
|
||||
|
||||
inter_tops = np.maximum(y_min[i], y_min[idxs[:last]])
|
||||
inter_bottoms = np.minimum(y_max[i], y_max[idxs[:last]])
|
||||
inter_lefts = np.maximum(x_min[i], x_min[idxs[:last]])
|
||||
inter_rights = np.minimum(x_max[i], x_max[idxs[:last]])
|
||||
inter_areas = (inter_rights - inter_lefts) * (inter_bottoms - inter_tops)
|
||||
|
||||
# compute the ratio of overlap
|
||||
union_area = (areas[idxs[:last]] + areas[i]) - inter_areas
|
||||
overlap = inter_areas / union_area
|
||||
|
||||
# delete all indexes from the index list that have less overlap than min_iou
|
||||
idxs = np.delete(
|
||||
idxs, np.concatenate(([last], np.where(overlap > min_iou)[0])))
|
||||
|
||||
# return only the bounding boxes that were picked using the
|
||||
# integer data type
|
||||
return boxes[pick]
|
||||
112
ssd/loss.py
Normal file
112
ssd/loss.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class JacardOverlap(nn.Module):
|
||||
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
|
||||
Jaccard overlap : A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||
Return:
|
||||
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
|
||||
"""
|
||||
anchors_count = anchors.size(0)
|
||||
labels_count = labels.size(0)
|
||||
|
||||
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
|
||||
anchor_coords = torch.cat([
|
||||
anchors[:, :2] - (anchors[:, 2:] / 2),
|
||||
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
|
||||
label_coords = torch.cat([
|
||||
labels[:, :2] - (labels[:, 2:] / 2),
|
||||
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
|
||||
|
||||
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
|
||||
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
|
||||
|
||||
inter_coords = torch.clamp(maxes - mins, min=0)
|
||||
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
|
||||
|
||||
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
|
||||
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
|
||||
|
||||
union_area = anchor_areas + label_areas - inter_area
|
||||
return inter_area / union_area
|
||||
|
||||
|
||||
class SSDLoss(nn.Module):
|
||||
def __init__(self, anchors: torch.Tensor, label_per_image: int,
|
||||
negative_mining_ratio: int, matching_iou: float,
|
||||
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
|
||||
super().__init__()
|
||||
self.anchors = anchors
|
||||
self.anchor_count = anchors.size(0)
|
||||
self.label_per_image = label_per_image
|
||||
self.location_dimmension = location_dimmension
|
||||
self.negative_mining_ratio = negative_mining_ratio
|
||||
self.matching_iou = matching_iou
|
||||
self.localization_loss_weight = localization_loss_weight
|
||||
|
||||
self.overlap = JacardOverlap()
|
||||
self.matches = []
|
||||
# self.negative_matches = []
|
||||
self.positive_class_loss = torch.Tensor()
|
||||
self.negative_class_loss = torch.Tensor()
|
||||
self.localization_loss = torch.Tensor()
|
||||
self.class_loss = torch.Tensor()
|
||||
self.final_loss = torch.Tensor()
|
||||
|
||||
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = input_data.size(0)
|
||||
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
|
||||
batch_size, self.anchor_count, self.label_per_image, 4)
|
||||
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
|
||||
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
|
||||
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
|
||||
expanded_anchors[:, :, :, 2:])
|
||||
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
|
||||
|
||||
positive_objectives = []
|
||||
positive_predictions = []
|
||||
positive_class_loss = []
|
||||
negative_class_loss = []
|
||||
self.matches = []
|
||||
# self.negative_matches = []
|
||||
for batch_index in range(batch_size):
|
||||
predictions = input_data[batch_index]
|
||||
labels = input_labels[batch_index]
|
||||
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
|
||||
mask = (overlaps >= self.matching_iou).long()
|
||||
match_indices = torch.nonzero(mask, as_tuple=False)
|
||||
self.matches.append(match_indices.detach().cpu())
|
||||
|
||||
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
|
||||
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
|
||||
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
|
||||
# self.negative_matches.append(non_match_indices.detach().cpu())
|
||||
|
||||
for anchor_index, label_index in match_indices:
|
||||
positive_predictions.append(predictions[anchor_index])
|
||||
positive_objectives.append(
|
||||
torch.cat((
|
||||
objective_pos[batch_index, anchor_index, label_index],
|
||||
objective_size[batch_index, anchor_index, label_index]), dim=-1))
|
||||
positive_class_loss.append(torch.log(
|
||||
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
|
||||
|
||||
for anchor_index in non_match_indices:
|
||||
negative_class_loss.append(
|
||||
torch.log(predictions[anchor_index, self.location_dimmension]))
|
||||
|
||||
if not positive_predictions:
|
||||
return None
|
||||
positive_predictions = torch.stack(positive_predictions)
|
||||
positive_objectives = torch.stack(positive_objectives)
|
||||
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
|
||||
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
|
||||
self.localization_loss = nn.functional.smooth_l1_loss(
|
||||
positive_predictions[:, self.location_dimmension],
|
||||
positive_objectives)
|
||||
self.class_loss = self.positive_class_loss + self.negative_class_loss
|
||||
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
|
||||
return self.final_loss
|
||||
165
ssd/ssd.py
Normal file
165
ssd/ssd.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
import colorsys
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .box import check_rectangle
|
||||
from ..layers import Conv2d
|
||||
|
||||
|
||||
class SSD(nn.Module):
|
||||
|
||||
class Detector(nn.Module):
|
||||
def __init__(self, input_features: int, output_features: int):
|
||||
super().__init__()
|
||||
self.conv = Conv2d(input_features, output_features, kernel_size=3, padding=1,
|
||||
batch_norm=False, activation=None)
|
||||
self.output = None
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
self.output = self.conv(input_data).permute(0, 2, 3, 1)
|
||||
return self.output
|
||||
|
||||
class DetectorMerge(nn.Module):
|
||||
def __init__(self, location_dimmension: int):
|
||||
super().__init__()
|
||||
self.location_dim = location_dimmension
|
||||
|
||||
def forward(self, detector_outputs: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cat(
|
||||
[detector_outputs[:, :, :self.location_dim],
|
||||
torch.softmax(detector_outputs[:, :, self.location_dim:], dim=2)], dim=2)
|
||||
|
||||
class AnchorInfo:
|
||||
def __init__(self, center: tuple[float, float], size: tuple[float],
|
||||
index: int, layer_index: int, map_index: tuple[int, int], color_index: int,
|
||||
ratio: float, size_factor: float):
|
||||
self.index = index
|
||||
self.layer_index = layer_index
|
||||
self.map_index = map_index
|
||||
self.color_index = color_index
|
||||
self.ratio = ratio
|
||||
self.size_factor = size_factor
|
||||
self.center = center
|
||||
self.size = size
|
||||
self.box = check_rectangle(
|
||||
center[0] - (size[0] / 2), center[1] - (size[1] / 2),
|
||||
center[0] + (size[0] / 2), center[1] + (size[1] / 2))
|
||||
|
||||
def __repr__(self):
|
||||
return (f'{self.__class__.__name__}'
|
||||
f'(index:{self.index}, layer:{self.layer_index}, coord:{self.map_index}'
|
||||
f', center:({self.center[0]:.03f}, {self.center[1]:.03f})'
|
||||
f', size:({self.size[0]:.03f}, {self.size[1]:.03f})'
|
||||
f', ratio:{self.ratio:.03f}, size_factor:{self.size_factor:.03f})'
|
||||
f', y:[{self.box[0]:.03f}:{self.box[2]:.03f}]'
|
||||
f', x:[{self.box[1]:.03f}:{self.box[3]:.03f}])')
|
||||
|
||||
def __array__(self):
|
||||
return np.array([*self.center, *self.size])
|
||||
|
||||
def __init__(self, base_network: nn.Module, input_sample: torch.Tensor, classes: list[str],
|
||||
location_dimmension: int, layer_channels: list[int], layer_box_ratios: list[float], layer_args: dict,
|
||||
box_size_factors: list[float]):
|
||||
super().__init__()
|
||||
|
||||
self.location_dim = location_dimmension
|
||||
self.classes = ['none'] + classes
|
||||
self.class_count = len(self.classes)
|
||||
self.base_input_shape = input_sample.numpy().shape[1:]
|
||||
self.base_network = base_network
|
||||
sample_output = base_network(input_sample)
|
||||
self.base_output_shape = list(sample_output.detach().numpy().shape)[-3:]
|
||||
|
||||
layer_convs: list[nn.Module] = []
|
||||
layer_detectors: list[SSD.Detector] = []
|
||||
last_feature_count = self.base_output_shape[0]
|
||||
for layer_index, (output_features, kwargs) in enumerate(zip(layer_channels, layer_args)):
|
||||
if 'disable' not in kwargs:
|
||||
layer_convs.append(Conv2d(last_feature_count, output_features, **kwargs))
|
||||
layer_detectors.append(SSD.Detector(
|
||||
last_feature_count, (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index])))
|
||||
# layers.append(SSD.Layer(
|
||||
# last_feature_count, output_features,
|
||||
# (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index]),
|
||||
# **kwargs))
|
||||
last_feature_count = output_features
|
||||
self.layer_convs = nn.ModuleList(layer_convs)
|
||||
self.layer_detectors = nn.ModuleList(layer_detectors)
|
||||
|
||||
self.merge = self.DetectorMerge(location_dimmension)
|
||||
|
||||
self.anchors_numpy, self.anchor_info, self.box_colors = self._create_anchors(
|
||||
sample_output, self.layer_convs, self.layer_detectors, layer_box_ratios, box_size_factors,
|
||||
input_sample.shape[3] / input_sample.shape[2])
|
||||
self.anchors = torch.from_numpy(self.anchors_numpy)
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
head = self.base_network(input_data)
|
||||
detector_outputs = []
|
||||
for layer_index, detector in enumerate(self.layer_detectors):
|
||||
detector_out = detector(head)
|
||||
detector_outputs.append(detector_out.reshape(
|
||||
detector_out.size(0), -1, self.class_count + self.location_dim))
|
||||
if layer_index < len(self.layer_convs):
|
||||
head = self.layer_convs[layer_index](head)
|
||||
detector_outputs = torch.cat(detector_outputs, 1)
|
||||
return self.merge(detector_outputs)
|
||||
# base_output = self.base_network(input_data)
|
||||
# head = base_output
|
||||
# outputs = []
|
||||
# for layer in self.layers:
|
||||
# head, detector_output = layer(head)
|
||||
# outputs.append(detector_output.reshape(base_output.size(0), -1, self.class_count + self.location_dim))
|
||||
# outputs = torch.cat(outputs, 1)
|
||||
# return torch.cat(
|
||||
# [outputs[:, :, :self.location_dim], torch.softmax(outputs[:, :, self.location_dim:], dim=2)], dim=2)
|
||||
|
||||
def _apply(self, fn):
|
||||
super()._apply(fn)
|
||||
self.anchors = fn(self.anchors)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _create_anchors(
|
||||
base_output: torch.Tensor, layers: nn.ModuleList, detectors: nn.ModuleList, layer_box_ratios: list[float],
|
||||
box_size_factors: list[float], image_ratio: float) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]:
|
||||
anchors = []
|
||||
anchor_info: list[SSD.AnchorInfo] = []
|
||||
box_colors: list[np.ndarray] = []
|
||||
head = base_output
|
||||
|
||||
for layer_index, detector in enumerate(detectors):
|
||||
detector_output = detector(head) # detector output shape : NCRSHW (Ratio, Size)
|
||||
if layer_index < len(layers):
|
||||
head = layers[layer_index](head)
|
||||
|
||||
detector_rows = detector_output.size()[1]
|
||||
detector_cols = detector_output.size()[2]
|
||||
color_index = 0
|
||||
layer_ratios = layer_box_ratios[layer_index]
|
||||
for index_y in range(detector_rows):
|
||||
center_y = (index_y + 0.5) / detector_rows
|
||||
for index_x in range(detector_cols):
|
||||
center_x = (index_x + 0.5) / detector_cols
|
||||
for ratio, size_factor in zip(layer_ratios, box_size_factors):
|
||||
box_colors.append((np.asarray(colorsys.hsv_to_rgb(
|
||||
color_index / len(layer_ratios), 1.0, 1.0)) * 255).astype(np.uint8))
|
||||
color_index += 1
|
||||
unit_box_size = size_factor / max(detector_rows, detector_cols)
|
||||
anchor_width = unit_box_size * math.sqrt(ratio / image_ratio)
|
||||
anchor_height = unit_box_size / math.sqrt(ratio / image_ratio)
|
||||
anchor_info.append(SSD.AnchorInfo(
|
||||
(center_y, center_x),
|
||||
(anchor_height, anchor_width),
|
||||
len(anchors),
|
||||
layer_index,
|
||||
(index_y, index_x),
|
||||
len(box_colors) - 1,
|
||||
ratio,
|
||||
size_factor
|
||||
))
|
||||
anchors.append([center_y, center_x, anchor_height, anchor_width])
|
||||
return np.asarray(anchors, dtype=np.float32), anchor_info, box_colors
|
||||
21
train.py
21
train.py
|
|
@ -28,11 +28,26 @@ def parameter_summary(network: torch.nn.Module) -> List[Tuple[str, Tuple[int], s
|
|||
|
||||
def resource_usage() -> Tuple[int, str]:
|
||||
memory_peak = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
|
||||
return memory_peak, gpu_used_memory()
|
||||
|
||||
|
||||
def gpu_used_memory() -> str:
|
||||
gpu_memory = subprocess.check_output(
|
||||
'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode()
|
||||
'nvidia-smi --query-gpu=memory.used --format=csv,noheader', shell=True).decode().strip()
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
|
||||
else:
|
||||
gpu_memory = ' '.join(gpu_memory.split('\n'))
|
||||
gpu_memory = ','.join(gpu_memory.split('\n'))
|
||||
|
||||
return memory_peak, gpu_memory
|
||||
return gpu_memory
|
||||
|
||||
|
||||
def gpu_total_memory() -> str:
|
||||
gpu_memory = subprocess.check_output(
|
||||
'nvidia-smi --query-gpu=memory.total --format=csv,noheader', shell=True).decode().strip()
|
||||
if 'CUDA_VISIBLE_DEVICES' in os.environ:
|
||||
gpu_memory = gpu_memory.split('\n')[int(os.environ['CUDA_VISIBLE_DEVICES'])]
|
||||
else:
|
||||
gpu_memory = ','.join(gpu_memory.split('\n'))
|
||||
|
||||
return gpu_memory
|
||||
|
|
|
|||
120
trainer.py
120
trainer.py
|
|
@ -23,12 +23,13 @@ class Trainer:
|
|||
epoch_skip: int, summary_per_epoch: int, image_per_epoch: int,
|
||||
data_dtype=None, label_dtype=None,
|
||||
train_pre_process: Optional[nn.Module] = None, data_is_label: bool = False,
|
||||
logger=DummyLogger()):
|
||||
logger=DummyLogger(), verbose=True, save_src=True):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.output_dir = output_dir
|
||||
self.data_is_label = data_is_label
|
||||
self.logger = logger
|
||||
self.verbose = verbose
|
||||
self.should_stop = False
|
||||
|
||||
self.batch_generator_train = batch_generator_train
|
||||
|
|
@ -42,8 +43,9 @@ class Trainer:
|
|||
self.network = network
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'train'), flush_secs=30)
|
||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'val'), flush_secs=30)
|
||||
self.accuracy_fn: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None
|
||||
self.writer_train = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'train'), flush_secs=30)
|
||||
self.writer_val = SummaryWriter(log_dir=os.path.join(output_dir, 'tensorboard', 'val'), flush_secs=30)
|
||||
|
||||
# Save network graph
|
||||
batch_inputs = torch.as_tensor(batch_generator_train.batch_data[:2], dtype=data_dtype, device=device)
|
||||
|
|
@ -74,14 +76,15 @@ class Trainer:
|
|||
if summary_per_epoch % image_per_epoch == 0:
|
||||
self.image_period = self.summary_period * (summary_per_epoch // image_per_epoch)
|
||||
|
||||
torch.save(network, os.path.join(output_dir, 'model_init.pt'))
|
||||
torch.save(network.state_dict(), os.path.join(output_dir, 'model_init.pt'))
|
||||
# Save source files
|
||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||
os.path.join('src', '**', '*.py'), recursive=True):
|
||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||
if save_src:
|
||||
for entry in glob.glob(os.path.join('config', '**', '*.py'), recursive=True) + glob.glob(
|
||||
os.path.join('src', '**', '*.py'), recursive=True) + glob.glob('*.py'):
|
||||
dirname = os.path.join(output_dir, 'code', os.path.dirname(entry))
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
shutil.copy2(entry, os.path.join(output_dir, 'code', entry))
|
||||
|
||||
# Initialize training loop variables
|
||||
self.batch_inputs = batch_inputs
|
||||
|
|
@ -89,23 +92,31 @@ class Trainer:
|
|||
self.processed_inputs = processed_inputs
|
||||
self.network_outputs = processed_inputs # Placeholder
|
||||
self.train_loss = 0.0
|
||||
self.train_accuracy = 0.0
|
||||
self.running_loss = 0.0
|
||||
self.running_accuracy = 0.0
|
||||
self.running_count = 0
|
||||
self.benchmark_step = 0
|
||||
self.benchmark_time = time.time()
|
||||
|
||||
def end_epoch_callback(self):
|
||||
pass
|
||||
|
||||
def train_step_callback(
|
||||
self,
|
||||
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||
loss: float):
|
||||
loss: float, accuracy: float):
|
||||
pass
|
||||
|
||||
def val_step_callback(
|
||||
self,
|
||||
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||
batch_labels: torch.Tensor, network_outputs: torch.Tensor,
|
||||
loss: float):
|
||||
loss: float, accuracy: float):
|
||||
pass
|
||||
|
||||
def pre_summary_callback(self):
|
||||
pass
|
||||
|
||||
def summary_callback(
|
||||
|
|
@ -129,16 +140,17 @@ class Trainer:
|
|||
try:
|
||||
while not self.should_stop and self.batch_generator_train.epoch < epochs:
|
||||
epoch = self.batch_generator_train.epoch
|
||||
print()
|
||||
print(' ' * os.get_terminal_size()[0], end='\r')
|
||||
print(f'Epoch {self.batch_generator_train.epoch}')
|
||||
if self.verbose:
|
||||
print()
|
||||
print(' ' * os.get_terminal_size()[0], end='\r')
|
||||
print(f'Epoch {self.batch_generator_train.epoch}')
|
||||
while not self.should_stop and epoch == self.batch_generator_train.epoch:
|
||||
self.batch_inputs = torch.as_tensor(
|
||||
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
||||
self.batch_labels = torch.as_tensor(
|
||||
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
||||
|
||||
if self.benchmark_step > 1:
|
||||
if self.verbose and self.benchmark_step > 1:
|
||||
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||
print(
|
||||
f'Step {self.batch_generator_train.global_step}, {speed:0.02f} steps/s'
|
||||
|
|
@ -150,35 +162,60 @@ class Trainer:
|
|||
|
||||
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||
self.network_outputs = self.network(self.processed_inputs)
|
||||
loss = self.criterion(
|
||||
self.network_outputs,
|
||||
self.batch_labels if not self.data_is_label else self.processed_inputs)
|
||||
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
||||
loss = self.criterion(self.network_outputs, labels)
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
self.train_loss = loss.item()
|
||||
self.train_accuracy = self.accuracy_fn(
|
||||
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
|
||||
self.running_loss += self.train_loss
|
||||
self.running_accuracy += self.train_accuracy
|
||||
self.running_count += len(self.batch_generator_train.batch_data)
|
||||
self.train_step_callback(
|
||||
self.batch_inputs, self.processed_inputs, self.batch_labels,
|
||||
self.network_outputs, self.train_loss)
|
||||
self.batch_inputs, self.processed_inputs, labels,
|
||||
self.network_outputs, self.train_loss, self.train_accuracy)
|
||||
|
||||
self.benchmark_step += 1
|
||||
self.save_summaries()
|
||||
self.batch_generator_train.next_batch()
|
||||
self.end_epoch_callback()
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
if self.verbose:
|
||||
print()
|
||||
|
||||
# Small training loop for last metrics
|
||||
for _ in range(20):
|
||||
self.batch_inputs = torch.as_tensor(
|
||||
self.batch_generator_train.batch_data, dtype=self.data_dtype, device=self.device)
|
||||
self.batch_labels = torch.as_tensor(
|
||||
self.batch_generator_train.batch_label, dtype=self.label_dtype, device=self.device)
|
||||
|
||||
self.processed_inputs = self.train_pre_process(self.batch_inputs)
|
||||
self.network_outputs = self.network(self.processed_inputs)
|
||||
labels = self.batch_labels if not self.data_is_label else self.processed_inputs
|
||||
loss = self.criterion(self.network_outputs, labels)
|
||||
|
||||
self.train_loss = loss.item()
|
||||
self.train_accuracy = self.accuracy_fn(
|
||||
self.network_outputs, labels).item() if self.accuracy_fn is not None else 0.0
|
||||
self.running_loss += self.train_loss
|
||||
self.running_accuracy += self.train_accuracy
|
||||
self.running_count += len(self.batch_generator_train.batch_data)
|
||||
self.benchmark_step += 1
|
||||
self.batch_generator_train.next_batch()
|
||||
self.save_summaries(force_summary=True)
|
||||
train_stop_time = time.time()
|
||||
self.writer_train.close()
|
||||
torch.save(self.network, os.path.join(self.output_dir, 'model_final.pt'))
|
||||
self.writer_val.close()
|
||||
|
||||
memory_peak, gpu_memory = resource_usage()
|
||||
self.logger.info(
|
||||
f'Training time : {train_stop_time - train_start_time:.03f}s\n'
|
||||
f'\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}')
|
||||
|
||||
def save_summaries(self):
|
||||
def save_summaries(self, force_summary=False):
|
||||
global_step = self.batch_generator_train.global_step
|
||||
if self.batch_generator_train.epoch < self.epoch_skip:
|
||||
return
|
||||
|
|
@ -188,7 +225,9 @@ class Trainer:
|
|||
|
||||
if self.batch_generator_val.step != 0:
|
||||
self.batch_generator_val.skip_epoch()
|
||||
self.pre_summary_callback()
|
||||
val_loss = 0.0
|
||||
val_accuracy = 0.0
|
||||
val_count = 0
|
||||
self.network.train(False)
|
||||
with torch.no_grad():
|
||||
|
|
@ -201,23 +240,30 @@ class Trainer:
|
|||
|
||||
val_pre_process = self.pre_process(val_inputs)
|
||||
val_outputs = self.network(val_pre_process)
|
||||
loss = self.criterion(
|
||||
val_outputs,
|
||||
val_labels if not self.data_is_label else val_pre_process).item()
|
||||
val_labels = val_labels if not self.data_is_label else val_pre_process
|
||||
loss = self.criterion(val_outputs, val_labels).item()
|
||||
accuracy = self.accuracy_fn(
|
||||
val_outputs, val_labels).item() if self.accuracy_fn is not None else 0.0
|
||||
val_loss += loss
|
||||
val_accuracy += accuracy
|
||||
val_count += len(self.batch_generator_val.batch_data)
|
||||
self.val_step_callback(
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, loss)
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, loss, accuracy)
|
||||
|
||||
self.batch_generator_val.next_batch()
|
||||
self.network.train(True)
|
||||
|
||||
# Add summaries
|
||||
if self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||
if force_summary or self.batch_generator_train.step % self.summary_period == (self.summary_period - 1):
|
||||
self.writer_train.add_scalar(
|
||||
'loss', self.running_loss / self.running_count, global_step=global_step)
|
||||
self.writer_val.add_scalar(
|
||||
'loss', val_loss / val_count, global_step=global_step)
|
||||
if self.accuracy_fn is not None:
|
||||
self.writer_train.add_scalar(
|
||||
'error', 1 - (self.running_accuracy / self.running_count), global_step=global_step)
|
||||
self.writer_val.add_scalar(
|
||||
'error', 1 - (val_accuracy / val_count), global_step=global_step)
|
||||
self.summary_callback(
|
||||
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs, self.running_count,
|
||||
val_inputs, val_pre_process, val_labels, val_outputs, val_count)
|
||||
|
|
@ -228,12 +274,18 @@ class Trainer:
|
|||
self.batch_inputs, self.processed_inputs, self.batch_labels, self.network_outputs,
|
||||
val_inputs, val_pre_process, val_labels, val_outputs)
|
||||
|
||||
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||
print(f'Step {global_step}, '
|
||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||
torch.save(self.network, os.path.join(self.output_dir, f'model_{global_step}.pt'))
|
||||
if self.verbose:
|
||||
speed = self.benchmark_step / (time.time() - self.benchmark_time)
|
||||
print(f'Step {global_step}, '
|
||||
f'loss {self.running_loss / self.running_count:.03e} {val_loss / val_count:.03e}, '
|
||||
f'{speed:0.02f} steps/s, {speed * self.batch_generator_train.batch_size:0.02f} input/sec')
|
||||
if self.accuracy_fn is not None:
|
||||
torch.save(self.network.state_dict(), os.path.join(
|
||||
self.output_dir, f'step_{global_step}_acc_{val_accuracy / val_count:.04f}.pt'))
|
||||
else:
|
||||
torch.save(self.network.state_dict(), os.path.join(self.output_dir, f'step_{global_step}.pt'))
|
||||
self.benchmark_time = time.time()
|
||||
self.benchmark_step = 0
|
||||
self.running_loss = 0.0
|
||||
self.running_accuracy = 0.0
|
||||
self.running_count = 0
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ class BatchGenerator:
|
|||
|
||||
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
pipeline: Optional[Callable] = None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
flip_data=False, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
|
|
@ -18,6 +19,7 @@ class BatchGenerator:
|
|||
self.prefetch = prefetch and not preload
|
||||
self.num_workers = num_workers
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -59,15 +61,22 @@ class BatchGenerator:
|
|||
self.last_batch_size = len(self.index_list) % self.batch_size
|
||||
if self.last_batch_size == 0:
|
||||
self.last_batch_size = self.batch_size
|
||||
else:
|
||||
self.step_per_epoch += 1
|
||||
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
self.step = 0
|
||||
|
||||
first_data = np.array([data_processor(entry) if data_processor else entry
|
||||
for entry in self.data[self.index_list[:batch_size]]])
|
||||
first_label = np.array([label_processor(entry) if label_processor else entry
|
||||
for entry in self.label[self.index_list[:batch_size]]])
|
||||
first_data = [data_processor(entry) if data_processor else entry
|
||||
for entry in self.data[self.index_list[:batch_size]]]
|
||||
first_label = [label_processor(entry) if label_processor else entry
|
||||
for entry in self.label[self.index_list[:batch_size]]]
|
||||
if self.pipeline is not None:
|
||||
for data_index, sample_data in enumerate(first_data):
|
||||
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
|
||||
first_data = np.asarray(first_data)
|
||||
first_label = np.asarray(first_label)
|
||||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
|
|
@ -194,10 +203,8 @@ class BatchGenerator:
|
|||
self.current_cache = 0
|
||||
parent_cache_data = self.cache_data
|
||||
parent_cache_label = self.cache_label
|
||||
cache_data = np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)
|
||||
cache_label = np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)
|
||||
self.cache_data = [cache_data]
|
||||
self.cache_label = [cache_label]
|
||||
self.cache_data = [np.ndarray(self.cache_data[0].shape, dtype=self.cache_data[0].dtype)]
|
||||
self.cache_label = [np.ndarray(self.cache_label[0].shape, dtype=self.cache_label[0].dtype)]
|
||||
self.index_list[:] = self.cache_indices
|
||||
pipe = self.worker_pipes[worker_index][1]
|
||||
|
||||
|
|
@ -208,8 +215,6 @@ class BatchGenerator:
|
|||
continue
|
||||
self.index_list = self.cache_indices[start_index:start_index + self.batch_size].copy()
|
||||
|
||||
self.cache_data[0] = cache_data[:self.batch_size]
|
||||
self.cache_label[0] = cache_label[:self.batch_size]
|
||||
self._next_batch()
|
||||
parent_cache_data[current_cache][batch_index:batch_index + self.batch_size] = self.cache_data[
|
||||
self.current_cache][:self.batch_size]
|
||||
|
|
@ -263,31 +268,37 @@ class BatchGenerator:
|
|||
data = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
data.append(self.data_processor(self.data[entry]))
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = self.data[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
|
||||
data = data[:, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
|
||||
data = data[:, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
data = data[:, :, ::-1, ::-1]
|
||||
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
label = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
label.append(self.label_processor(self.label[entry]))
|
||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||
label = np.asarray(label)
|
||||
else:
|
||||
label = self.label[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for data_index, data_entry in enumerate(data):
|
||||
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
|
||||
self.cache_data[self.current_cache][data_index] = piped_data
|
||||
self.cache_label[self.current_cache][data_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
self.cache_label[self.current_cache][:len(label)] = label
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
|
|
@ -338,7 +349,6 @@ if __name__ == '__main__':
|
|||
for _ in range(19):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
raise KeyboardInterrupt
|
||||
print()
|
||||
|
||||
test()
|
||||
|
|
|
|||
138
utils/ipc_data_generator.py
Normal file
138
utils/ipc_data_generator.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import multiprocessing as mp
|
||||
from multiprocessing import shared_memory
|
||||
import signal
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class IPCBatchGenerator:
|
||||
|
||||
def __init__(self, ipc_processor: Callable,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
pipeline: Optional[Callable] = None,
|
||||
prefetch=True, flip_data=False):
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
self.prefetch = prefetch
|
||||
|
||||
self.ipc_processor = ipc_processor
|
||||
self.data_processor = data_processor
|
||||
self.label_processor = label_processor
|
||||
|
||||
self.global_step = 0
|
||||
|
||||
self.data, self.label = ipc_processor()
|
||||
first_data = [data_processor(entry) for entry in self.data] if data_processor else self.data
|
||||
first_label = [label_processor(entry) for entry in self.label] if label_processor else self.label
|
||||
if self.pipeline is not None:
|
||||
for data_index, sample_data in enumerate(first_data):
|
||||
first_data[data_index], first_label[data_index] = self.pipeline(sample_data, first_label[data_index])
|
||||
first_data = np.asarray(first_data)
|
||||
first_label = np.asarray(first_label)
|
||||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
self.process_id = 'NA'
|
||||
if self.prefetch:
|
||||
self.cache_memory_data = [
|
||||
shared_memory.SharedMemory(create=True, size=first_data.nbytes),
|
||||
shared_memory.SharedMemory(create=True, size=first_data.nbytes)]
|
||||
self.cache_data = [
|
||||
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[0].buf),
|
||||
np.ndarray(first_data.shape, dtype=first_data.dtype, buffer=self.cache_memory_data[1].buf)]
|
||||
self.cache_memory_label = [
|
||||
shared_memory.SharedMemory(create=True, size=first_label.nbytes),
|
||||
shared_memory.SharedMemory(create=True, size=first_label.nbytes)]
|
||||
self.cache_label = [
|
||||
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[0].buf),
|
||||
np.ndarray(first_label.shape, dtype=first_label.dtype, buffer=self.cache_memory_label[1].buf)]
|
||||
self.prefetch_pipe_parent, self.prefetch_pipe_child = mp.Pipe()
|
||||
self.prefetch_stop = shared_memory.SharedMemory(create=True, size=1)
|
||||
self.prefetch_stop.buf[0] = 0
|
||||
self.prefetch_process = mp.Process(target=self._prefetch_worker)
|
||||
self.prefetch_process.start()
|
||||
else:
|
||||
self.cache_data = [first_data]
|
||||
self.cache_label = [first_label]
|
||||
|
||||
self.current_cache = 0
|
||||
self.process_id = 'main'
|
||||
|
||||
def __del__(self):
|
||||
self.release()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, _exc_type, _exc_value, _traceback):
|
||||
self.release()
|
||||
|
||||
def release(self):
|
||||
if self.prefetch:
|
||||
self.prefetch_stop.buf[0] = 1
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.prefetch_process.join()
|
||||
|
||||
for shared_mem in self.cache_memory_data + self.cache_memory_label:
|
||||
shared_mem.close()
|
||||
shared_mem.unlink()
|
||||
self.prefetch_stop.close()
|
||||
self.prefetch_stop.unlink()
|
||||
self.prefetch = False # Avoids double release
|
||||
|
||||
def _prefetch_worker(self):
|
||||
self.prefetch = False
|
||||
self.current_cache = 1
|
||||
self.process_id = 'prefetch'
|
||||
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
while self.prefetch_stop.buf is not None and self.prefetch_stop.buf[0] == 0:
|
||||
self.current_cache = 1 - self.current_cache
|
||||
self.global_step += 1
|
||||
|
||||
self._next_batch()
|
||||
self.prefetch_pipe_child.recv()
|
||||
self.prefetch_pipe_child.send(self.current_cache)
|
||||
|
||||
def _next_batch(self):
|
||||
# Loading data
|
||||
self.data, self.label = self.ipc_processor()
|
||||
|
||||
data = np.asarray([self.data_processor(entry) for entry in self.data]) if self.data_processor else self.data
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
data = data[:, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
data = data[:, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
data = data[:, :, ::-1, ::-1]
|
||||
|
||||
# Loading label
|
||||
label = np.asarray([
|
||||
self.label_processor(entry) for entry in self.label]) if self.label_processor else self.label
|
||||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for data_index, data_entry in enumerate(data):
|
||||
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
|
||||
self.cache_data[self.current_cache][data_index] = piped_data
|
||||
self.cache_label[self.current_cache][data_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
self.cache_label[self.current_cache][:len(label)] = label
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
self.global_step += 1
|
||||
|
||||
if self.prefetch:
|
||||
self.prefetch_pipe_parent.send(True)
|
||||
self.current_cache = self.prefetch_pipe_parent.recv()
|
||||
else:
|
||||
self._next_batch()
|
||||
|
||||
self.batch_data = self.cache_data[self.current_cache]
|
||||
self.batch_label = self.cache_label[self.current_cache]
|
||||
|
||||
return self.batch_data, self.batch_label
|
||||
|
|
@ -6,3 +6,22 @@ def human_size(byte_count: int) -> str:
|
|||
break
|
||||
amount /= 1024.0
|
||||
return f'{amount:.2f}{unit}B'
|
||||
|
||||
|
||||
def human_to_bytes(text: str) -> float:
|
||||
split_index = 0
|
||||
while '0' <= text[split_index] <= '9':
|
||||
split_index += 1
|
||||
if split_index == len(text):
|
||||
return float(text)
|
||||
amount = float(text[:split_index])
|
||||
unit = text[split_index:].strip()
|
||||
|
||||
if not unit:
|
||||
return amount
|
||||
if unit not in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
|
||||
raise RuntimeError(f'Unrecognized unit : {unit}')
|
||||
for final_unit in ['KiB', 'MiB', 'GiB', 'TiB', 'PiB', 'EiB', 'ZiB', 'YiB']:
|
||||
amount *= 1024.0
|
||||
if unit == final_unit:
|
||||
return amount
|
||||
|
|
|
|||
|
|
@ -16,14 +16,15 @@ class SequenceGenerator(BatchGenerator):
|
|||
|
||||
def __init__(self, data: Iterable, label: Iterable, sequence_size: int, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
pipeline: Optional[Callable] = None, index_list=None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
flip_data=False, save: Optional[str] = None):
|
||||
sequence_stride=1, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
self.sequence_size = sequence_size
|
||||
self.shuffle = shuffle
|
||||
self.prefetch = prefetch and not preload
|
||||
self.num_workers = num_workers
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -70,14 +71,26 @@ class SequenceGenerator(BatchGenerator):
|
|||
h5_file.create_dataset(f'data_{sequence_index}', data=self.data[sequence_index])
|
||||
h5_file.create_dataset(f'label_{sequence_index}', data=self.label[sequence_index])
|
||||
|
||||
self.index_list = []
|
||||
for sequence_index in range(len(self.data)):
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
||||
axis=-1)
|
||||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||
self.index_list.append(start_indices)
|
||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||
if index_list is not None:
|
||||
self.index_list = index_list
|
||||
else:
|
||||
self.index_list = []
|
||||
for sequence_index in range(len(self.data)):
|
||||
if sequence_stride > 1:
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(0,
|
||||
len(self.data[sequence_index]) - sequence_size + 1,
|
||||
sequence_stride,
|
||||
dtype=np.uint32),
|
||||
axis=-1)
|
||||
else:
|
||||
start_indices = np.expand_dims(
|
||||
np.arange(len(self.data[sequence_index]) - sequence_size + 1, dtype=np.uint32),
|
||||
axis=-1)
|
||||
start_indices = np.insert(start_indices, 0, sequence_index, axis=1)
|
||||
self.index_list.append(start_indices)
|
||||
self.index_list = np.concatenate(self.index_list, axis=0)
|
||||
|
||||
if shuffle or initial_shuffle:
|
||||
np.random.shuffle(self.index_list)
|
||||
self.step_per_epoch = len(self.index_list) // self.batch_size
|
||||
|
|
@ -95,26 +108,28 @@ class SequenceGenerator(BatchGenerator):
|
|||
first_data.append(
|
||||
[data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
first_data = np.asarray(first_data)
|
||||
else:
|
||||
first_data = []
|
||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||
first_data.append(
|
||||
self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||
first_data = np.asarray(first_data)
|
||||
if label_processor:
|
||||
first_label = []
|
||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||
first_label.append(
|
||||
[label_processor(input_label)
|
||||
for input_label in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
first_label = np.asarray(first_label)
|
||||
else:
|
||||
first_label = []
|
||||
for sequence_index, start_index in self.index_list[:batch_size]:
|
||||
first_label.append(
|
||||
self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||
first_label = np.asarray(first_label)
|
||||
if self.pipeline is not None:
|
||||
for batch_index, (data_sequence, label_sequence) in enumerate(zip(first_data, first_label)):
|
||||
first_data[batch_index], first_label[batch_index] = self.pipeline(
|
||||
np.asarray(data_sequence), np.asarray(label_sequence))
|
||||
first_data = np.asarray(first_data)
|
||||
first_label = np.asarray(first_label)
|
||||
self.batch_data = first_data
|
||||
self.batch_label = first_label
|
||||
|
||||
|
|
@ -165,37 +180,14 @@ class SequenceGenerator(BatchGenerator):
|
|||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
data.append(
|
||||
[self.data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
np.asarray(
|
||||
[self.data_processor(input_data)
|
||||
for input_data in self.data[sequence_index][start_index: start_index + self.sequence_size]]))
|
||||
else:
|
||||
data = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
data.append(self.data[sequence_index][start_index: start_index + self.sequence_size])
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)[:, :, :, ::-1, ::-1]
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
|
|
@ -203,14 +195,23 @@ class SequenceGenerator(BatchGenerator):
|
|||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
label.append(
|
||||
[self.label_processor(input_data)
|
||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]])
|
||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||
np.asarray(
|
||||
[self.label_processor(input_data)
|
||||
for input_data in self.label[sequence_index][start_index: start_index + self.sequence_size]]))
|
||||
else:
|
||||
label = []
|
||||
for sequence_index, start_index in self.index_list[
|
||||
self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
label.append(self.label[sequence_index][start_index: start_index + self.sequence_size])
|
||||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for batch_index in range(len(data)):
|
||||
piped_data, piped_label = self.pipeline(data[batch_index], label[batch_index])
|
||||
self.cache_data[self.current_cache][batch_index] = piped_data
|
||||
self.cache_label[self.current_cache][batch_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||
|
||||
|
||||
|
|
@ -219,17 +220,22 @@ if __name__ == '__main__':
|
|||
data = np.array(
|
||||
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [11, 12, 13, 14, 15, 16, 17, 18, 19]], dtype=np.uint8)
|
||||
label = np.array(
|
||||
[[.1, .2, .3, .4, .5, .6, .7, .8, .9], [.11, .12, .13, .14, .15, .16, .17, .18, .19]], dtype=np.uint8)
|
||||
[[10, 20, 30, 40, 50, 60, 70, 80, 90], [110, 120, 130, 140, 150, 160, 170, 180, 190]], dtype=np.uint8)
|
||||
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 2]:
|
||||
print(f'{data_processor=} {prefetch=} {num_workers=}')
|
||||
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
print(batch_generator.batch_data, batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
print()
|
||||
def pipeline(data, label):
|
||||
return data, label
|
||||
|
||||
for pipeline in [None, pipeline]:
|
||||
for data_processor in [None, lambda x:x]:
|
||||
for prefetch in [False, True]:
|
||||
for num_workers in [1, 2]:
|
||||
print(f'{pipeline=} {data_processor=} {prefetch=} {num_workers=}')
|
||||
with SequenceGenerator(data, label, 5, 2, data_processor=data_processor, pipeline=pipeline,
|
||||
prefetch=prefetch, num_workers=num_workers) as batch_generator:
|
||||
for _ in range(9):
|
||||
print(batch_generator.batch_data.tolist(), batch_generator.batch_label.tolist(),
|
||||
batch_generator.epoch, batch_generator.step)
|
||||
batch_generator.next_batch()
|
||||
print()
|
||||
|
||||
test()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue