torch_utils/ssd/loss.py
2021-05-21 15:14:14 +09:00

112 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
class JacardOverlap(nn.Module):
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
Jaccard overlap : A ∩ B / A B = A ∩ B / (area(A) + area(B) - A ∩ B)
Return:
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
"""
anchors_count = anchors.size(0)
labels_count = labels.size(0)
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
anchor_coords = torch.cat([
anchors[:, :2] - (anchors[:, 2:] / 2),
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
label_coords = torch.cat([
labels[:, :2] - (labels[:, 2:] / 2),
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
inter_coords = torch.clamp(maxes - mins, min=0)
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
union_area = anchor_areas + label_areas - inter_area
return inter_area / union_area
class SSDLoss(nn.Module):
def __init__(self, anchors: torch.Tensor, label_per_image: int,
negative_mining_ratio: int, matching_iou: float,
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
super().__init__()
self.anchors = anchors
self.anchor_count = anchors.size(0)
self.label_per_image = label_per_image
self.location_dimmension = location_dimmension
self.negative_mining_ratio = negative_mining_ratio
self.matching_iou = matching_iou
self.localization_loss_weight = localization_loss_weight
self.overlap = JacardOverlap()
self.matches = []
# self.negative_matches = []
self.positive_class_loss = torch.Tensor()
self.negative_class_loss = torch.Tensor()
self.localization_loss = torch.Tensor()
self.class_loss = torch.Tensor()
self.final_loss = torch.Tensor()
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
batch_size = input_data.size(0)
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
batch_size, self.anchor_count, self.label_per_image, 4)
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
expanded_anchors[:, :, :, 2:])
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
positive_objectives = []
positive_predictions = []
positive_class_loss = []
negative_class_loss = []
self.matches = []
# self.negative_matches = []
for batch_index in range(batch_size):
predictions = input_data[batch_index]
labels = input_labels[batch_index]
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
mask = (overlaps >= self.matching_iou).long()
match_indices = torch.nonzero(mask, as_tuple=False)
self.matches.append(match_indices.detach().cpu())
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
# self.negative_matches.append(non_match_indices.detach().cpu())
for anchor_index, label_index in match_indices:
positive_predictions.append(predictions[anchor_index])
positive_objectives.append(
torch.cat((
objective_pos[batch_index, anchor_index, label_index],
objective_size[batch_index, anchor_index, label_index]), dim=-1))
positive_class_loss.append(torch.log(
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
for anchor_index in non_match_indices:
negative_class_loss.append(
torch.log(predictions[anchor_index, self.location_dimmension]))
if not positive_predictions:
return None
positive_predictions = torch.stack(positive_predictions)
positive_objectives = torch.stack(positive_objectives)
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
self.localization_loss = nn.functional.smooth_l1_loss(
positive_predictions[:, self.location_dimmension],
positive_objectives)
self.class_loss = self.positive_class_loss + self.negative_class_loss
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
return self.final_loss