112 lines
5.5 KiB
Python
112 lines
5.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class JacardOverlap(nn.Module):
|
||
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
|
||
Jaccard overlap : A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||
Return:
|
||
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
|
||
"""
|
||
anchors_count = anchors.size(0)
|
||
labels_count = labels.size(0)
|
||
|
||
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
|
||
anchor_coords = torch.cat([
|
||
anchors[:, :2] - (anchors[:, 2:] / 2),
|
||
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
|
||
label_coords = torch.cat([
|
||
labels[:, :2] - (labels[:, 2:] / 2),
|
||
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
|
||
|
||
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
|
||
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
|
||
|
||
inter_coords = torch.clamp(maxes - mins, min=0)
|
||
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
|
||
|
||
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
|
||
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
|
||
|
||
union_area = anchor_areas + label_areas - inter_area
|
||
return inter_area / union_area
|
||
|
||
|
||
class SSDLoss(nn.Module):
|
||
def __init__(self, anchors: torch.Tensor, label_per_image: int,
|
||
negative_mining_ratio: int, matching_iou: float,
|
||
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
|
||
super().__init__()
|
||
self.anchors = anchors
|
||
self.anchor_count = anchors.size(0)
|
||
self.label_per_image = label_per_image
|
||
self.location_dimmension = location_dimmension
|
||
self.negative_mining_ratio = negative_mining_ratio
|
||
self.matching_iou = matching_iou
|
||
self.localization_loss_weight = localization_loss_weight
|
||
|
||
self.overlap = JacardOverlap()
|
||
self.matches = []
|
||
# self.negative_matches = []
|
||
self.positive_class_loss = torch.Tensor()
|
||
self.negative_class_loss = torch.Tensor()
|
||
self.localization_loss = torch.Tensor()
|
||
self.class_loss = torch.Tensor()
|
||
self.final_loss = torch.Tensor()
|
||
|
||
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
|
||
batch_size = input_data.size(0)
|
||
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
|
||
batch_size, self.anchor_count, self.label_per_image, 4)
|
||
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
|
||
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
|
||
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
|
||
expanded_anchors[:, :, :, 2:])
|
||
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
|
||
|
||
positive_objectives = []
|
||
positive_predictions = []
|
||
positive_class_loss = []
|
||
negative_class_loss = []
|
||
self.matches = []
|
||
# self.negative_matches = []
|
||
for batch_index in range(batch_size):
|
||
predictions = input_data[batch_index]
|
||
labels = input_labels[batch_index]
|
||
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
|
||
mask = (overlaps >= self.matching_iou).long()
|
||
match_indices = torch.nonzero(mask, as_tuple=False)
|
||
self.matches.append(match_indices.detach().cpu())
|
||
|
||
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
|
||
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
|
||
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
|
||
# self.negative_matches.append(non_match_indices.detach().cpu())
|
||
|
||
for anchor_index, label_index in match_indices:
|
||
positive_predictions.append(predictions[anchor_index])
|
||
positive_objectives.append(
|
||
torch.cat((
|
||
objective_pos[batch_index, anchor_index, label_index],
|
||
objective_size[batch_index, anchor_index, label_index]), dim=-1))
|
||
positive_class_loss.append(torch.log(
|
||
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
|
||
|
||
for anchor_index in non_match_indices:
|
||
negative_class_loss.append(
|
||
torch.log(predictions[anchor_index, self.location_dimmension]))
|
||
|
||
if not positive_predictions:
|
||
return None
|
||
positive_predictions = torch.stack(positive_predictions)
|
||
positive_objectives = torch.stack(positive_objectives)
|
||
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
|
||
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
|
||
self.localization_loss = nn.functional.smooth_l1_loss(
|
||
positive_predictions[:, self.location_dimmension],
|
||
positive_objectives)
|
||
self.class_loss = self.positive_class_loss + self.negative_class_loss
|
||
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
|
||
return self.final_loss
|