SSDLoss implementation
This commit is contained in:
parent
092f4acc3b
commit
d87bb89e6c
2 changed files with 113 additions and 1 deletions
|
|
@ -54,7 +54,7 @@ class Linear(Layer):
|
|||
def __init__(self, in_channels: int, out_channels: int, activation=0, batch_norm=None, **kwargs):
|
||||
super().__init__(activation, batch_norm)
|
||||
|
||||
self.fc = nn.Linear(in_channels, out_channels, **kwargs)
|
||||
self.fc = nn.Linear(in_channels, out_channels, bias=not self.batch_norm, **kwargs)
|
||||
self.batch_norm = nn.BatchNorm1d(
|
||||
out_channels,
|
||||
momentum=Layer.BATCH_NORM_MOMENTUM,
|
||||
|
|
|
|||
112
ssd/loss.py
Normal file
112
ssd/loss.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class JacardOverlap(nn.Module):
|
||||
def forward(self, anchors: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Assuming rank 2 (number of boxes, locations), location is (y, x, h, w)
|
||||
Jaccard overlap : A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
||||
Return:
|
||||
jaccard overlap: (tensor) Shape: [predictions.size(0), labels.size(0)]
|
||||
"""
|
||||
anchors_count = anchors.size(0)
|
||||
labels_count = labels.size(0)
|
||||
|
||||
# Getting coords (y_min, x_min, y_max, x_max) repeated to fill (anchor count, label count)
|
||||
anchor_coords = torch.cat([
|
||||
anchors[:, :2] - (anchors[:, 2:] / 2),
|
||||
anchors[:, :2] + (anchors[:, 2:] / 2)], 1).unsqueeze(1).expand(anchors_count, labels_count, 4)
|
||||
label_coords = torch.cat([
|
||||
labels[:, :2] - (labels[:, 2:] / 2),
|
||||
labels[:, :2] + (labels[:, 2:] / 2)], 1).unsqueeze(0).expand(anchors_count, labels_count, 4)
|
||||
|
||||
mins = torch.max(anchor_coords, label_coords)[:, :, :2]
|
||||
maxes = torch.min(anchor_coords, label_coords)[:, :, 2:]
|
||||
|
||||
inter_coords = torch.clamp(maxes - mins, min=0)
|
||||
inter_area = inter_coords[:, :, 0] * inter_coords[:, :, 1]
|
||||
|
||||
anchor_areas = (anchors[:, 2] * anchors[:, 3]).unsqueeze(1).expand_as(inter_area)
|
||||
label_areas = (labels[:, 2] * labels[:, 3]).unsqueeze(0).expand_as(inter_area)
|
||||
|
||||
union_area = anchor_areas + label_areas - inter_area
|
||||
return inter_area / union_area
|
||||
|
||||
|
||||
class SSDLoss(nn.Module):
|
||||
def __init__(self, anchors: torch.Tensor, label_per_image: int,
|
||||
negative_mining_ratio: int, matching_iou: float,
|
||||
location_dimmension: int = 4, localization_loss_weight: float = 1.0):
|
||||
super().__init__()
|
||||
self.anchors = anchors
|
||||
self.anchor_count = anchors.size(0)
|
||||
self.label_per_image = label_per_image
|
||||
self.location_dimmension = location_dimmension
|
||||
self.negative_mining_ratio = negative_mining_ratio
|
||||
self.matching_iou = matching_iou
|
||||
self.localization_loss_weight = localization_loss_weight
|
||||
|
||||
self.overlap = JacardOverlap()
|
||||
self.matches = []
|
||||
# self.negative_matches = []
|
||||
self.positive_class_loss = torch.Tensor()
|
||||
self.negative_class_loss = torch.Tensor()
|
||||
self.localization_loss = torch.Tensor()
|
||||
self.class_loss = torch.Tensor()
|
||||
self.final_loss = torch.Tensor()
|
||||
|
||||
def forward(self, input_data: torch.Tensor, input_labels: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = input_data.size(0)
|
||||
expanded_anchors = self.anchors[:, :4].unsqueeze(0).unsqueeze(2).expand(
|
||||
batch_size, self.anchor_count, self.label_per_image, 4)
|
||||
expanded_labels = input_labels[:, :, :self.location_dimmension].unsqueeze(1).expand(
|
||||
batch_size, self.anchor_count, self.label_per_image, self.location_dimmension)
|
||||
objective_pos = (expanded_labels[:, :, :, :2] - expanded_anchors[:, :, :, :2]) / (
|
||||
expanded_anchors[:, :, :, 2:])
|
||||
objective_size = torch.log(expanded_labels[:, :, :, 2:] / expanded_anchors[:, :, :, 2:])
|
||||
|
||||
positive_objectives = []
|
||||
positive_predictions = []
|
||||
positive_class_loss = []
|
||||
negative_class_loss = []
|
||||
self.matches = []
|
||||
# self.negative_matches = []
|
||||
for batch_index in range(batch_size):
|
||||
predictions = input_data[batch_index]
|
||||
labels = input_labels[batch_index]
|
||||
overlaps = self.overlap(self.anchors[:, :4], labels[:, :4])
|
||||
mask = (overlaps >= self.matching_iou).long()
|
||||
match_indices = torch.nonzero(mask, as_tuple=False)
|
||||
self.matches.append(match_indices.detach().cpu())
|
||||
|
||||
mining_count = int(self.negative_mining_ratio * len(self.matches[-1]))
|
||||
masked_prediction = predictions[:, self.location_dimmension] + torch.max(mask, dim=1)[0]
|
||||
non_match_indices = torch.argsort(masked_prediction, dim=-1, descending=False)[:mining_count]
|
||||
# self.negative_matches.append(non_match_indices.detach().cpu())
|
||||
|
||||
for anchor_index, label_index in match_indices:
|
||||
positive_predictions.append(predictions[anchor_index])
|
||||
positive_objectives.append(
|
||||
torch.cat((
|
||||
objective_pos[batch_index, anchor_index, label_index],
|
||||
objective_size[batch_index, anchor_index, label_index]), dim=-1))
|
||||
positive_class_loss.append(torch.log(
|
||||
predictions[anchor_index, self.location_dimmension + labels[label_index, -1].long()]))
|
||||
|
||||
for anchor_index in non_match_indices:
|
||||
negative_class_loss.append(
|
||||
torch.log(predictions[anchor_index, self.location_dimmension]))
|
||||
|
||||
if not positive_predictions:
|
||||
return None
|
||||
positive_predictions = torch.stack(positive_predictions)
|
||||
positive_objectives = torch.stack(positive_objectives)
|
||||
self.positive_class_loss = -torch.sum(torch.stack(positive_class_loss))
|
||||
self.negative_class_loss = -torch.sum(torch.stack(negative_class_loss))
|
||||
self.localization_loss = nn.functional.smooth_l1_loss(
|
||||
positive_predictions[:, self.location_dimmension],
|
||||
positive_objectives)
|
||||
self.class_loss = self.positive_class_loss + self.negative_class_loss
|
||||
self.final_loss = (self.localization_loss_weight * self.localization_loss) + self.class_loss
|
||||
return self.final_loss
|
||||
Loading…
Add table
Add a link
Reference in a new issue