未验证 提交 125409c8 编写于 作者: S shangliang Xu 提交者: GitHub

update ssd loss, make it faster (#2761)

* update ssd loss, make it faster test=develop

* update comments, test=develop

* ssd loss, fix bug when no object, text=develop
上级 c0404286
......@@ -19,9 +19,9 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from ppdet.core.workspace import register
from ..ops import bipartite_match, box_coder, iou_similarity
from ..ops import iou_similarity
from ..bbox_utils import bbox2delta
__all__ = ['SSDLoss']
......@@ -32,191 +32,131 @@ class SSDLoss(nn.Layer):
SSDLoss
Args:
match_type (str): The type of matching method, should be
'bipartite' or 'per_prediction'. None ('bipartite') by default.
overlap_threshold (float32, optional): If `match_type` is 'per_prediction',
this threshold is to determine the extra matching bboxes based
on the maximum distance, 0.5 by default.
overlap_threshold (float32, optional): IoU threshold for negative bboxes
and positive bboxes, 0.5 by default.
neg_pos_ratio (float): The ratio of negative samples / positive samples.
neg_overlap (float): The overlap threshold of negative samples.
loc_loss_weight (float): The weight of loc_loss.
conf_loss_weight (float): The weight of conf_loss.
prior_box_var (list): Variances corresponding to prior box coord, [0.1,
0.1, 0.2, 0.2] by default.
"""
def __init__(self,
match_type='per_prediction',
overlap_threshold=0.5,
neg_pos_ratio=3.0,
neg_overlap=0.5,
loc_loss_weight=1.0,
conf_loss_weight=1.0):
conf_loss_weight=1.0,
prior_box_var=[0.1, 0.1, 0.2, 0.2]):
super(SSDLoss, self).__init__()
self.match_type = match_type
self.overlap_threshold = overlap_threshold
self.neg_pos_ratio = neg_pos_ratio
self.neg_overlap = neg_overlap
self.loc_loss_weight = loc_loss_weight
self.conf_loss_weight = conf_loss_weight
def _label_target_assign(self,
gt_label,
matched_indices,
neg_mask=None,
mismatch_value=0):
gt_label = gt_label.numpy()
matched_indices = matched_indices.numpy()
if neg_mask is not None:
neg_mask = neg_mask.numpy()
batch_size, num_priors = matched_indices.shape
trg_lbl = np.ones((batch_size, num_priors, 1)).astype('int32')
trg_lbl *= mismatch_value
trg_lbl_wt = np.zeros((batch_size, num_priors, 1)).astype('float32')
self.prior_box_var = [1. / a for a in prior_box_var]
def _bipartite_match_for_batch(self, gt_bbox, gt_label, prior_boxes,
bg_index):
"""
Args:
gt_bbox (Tensor): [B, N, 4]
gt_label (Tensor): [B, N, 1]
prior_boxes (Tensor): [A, 4]
bg_index (int): Background class index
"""
batch_size, num_priors = gt_bbox.shape[0], prior_boxes.shape[0]
ious = iou_similarity(gt_bbox.reshape((-1, 4)), prior_boxes).reshape(
(batch_size, -1, num_priors))
# Calculate the number of object per sample.
num_object = (ious.sum(axis=-1) > 0).astype('int64').sum(axis=-1)
# For each prior box, get the max IoU of all GTs.
prior_max_iou, prior_argmax_iou = ious.max(axis=1), ious.argmax(axis=1)
# For each GT, get the max IoU of all prior boxes.
gt_max_iou, gt_argmax_iou = ious.max(axis=2), ious.argmax(axis=2)
# Gather target bbox and label according to 'prior_argmax_iou' index.
batch_ind = paddle.arange(
0, batch_size, dtype='int64').unsqueeze(-1).tile([1, num_priors])
prior_argmax_iou = paddle.stack([batch_ind, prior_argmax_iou], axis=-1)
targets_bbox = paddle.gather_nd(gt_bbox, prior_argmax_iou)
targets_label = paddle.gather_nd(gt_label, prior_argmax_iou)
# Assign negative
bg_index_tensor = paddle.full([batch_size, num_priors, 1], bg_index,
'int64')
targets_label = paddle.where(
prior_max_iou.unsqueeze(-1) < self.overlap_threshold,
bg_index_tensor, targets_label)
# Ensure each GT can match the max IoU prior box.
for i in range(batch_size):
col_ids = np.where(matched_indices[i] > -1)
col_val = matched_indices[i][col_ids]
trg_lbl[i][col_ids] = gt_label[i][col_val]
trg_lbl_wt[i][col_ids] = 1.0
if neg_mask is not None:
trg_lbl_wt += neg_mask[:, :, np.newaxis]
return paddle.to_tensor(trg_lbl), paddle.to_tensor(trg_lbl_wt)
def _bbox_target_assign(self, encoded_box, matched_indices):
encoded_box = encoded_box.numpy()
matched_indices = matched_indices.numpy()
batch_size, num_priors = matched_indices.shape
trg_bbox = np.zeros((batch_size, num_priors, 4)).astype('float32')
trg_bbox_wt = np.zeros((batch_size, num_priors, 1)).astype('float32')
for i in range(batch_size):
col_ids = np.where(matched_indices[i] > -1)
col_val = matched_indices[i][col_ids]
for v, c in zip(col_val.tolist(), col_ids[0]):
trg_bbox[i][c] = encoded_box[i][v][c]
trg_bbox_wt[i][col_ids] = 1.0
return paddle.to_tensor(trg_bbox), paddle.to_tensor(trg_bbox_wt)
def _mine_hard_example(self,
conf_loss,
matched_indices,
matched_dist,
neg_pos_ratio=3.0,
neg_overlap=0.5):
pos = (matched_indices > -1).astype(conf_loss.dtype)
if num_object[i] > 0:
targets_bbox[i] = paddle.scatter(
targets_bbox[i], gt_argmax_iou[i, :int(num_object[i])],
gt_bbox[i, :int(num_object[i])])
targets_label[i] = paddle.scatter(
targets_label[i], gt_argmax_iou[i, :int(num_object[i])],
gt_label[i, :int(num_object[i])])
# Encode box
prior_boxes = prior_boxes.unsqueeze(0).tile([batch_size, 1, 1])
targets_bbox = bbox2delta(
prior_boxes.reshape([-1, 4]),
targets_bbox.reshape([-1, 4]), self.prior_box_var)
targets_bbox = targets_bbox.reshape([batch_size, -1, 4])
return targets_bbox, targets_label
def _mine_hard_example(self, conf_loss, targets_label, bg_index):
pos = (targets_label != bg_index).astype(conf_loss.dtype)
num_pos = pos.sum(axis=1, keepdim=True)
neg = (matched_dist < neg_overlap).astype(conf_loss.dtype)
neg = (targets_label == bg_index).astype(conf_loss.dtype)
conf_loss = conf_loss * (1.0 - pos) * neg
conf_loss = conf_loss.clone() * neg
loss_idx = conf_loss.argsort(axis=1, descending=True)
idx_rank = loss_idx.argsort(axis=1)
num_negs = []
for i in range(matched_indices.shape[0]):
cur_idx = loss_idx[i]
for i in range(conf_loss.shape[0]):
cur_num_pos = num_pos[i]
num_neg = paddle.clip(cur_num_pos * neg_pos_ratio, max=pos.shape[1])
num_neg = paddle.clip(
cur_num_pos * self.neg_pos_ratio, max=pos.shape[1])
num_negs.append(num_neg)
num_neg = paddle.stack(num_negs, axis=0).expand_as(idx_rank)
num_neg = paddle.stack(num_negs).expand_as(idx_rank)
neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype)
return neg_mask
def forward(self, boxes, scores, gt_box, gt_class, anchors):
return (neg_mask + pos).astype('bool')
def forward(self, boxes, scores, gt_bbox, gt_label, prior_boxes):
boxes = paddle.concat(boxes, axis=1)
scores = paddle.concat(scores, axis=1)
prior_boxes = paddle.concat(anchors, axis=0)
gt_label = gt_class.unsqueeze(-1)
batch_size, num_priors = scores.shape[:2]
num_classes = scores.shape[-1] - 1
def _reshape_to_2d(x):
return paddle.flatten(x, start_axis=2)
# 1. Find matched bounding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
# 1.2 Compute matched bounding box by bipartite matching algorithm.
matched_indices = []
matched_dist = []
for i in range(gt_box.shape[0]):
iou = iou_similarity(gt_box[i], prior_boxes)
matched_indice, matched_d = bipartite_match(iou, self.match_type,
self.overlap_threshold)
matched_indices.append(matched_indice)
matched_dist.append(matched_d)
matched_indices = paddle.concat(matched_indices, axis=0)
matched_indices.stop_gradient = True
matched_dist = paddle.concat(matched_dist, axis=0)
matched_dist.stop_gradient = True
# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
target_label, _ = self._label_target_assign(
gt_label, matched_indices, mismatch_value=num_classes)
confidence = _reshape_to_2d(scores)
# 2.2. Compute confidence loss.
# Reshape confidence to 2D tensor.
target_label = _reshape_to_2d(target_label).astype('int64')
conf_loss = F.softmax_with_cross_entropy(confidence, target_label)
conf_loss = paddle.reshape(conf_loss, [batch_size, num_priors])
# 3. Mining hard examples
neg_mask = self._mine_hard_example(
conf_loss,
matched_indices,
matched_dist,
neg_pos_ratio=self.neg_pos_ratio,
neg_overlap=self.neg_overlap)
# 4. Assign classification and regression targets
# 4.1. Encoded bbox according to the prior boxes.
prior_box_var = paddle.to_tensor(
np.array(
[0.1, 0.1, 0.2, 0.2], dtype='float32')).reshape(
[1, 4]).expand_as(prior_boxes)
encoded_bbox = []
for i in range(gt_box.shape[0]):
encoded_bbox.append(
box_coder(
prior_box=prior_boxes,
prior_box_var=prior_box_var,
target_box=gt_box[i],
code_type='encode_center_size'))
encoded_bbox = paddle.stack(encoded_bbox, axis=0)
# 4.2. Assign regression targets
target_bbox, target_loc_weight = self._bbox_target_assign(
encoded_bbox, matched_indices)
# 4.3. Assign classification targets
target_label, target_conf_weight = self._label_target_assign(
gt_label,
matched_indices,
neg_mask=neg_mask,
mismatch_value=num_classes)
# 5. Compute loss.
# 5.1 Compute confidence loss.
target_label = _reshape_to_2d(target_label).astype('int64')
conf_loss = F.softmax_with_cross_entropy(confidence, target_label)
target_conf_weight = _reshape_to_2d(target_conf_weight)
conf_loss = conf_loss * target_conf_weight * self.conf_loss_weight
# 5.2 Compute regression loss.
location = _reshape_to_2d(boxes)
target_bbox = _reshape_to_2d(target_bbox)
loc_loss = F.smooth_l1_loss(location, target_bbox, reduction='none')
loc_loss = paddle.sum(loc_loss, axis=-1, keepdim=True)
target_loc_weight = _reshape_to_2d(target_loc_weight)
loc_loss = loc_loss * target_loc_weight * self.loc_loss_weight
# 5.3 Compute overall weighted loss.
loss = conf_loss + loc_loss
loss = paddle.reshape(loss, [batch_size, num_priors])
loss = paddle.sum(loss, axis=1, keepdim=True)
normalizer = paddle.sum(target_loc_weight)
loss = paddle.sum(loss / normalizer)
gt_label = gt_label.unsqueeze(-1).astype('int64')
prior_boxes = paddle.concat(prior_boxes, axis=0)
bg_index = scores.shape[-1] - 1
# Match bbox and get targets.
targets_bbox, targets_label = \
self._bipartite_match_for_batch(gt_bbox, gt_label, prior_boxes, bg_index)
targets_bbox.stop_gradient = True
targets_label.stop_gradient = True
# Compute regression loss.
# Select positive samples.
bbox_mask = (targets_label != bg_index).astype(boxes.dtype)
loc_loss = bbox_mask * F.smooth_l1_loss(
boxes, targets_bbox, reduction='none')
loc_loss = loc_loss.sum() * self.loc_loss_weight
# Compute confidence loss.
conf_loss = F.softmax_with_cross_entropy(scores, targets_label)
# Mining hard examples.
label_mask = self._mine_hard_example(
conf_loss.squeeze(-1), targets_label.squeeze(-1), bg_index)
conf_loss = conf_loss * label_mask.unsqueeze(-1).astype(conf_loss.dtype)
conf_loss = conf_loss.sum() * self.conf_loss_weight
# Compute overall weighted loss.
normalizer = (targets_label != bg_index).astype('float32').sum().clip(
min=1)
loss = (conf_loss + loc_loss) / (normalizer + 1e-9)
return loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册