diff --git a/layers.py b/layers.py index 6d511f6..2c9fd9c 100644 --- a/layers.py +++ b/layers.py @@ -54,7 +54,7 @@ class Linear(Layer): def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs): super().__init__(activation, batch_norm) - self.fc = nn.Linear(in_channels, out_channels, **kwargs) + self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs) self.batch_norm = nn.BatchNorm1d( out_channels, momentum=Layer.BATCH_NORM_MOMENTUM, diff --git a/ssd/loss.py b/ssd/loss.py new file mode 100644 index 0000000..1b3d259 --- /dev/null +++ b/ssd/loss.py @@ -0,0 +1,112 @@ +import torch +import torch.nn as nn + + +class JacardOverlap(nn.Module): + def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Assuming rank 2 (number of boxes, locations), location is (y, x, h, w) + Jaccard overlap : A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Return: + jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)] + """ + anchors_count = anchors.size(0) + labels_count = labels.size(0) + + # Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count) + anchor_coords = torch.cat([ + anchors[:, :2] - (anchors[:, 2:] / 2), + anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4) + label_coords = torch.cat([ + labels[:, :2] - (labels[:, 2:] / 2), + labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4) + + mins = torch.max(anchor_coords, label_coords)[:, :, :2] + maxes = torch.min(anchor_coords, label_coords)[:, :, 2:] + + inter_coords = torch.clamp(maxes - mins, min=0) + inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1] + + anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area) + label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area) + + union_area = anchor_areas + label_areas - inter_area + return inter_area / union_area + + +class SSDLoss(nn.Module): + def __init__(self, anchors: torch.Tensor, label_per_image: int, + negative_mining_ratio: int, matching_iou: float, + location_dimmension: int = 4, localization_loss_weight: float = 1.0): + super().__init__() + self.anchors = anchors + self.anchor_count = anchors.size(0) + self.label_per_image = label_per_image + self.location_dimmension = location_dimmension + self.negative_mining_ratio = negative_mining_ratio + self.matching_iou = matching_iou + self.localization_loss_weight = localization_loss_weight + + self.overlap = JacardOverlap() + self.matches = [] + # self.negative_matches = [] + self.positive_class_loss = torch.Tensor() + self.negative_class_loss = torch.Tensor() + self.localization_loss = torch.Tensor() + self.class_loss = torch.Tensor() + self.final_loss = torch.Tensor() + + def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor: + batch_size = input_data.size(0) + expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand( + batch_size, self.anchor_count, self.label_per_image, 4) + expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand( + batch_size, self.anchor_count, self.label_per_image, self.location_dimmension) + objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / ( + expanded_anchors[:, :, :, 2:]) + objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:]) + + positive_objectives = [] + positive_predictions = [] + positive_class_loss = [] + negative_class_loss = [] + self.matches = [] + # self.negative_matches = [] + for batch_index in range(batch_size): + predictions = input_data[batch_index] + labels = input_labels[batch_index] + overlaps = self.overlap(self.anchors[:, :4], labels[:, :4]) + mask = (overlaps >= self.matching_iou).long() + match_indices = torch.nonzero(mask, as_tuple=False) + self.matches.append(match_indices.detach().cpu()) + + mining_count = int(self.negative_mining_ratio * len(self.matches[-1])) + masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0] + non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count] + # self.negative_matches.append(non_match_indices.detach().cpu()) + + for anchor_index, label_index in match_indices: + positive_predictions.append(predictions[anchor_index]) + positive_objectives.append( + torch.cat(( + objective_pos[batch_index, anchor_index, label_index], + objective_size[batch_index, anchor_index, label_index]), dim=-1)) + positive_class_loss.append(torch.log( + predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()])) + + for anchor_index in non_match_indices: + negative_class_loss.append( + torch.log(predictions[anchor_index, self.location_dimmension])) + + if not positive_predictions: + return None + positive_predictions = torch.stack(positive_predictions) + positive_objectives = torch.stack(positive_objectives) + self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss)) + self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss)) + self.localization_loss = nn.functional.smooth_l1_loss( + positive_predictions[:, self.location_dimmension], + positive_objectives) + self.class_loss = self.positive_class_loss + self.negative_class_loss + self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss + return self.final_loss