SSDLoss implementation

This commit is contained in:
Corentin 2021-05-21 15:14:14 +09:00
commit d87bb89e6c
2 changed files with 113 additions and 1 deletions

View file

@ -54,7 +54,7 @@ class Linear(Layer):
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs):
super().__init__(activation, batch_norm)
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
self.batch_norm = nn.BatchNorm1d(
out_channels,
momentum=Layer.BATCH_NORM_MOMENTUM,

112
ssd/loss.py Normal file
View file

@ -0,0 +1,112 @@
import torch
import torch.nn as nn
class JacardOverlap(nn.Module):
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
Jaccard overlap : A B / A B = A B / (area(A) + area(B) - A B)
Return:
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
"""
anchors_count = anchors.size(0)
labels_count = labels.size(0)
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
anchor_coords = torch.cat([
anchors[:, :2] - (anchors[:, 2:] / 2),
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
label_coords = torch.cat([
labels[:, :2] - (labels[:, 2:] / 2),
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
inter_coords = torch.clamp(maxes - mins, min=0)
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
union_area = anchor_areas + label_areas - inter_area
return inter_area / union_area
class SSDLoss(nn.Module):
def __init__(self, anchors: torch.Tensor, label_per_image: int,
negative_mining_ratio: int, matching_iou: float,
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
super().__init__()
self.anchors = anchors
self.anchor_count = anchors.size(0)
self.label_per_image = label_per_image
self.location_dimmension = location_dimmension
self.negative_mining_ratio = negative_mining_ratio
self.matching_iou = matching_iou
self.localization_loss_weight = localization_loss_weight
self.overlap = JacardOverlap()
self.matches = []
# self.negative_matches = []
self.positive_class_loss = torch.Tensor()
self.negative_class_loss = torch.Tensor()
self.localization_loss = torch.Tensor()
self.class_loss = torch.Tensor()
self.final_loss = torch.Tensor()
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
batch_size = input_data.size(0)
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
batch_size, self.anchor_count, self.label_per_image, 4)
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
expanded_anchors[:, :, :, 2:])
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
positive_objectives = []
positive_predictions = []
positive_class_loss = []
negative_class_loss = []
self.matches = []
# self.negative_matches = []
for batch_index in range(batch_size):
predictions = input_data[batch_index]
labels = input_labels[batch_index]
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
mask = (overlaps >= self.matching_iou).long()
match_indices = torch.nonzero(mask, as_tuple=False)
self.matches.append(match_indices.detach().cpu())
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
# self.negative_matches.append(non_match_indices.detach().cpu())
for anchor_index, label_index in match_indices:
positive_predictions.append(predictions[anchor_index])
positive_objectives.append(
torch.cat((
objective_pos[batch_index, anchor_index, label_index],
objective_size[batch_index, anchor_index, label_index]), dim=-1))
positive_class_loss.append(torch.log(
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
for anchor_index in non_match_indices:
negative_class_loss.append(
torch.log(predictions[anchor_index, self.location_dimmension]))
if not positive_predictions:
return None
positive_predictions = torch.stack(positive_predictions)
positive_objectives = torch.stack(positive_objectives)
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
self.localization_loss = nn.functional.smooth_l1_loss(
positive_predictions[:, self.location_dimmension],
positive_objectives)
self.class_loss = self.positive_class_loss + self.negative_class_loss
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
return self.final_loss