未验证 提交 588d78b4 编写于 作者: W wangguanzhong 提交者: GitHub

cherry-pick fix cascade (#2381)

上级 626d2fa1
......@@ -265,6 +265,11 @@ class BBoxHead(nn.Layer):
reg_name = 'loss_bbox_reg'
loss_bbox = {}
loss_weight = 1.
if fg_inds.numel() == 0:
fg_inds = paddle.zeros([1], dtype='int32')
loss_weight = 0.
if cls_agnostic_bbox_reg:
reg_delta = paddle.gather(deltas, fg_inds)
else:
......@@ -291,8 +296,8 @@ class BBoxHead(nn.Layer):
loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
) / tgt_labels.shape[0]
loss_bbox[cls_name] = loss_bbox_cls
loss_bbox[reg_name] = loss_bbox_reg
loss_bbox[cls_name] = loss_bbox_cls * loss_weight
loss_bbox[reg_name] = loss_bbox_reg * loss_weight
return loss_bbox
......
......@@ -196,17 +196,7 @@ class CascadeHead(BBoxHead):
if self.training:
rois, rois_num, targets = self.bbox_assigner(
rois, rois_num, inputs, i, is_cascade=True)
tgt_labels = targets[0]
tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0]
tgt_labels.stop_gradient = True
fg_inds = paddle.nonzero(
paddle.logical_and(tgt_labels >= 0, tgt_labels <
self.num_classes)).flatten()
if fg_inds.numel() == 0:
targets_list.append(targets_list[-1])
else:
targets_list.append(targets)
targets_list.append(targets)
rois_feat = self.roi_extractor(body_feats, rois, rois_num)
bbox_feat = self.head(rois_feat, i)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册