From 588d78b4bd3b4588805356e183080478df6a989e Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Fri, 19 Mar 2021 15:28:30 +0800 Subject: [PATCH] cherry-pick fix cascade (#2381) --- dygraph/ppdet/modeling/heads/bbox_head.py | 9 +++++++-- dygraph/ppdet/modeling/heads/cascade_head.py | 12 +----------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/dygraph/ppdet/modeling/heads/bbox_head.py b/dygraph/ppdet/modeling/heads/bbox_head.py index d96b8efb6..a6480961c 100644 --- a/dygraph/ppdet/modeling/heads/bbox_head.py +++ b/dygraph/ppdet/modeling/heads/bbox_head.py @@ -265,6 +265,11 @@ class BBoxHead(nn.Layer): reg_name = 'loss_bbox_reg' loss_bbox = {} + loss_weight = 1. + if fg_inds.numel() == 0: + fg_inds = paddle.zeros([1], dtype='int32') + loss_weight = 0. + if cls_agnostic_bbox_reg: reg_delta = paddle.gather(deltas, fg_inds) else: @@ -291,8 +296,8 @@ class BBoxHead(nn.Layer): loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum( ) / tgt_labels.shape[0] - loss_bbox[cls_name] = loss_bbox_cls - loss_bbox[reg_name] = loss_bbox_reg + loss_bbox[cls_name] = loss_bbox_cls * loss_weight + loss_bbox[reg_name] = loss_bbox_reg * loss_weight return loss_bbox diff --git a/dygraph/ppdet/modeling/heads/cascade_head.py b/dygraph/ppdet/modeling/heads/cascade_head.py index dabcfc64c..4ce1b8dad 100644 --- a/dygraph/ppdet/modeling/heads/cascade_head.py +++ b/dygraph/ppdet/modeling/heads/cascade_head.py @@ -196,17 +196,7 @@ class CascadeHead(BBoxHead): if self.training: rois, rois_num, targets = self.bbox_assigner( rois, rois_num, inputs, i, is_cascade=True) - tgt_labels = targets[0] - tgt_labels = paddle.concat(tgt_labels) if len( - tgt_labels) > 1 else tgt_labels[0] - tgt_labels.stop_gradient = True - fg_inds = paddle.nonzero( - paddle.logical_and(tgt_labels >= 0, tgt_labels < - self.num_classes)).flatten() - if fg_inds.numel() == 0: - targets_list.append(targets_list[-1]) - else: - targets_list.append(targets) + targets_list.append(targets) rois_feat = self.roi_extractor(body_feats, rois, rois_num) bbox_feat = self.head(rois_feat, i) -- GitLab