import colorsys import math import numpy as np import torch import torch.nn as nn from .box import check_rectangle from ..layers import Conv2d class SSD(nn.Module): class Detector(nn.Module): def __init__(self, input_features: int, output_features: int): super().__init__() self.conv = Conv2d(input_features, output_features, kernel_size=3, padding=1, batch_norm=False, activation=None) self.output = None def forward(self, input_data: torch.Tensor) -> torch.Tensor: self.output = self.conv(input_data).permute(0, 2, 3, 1) return self.output class DetectorMerge(nn.Module): def __init__(self, location_dimmension: int): super().__init__() self.location_dim = location_dimmension def forward(self, detector_outputs: torch.Tensor) -> torch.Tensor: return torch.cat( [detector_outputs[:, :, :self.location_dim], torch.softmax(detector_outputs[:, :, self.location_dim:], dim=2)], dim=2) class AnchorInfo: def __init__(self, center: tuple[float, float], size: tuple[float], index: int, layer_index: int, map_index: tuple[int, int], color_index: int, ratio: float, size_factor: float): self.index = index self.layer_index = layer_index self.map_index = map_index self.color_index = color_index self.ratio = ratio self.size_factor = size_factor self.center = center self.size = size self.box = check_rectangle( center[0] - (size[0] / 2), center[1] - (size[1] / 2), center[0] + (size[0] / 2), center[1] + (size[1] / 2)) def __repr__(self): return (f'{self.__class__.__name__}' f'(index:{self.index}, layer:{self.layer_index}, coord:{self.map_index}' f', center:({self.center[0]:.03f}, {self.center[1]:.03f})' f', size:({self.size[0]:.03f}, {self.size[1]:.03f})' f', ratio:{self.ratio:.03f}, size_factor:{self.size_factor:.03f})' f', y:[{self.box[0]:.03f}:{self.box[2]:.03f}]' f', x:[{self.box[1]:.03f}:{self.box[3]:.03f}])') def __array__(self): return np.array([*self.center, *self.size]) def __init__(self, base_network: nn.Module, input_sample: torch.Tensor, classes: list[str], location_dimmension: int, layer_channels: list[int], layer_box_ratios: list[float], layer_args: dict, box_size_factors: list[float]): super().__init__() self.location_dim = location_dimmension self.classes = ['none'] + classes self.class_count = len(self.classes) self.base_input_shape = input_sample.numpy().shape[1:] self.base_network = base_network sample_output = base_network(input_sample) self.base_output_shape = list(sample_output.detach().numpy().shape)[-3:] layer_convs: list[nn.Module] = [] layer_detectors: list[SSD.Detector] = [] last_feature_count = self.base_output_shape[0] for layer_index, (output_features, kwargs) in enumerate(zip(layer_channels, layer_args)): if 'disable' not in kwargs: layer_convs.append(Conv2d(last_feature_count, output_features, **kwargs)) layer_detectors.append(SSD.Detector( last_feature_count, (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index]))) # layers.append(SSD.Layer( # last_feature_count, output_features, # (self.class_count + self.location_dim) * len(layer_box_ratios[layer_index]), # **kwargs)) last_feature_count = output_features self.layer_convs = nn.ModuleList(layer_convs) self.layer_detectors = nn.ModuleList(layer_detectors) self.merge = self.DetectorMerge(location_dimmension) self.anchors_numpy, self.anchor_info, self.box_colors = self._create_anchors( sample_output, self.layer_convs, self.layer_detectors, layer_box_ratios, box_size_factors, input_sample.shape[3] / input_sample.shape[2]) self.anchors = torch.from_numpy(self.anchors_numpy) def forward(self, input_data: torch.Tensor) -> torch.Tensor: head = self.base_network(input_data) detector_outputs = [] for layer_index, detector in enumerate(self.layer_detectors): detector_out = detector(head) detector_outputs.append(detector_out.reshape( detector_out.size(0), -1, self.class_count + self.location_dim)) if layer_index < len(self.layer_convs): head = self.layer_convs[layer_index](head) detector_outputs = torch.cat(detector_outputs, 1) return self.merge(detector_outputs) # base_output = self.base_network(input_data) # head = base_output # outputs = [] # for layer in self.layers: # head, detector_output = layer(head) # outputs.append(detector_output.reshape(base_output.size(0), -1, self.class_count + self.location_dim)) # outputs = torch.cat(outputs, 1) # return torch.cat( # [outputs[:, :, :self.location_dim], torch.softmax(outputs[:, :, self.location_dim:], dim=2)], dim=2) def _apply(self, fn): super()._apply(fn) self.anchors = fn(self.anchors) return self @staticmethod def _create_anchors( base_output: torch.Tensor, layers: nn.ModuleList, detectors: nn.ModuleList, layer_box_ratios: list[float], box_size_factors: list[float], image_ratio: float) -> tuple[np.ndarray, np.ndarray, list[np.ndarray]]: anchors = [] anchor_info: list[SSD.AnchorInfo] = [] box_colors: list[np.ndarray] = [] head = base_output for layer_index, detector in enumerate(detectors): detector_output = detector(head) # detector output shape : NCRSHW (Ratio, Size) if layer_index < len(layers): head = layers[layer_index](head) detector_rows = detector_output.size()[1] detector_cols = detector_output.size()[2] color_index = 0 layer_ratios = layer_box_ratios[layer_index] for index_y in range(detector_rows): center_y = (index_y + 0.5) / detector_rows for index_x in range(detector_cols): center_x = (index_x + 0.5) / detector_cols for ratio, size_factor in zip(layer_ratios, box_size_factors): box_colors.append((np.asarray(colorsys.hsv_to_rgb( color_index / len(layer_ratios), 1.0, 1.0)) * 255).astype(np.uint8)) color_index += 1 unit_box_size = size_factor / max(detector_rows, detector_cols) anchor_width = unit_box_size * math.sqrt(ratio / image_ratio) anchor_height = unit_box_size / math.sqrt(ratio / image_ratio) anchor_info.append(SSD.AnchorInfo( (center_y, center_x), (anchor_height, anchor_width), len(anchors), layer_index, (index_y, index_x), len(box_colors) - 1, ratio, size_factor )) anchors.append([center_y, center_x, anchor_height, anchor_width]) return np.asarray(anchors, dtype=np.float32), anchor_info, box_colors