diff --git a/ppdet/modeling/losses/pisa_utils.py b/ppdet/modeling/losses/pisa_utils.py index 431eaf99d515d1dcf4cac174f702102f4383661d..4731636afa4eedfba2c31f043b90a6bad1c7f967 100644 --- a/ppdet/modeling/losses/pisa_utils.py +++ b/ppdet/modeling/losses/pisa_utils.py @@ -20,25 +20,33 @@ import numpy as np __all__ = ['get_isr_p_func'] -def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2): +def get_isr_p_func(max_box_num=50, pos_iou_thresh=0.5, bias=0, k=2): def irs_p(x): - np.save("data", x) x = np.array(x) - max_ious = x[:, :, 0] - gt_inds = x[:, :, 1].astype('int32') - cls = x[:, :, 2].astype('int32') + gt_label = x[:, :max_box_num] + gt_score = x[:, max_box_num:2 * max_box_num] + remain = x[:, 2 * max_box_num:] + pn = remain.shape[1] // 3 + max_ious = remain[:, :pn] + gt_inds = remain[:, pn:2 * pn].astype('int32') + cls = remain[:, 2 * pn:].astype('int32') - # # n_{max}: max gt box num in each class - # valid_gt = gt_box[:, :, 2] > 0. - # valid_gt_label = gt_label[valid_gt] - # max_l_num = np.bincount(valid_gt_label).max() + pos_mask = max_ious > pos_iou_thresh + if not np.any(pos_mask): + return np.zeros([max_ious.shape[0], pn, 2]).astype('float32') + + cls_target = np.zeros_like(max_ious) + cls_target_weights = np.zeros_like(max_ious) + for i in range(gt_label.shape[0]): + cls_target[i] = gt_label[i, gt_inds[i]] + cls_target_weights[i] = gt_score[i, gt_inds[i]] + # cls_target *= pos_mask.astype('float32') # divide gt index in each sample gt_inds = gt_inds + np.arange(gt_inds.shape[ - 0])[:, np.newaxis] * gt_inds.shape[1] + 0])[:, np.newaxis] * max_box_num - all_pos_weights = np.ones_like(max_ious) - pos_mask = max_ious > pos_iou_thresh + all_pos_weights = np.zeros_like(max_ious) cls = np.reshape(cls, list(max_ious.shape) + [-1]) max_ious = max_ious[pos_mask] pos_weights = all_pos_weights[pos_mask] @@ -58,12 +66,12 @@ def get_isr_p_func(pos_iou_thresh=0.25, bias=0, k=2): l_max_iou_rank = np.argsort(-l_max_ious).argsort().astype('float32') weight_factor = np.clip(max_l_num - l_max_iou_rank, 0., None) / max_l_num - weight_factor = np.power(bias + (1 - bias) * weight_factor, k) - pos_weights[l_inds] *= weight_factor - pos_weights = pos_weights / np.mean(pos_weights) + pos_weights[l_inds] = np.power(bias + (1 - bias) * weight_factor, k) + pos_weights = pos_weights / max(np.mean(pos_weights), 1e-6) all_pos_weights[pos_mask] = pos_weights + cls_target_weights *= all_pos_weights - return all_pos_weights + return np.stack([cls_target, cls_target_weights], axis=-1) return irs_p diff --git a/ppdet/modeling/losses/yolo_loss.py b/ppdet/modeling/losses/yolo_loss.py index b2f0191e2a5d35c38b7efa6b741971b9d5ad1625..4fdcfb9d373eb583b5892b735fa7f0e5b99dd02c 100644 --- a/ppdet/modeling/losses/yolo_loss.py +++ b/ppdet/modeling/losses/yolo_loss.py @@ -66,7 +66,7 @@ class YOLOv3Loss(object): anchor_masks, mask_anchors, num_classes, prefix_name): if self._use_fine_grained_loss: return self._get_fine_grained_loss( - outputs, targets, gt_box, gt_label, self._batch_size, + outputs, targets, gt_box, gt_label, gt_score, self._batch_size, num_classes, mask_anchors, self._ignore_thresh) else: losses = [] @@ -93,7 +93,7 @@ class YOLOv3Loss(object): return {'loss': sum(losses)} def _get_fine_grained_loss(self, outputs, targets, gt_box, gt_label, - batch_size, num_classes, mask_anchors, + gt_score, batch_size, num_classes, mask_anchors, ignore_thresh): """ Calculate fine grained YOLOv3 loss @@ -123,6 +123,7 @@ class YOLOv3Loss(object): "YOLOv3 output layer number not equal target number" loss_xys, loss_whs, loss_objs, loss_clss = [], [], [], [] + loss_carls, loss_isrp_clss = [], [] if self._iou_loss is not None: loss_ious = [] if self._iou_aware_loss is not None: @@ -144,22 +145,54 @@ class YOLOv3Loss(object): sorted_iou, sorted_gt_inds = fluid.layers.argsort( iou, axis=-1, descending=True) - max_iou = sorted_iou[:, :, 0:1] + max_iou = sorted_iou[:, :, 0] gt_inds = fluid.layers.cast( - sorted_gt_inds[:, :, 0:1], dtype='float32') - pred_cls = fluid.layers.argmax(cls, axis=-1) - pred_cls = fluid.layers.reshape(pred_cls, [batch_size, -1, 1]) + sorted_gt_inds[:, :, 0], dtype='float32') + cls_score = fluid.layers.sigmoid(cls) + sorted_cls_score, sorted_pred_cls = fluid.layers.argsort( + cls_score, axis=-1, descending=True) + pred_cls = fluid.layers.reshape(sorted_pred_cls[:, :, :, :, 0], + [batch_size, -1]) pred_cls = fluid.layers.cast(pred_cls, dtype='float32') + + gt_label_fp32 = fluid.layers.cast(gt_label, dtype='float32') + isr_p_input = fluid.layers.concat( - [max_iou, gt_inds, pred_cls], axis=-1) + [gt_label_fp32, gt_score, max_iou, gt_inds, pred_cls], axis=-1) isr_p = get_isr_p_func() - pos_weights = fluid.layers.zeros_like(max_iou) - fluid.layers.py_func(isr_p, isr_p_input, pos_weights) + isr_p_output = fluid.layers.zeros_like(sorted_iou[:, :, :2]) + fluid.layers.py_func(isr_p, isr_p_input, isr_p_output) tobj_shape = fluid.layers.shape(tobj) - pos_weights = fluid.layers.reshape(pos_weights, ( + isr_p_output = fluid.layers.reshape(isr_p_output, ( + -1, an_num, tobj_shape[2], tobj_shape[3], 2)) + cls_target = fluid.layers.cast( + isr_p_output[:, :, :, :, 0:1], dtype='int32') + cls_target = fluid.layers.one_hot(cls_target, num_classes) + cls_target_weights = isr_p_output[:, :, :, :, 1] + cls_target_weights.stop_gradient = True + + loss_isrp_cls = fluid.layers.sigmoid_cross_entropy_with_logits( + cls, cls_target) + loss_isrp_cls = fluid.layers.elementwise_mul( + loss_isrp_cls, cls_target_weights, axis=0) + loss_isrp_cls = fluid.layers.reduce_sum( + loss_isrp_cls, dim=[1, 2, 3]) + + bias = 0.2 + pos_cls_score = fluid.layers.reduce_sum( + cls_score * cls_target, dim=[-1]) + pos_cls_score = fluid.layers.reshape(pos_cls_score, [ + batch_size, + -1, + ]) + pos_mask = fluid.layers.cast( + sorted_iou[:, :, 0] > 0.5, dtype='float32') + carl_weights = bias + (1 - bias) * pos_cls_score * pos_mask + carl_weights *= fluid.layers.reduce_sum( + pos_mask) / fluid.layers.reduce_sum(carl_weights) + carl_weights = fluid.layers.reshape(carl_weights, ( -1, an_num, tobj_shape[2], tobj_shape[3])) - tobj = tobj * pos_weights # isr_tobj = tobj * pos_weights # loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls) @@ -172,18 +205,26 @@ class YOLOv3Loss(object): # new_loss_cls.stop_gradient = True # pos_loss_cls_ratio = orig_loss_cls / new_loss_cls - tscale_tobj = tscale * tobj - loss_x = fluid.layers.sigmoid_cross_entropy_with_logits( - x, tx) * tscale_tobj - loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) - loss_y = fluid.layers.sigmoid_cross_entropy_with_logits( - y, ty) * tscale_tobj - loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(x, + tx) * tscale + loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(y, + ty) * tscale + loss_xy = loss_x + loss_y # NOTE: we refined loss function of (w, h) as L1Loss - loss_w = fluid.layers.abs(w - tw) * tscale_tobj - loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) - loss_h = fluid.layers.abs(h - th) * tscale_tobj - loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + loss_w = fluid.layers.abs(w - tw) * tscale + loss_h = fluid.layers.abs(h - th) * tscale + loss_wh = loss_w + loss_h + + loss_carl = (loss_xy + loss_wh) * carl_weights + loss_carl = fluid.layers.reduce_sum(loss_carl, dim=[1, 2, 3]) + + # loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3]) + # loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3]) + # loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3]) + # loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3]) + loss_xy = fluid.layers.reduce_sum(loss_xy * tobj, dim=[1, 2, 3]) + loss_wh = fluid.layers.reduce_sum(loss_wh * tobj, dim=[1, 2, 3]) + if self._iou_loss is not None: loss_iou = self._iou_loss(x, y, w, h, tx, ty, tw, th, anchors, downsample, self._batch_size) @@ -200,6 +241,8 @@ class YOLOv3Loss(object): loss_iou_aware, dim=[1, 2, 3]) loss_iou_awares.append(fluid.layers.reduce_mean(loss_iou_aware)) + # tobj = tobj * pos_weights + loss_obj_pos, loss_obj_neg = self._calc_obj_loss( output, obj, tobj, iou, an_num, self._ignore_thresh, scale_x_y) @@ -207,8 +250,10 @@ class YOLOv3Loss(object): loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0) loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4]) - loss_xys.append(fluid.layers.reduce_mean(loss_x + loss_y)) - loss_whs.append(fluid.layers.reduce_mean(loss_w + loss_h)) + loss_xys.append(fluid.layers.reduce_mean(loss_xy)) + loss_whs.append(fluid.layers.reduce_mean(loss_wh)) + loss_isrp_clss.append(fluid.layers.reduce_mean(loss_isrp_cls)) + loss_carls.append(fluid.layers.reduce_mean(loss_carl)) loss_objs.append( fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)) loss_clss.append(fluid.layers.reduce_mean(loss_cls)) @@ -216,6 +261,8 @@ class YOLOv3Loss(object): losses_all = { "loss_xy": fluid.layers.sum(loss_xys), "loss_wh": fluid.layers.sum(loss_whs), + "loss_isrp_cls": fluid.layers.sum(loss_isrp_clss), + "loss_carl": fluid.layers.sum(loss_carls), "loss_obj": fluid.layers.sum(loss_objs), "loss_cls": fluid.layers.sum(loss_clss), }