import torch import torch.nn as nn class JacardOverlap(nn.Module): def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Assuming rank 2 (number of boxes, locations), location is (y, x, h, w) Jaccard overlap : A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) Return: jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)] """ anchors_count = anchors.size(0) labels_count = labels.size(0) # Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count) anchor_coords = torch.cat([ anchors[:, :2] - (anchors[:, 2:] / 2), anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4) label_coords = torch.cat([ labels[:, :2] - (labels[:, 2:] / 2), labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4) mins = torch.max(anchor_coords, label_coords)[:, :, :2] maxes = torch.min(anchor_coords, label_coords)[:, :, 2:] inter_coords = torch.clamp(maxes - mins, min=0) inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1] anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area) label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area) union_area = anchor_areas + label_areas - inter_area return inter_area / union_area class SSDLoss(nn.Module): def __init__(self, anchors: torch.Tensor, label_per_image: int, negative_mining_ratio: int, matching_iou: float, location_dimmension: int = 4, localization_loss_weight: float = 1.0): super().__init__() self.anchors = anchors self.anchor_count = anchors.size(0) self.label_per_image = label_per_image self.location_dimmension = location_dimmension self.negative_mining_ratio = negative_mining_ratio self.matching_iou = matching_iou self.localization_loss_weight = localization_loss_weight self.overlap = JacardOverlap() self.matches = [] # self.negative_matches = [] self.positive_class_loss = torch.Tensor() self.negative_class_loss = torch.Tensor() self.localization_loss = torch.Tensor() self.class_loss = torch.Tensor() self.final_loss = torch.Tensor() def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor: batch_size = input_data.size(0) expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand( batch_size, self.anchor_count, self.label_per_image, 4) expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand( batch_size, self.anchor_count, self.label_per_image, self.location_dimmension) objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / ( expanded_anchors[:, :, :, 2:]) objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:]) positive_objectives = [] positive_predictions = [] positive_class_loss = [] negative_class_loss = [] self.matches = [] # self.negative_matches = [] for batch_index in range(batch_size): predictions = input_data[batch_index] labels = input_labels[batch_index] overlaps = self.overlap(self.anchors[:, :4], labels[:, :4]) mask = (overlaps >= self.matching_iou).long() match_indices = torch.nonzero(mask, as_tuple=False) self.matches.append(match_indices.detach().cpu()) mining_count = int(self.negative_mining_ratio * len(self.matches[-1])) masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0] non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count] # self.negative_matches.append(non_match_indices.detach().cpu()) for anchor_index, label_index in match_indices: positive_predictions.append(predictions[anchor_index]) positive_objectives.append( torch.cat(( objective_pos[batch_index, anchor_index, label_index], objective_size[batch_index, anchor_index, label_index]), dim=-1)) positive_class_loss.append(torch.log( predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()])) for anchor_index in non_match_indices: negative_class_loss.append( torch.log(predictions[anchor_index, self.location_dimmension])) if not positive_predictions: return None positive_predictions = torch.stack(positive_predictions) positive_objectives = torch.stack(positive_objectives) self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss)) self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss)) self.localization_loss = nn.functional.smooth_l1_loss( positive_predictions[:, self.location_dimmension], positive_objectives) self.class_loss = self.positive_class_loss + self.negative_class_loss self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss return self.final_loss