From 79e2436a8be51ee468f226c31f40e8dc72656df3 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Sun, 21 Feb 2021 20:41:31 +0800 Subject: [PATCH] fix distribute train in cascade_rcnn (#2246) --- dygraph/ppdet/modeling/heads/bbox_head.py | 4 ---- dygraph/ppdet/modeling/heads/cascade_head.py | 14 +++++++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/dygraph/ppdet/modeling/heads/bbox_head.py b/dygraph/ppdet/modeling/heads/bbox_head.py index 2846398bf..d96b8efb6 100644 --- a/dygraph/ppdet/modeling/heads/bbox_head.py +++ b/dygraph/ppdet/modeling/heads/bbox_head.py @@ -265,10 +265,6 @@ class BBoxHead(nn.Layer): reg_name = 'loss_bbox_reg' loss_bbox = {} - if fg_inds.numel() == 0: - loss_bbox[cls_name] = paddle.to_tensor(0., dtype='float32') - loss_bbox[reg_name] = paddle.to_tensor(0., dtype='float32') - return loss_bbox if cls_agnostic_bbox_reg: reg_delta = paddle.gather(deltas, fg_inds) else: diff --git a/dygraph/ppdet/modeling/heads/cascade_head.py b/dygraph/ppdet/modeling/heads/cascade_head.py index 4ce1b8dad..8993e1935 100644 --- a/dygraph/ppdet/modeling/heads/cascade_head.py +++ b/dygraph/ppdet/modeling/heads/cascade_head.py @@ -196,7 +196,17 @@ class CascadeHead(BBoxHead): if self.training: rois, rois_num, targets = self.bbox_assigner( rois, rois_num, inputs, i, is_cascade=True) - targets_list.append(targets) + tgt_labels = targets[0] + tgt_labels = paddle.concat(tgt_labels) if len( + tgt_labels) > 1 else tgt_labels[0] + tgt_labels.stop_gradient = True + fg_inds = paddle.nonzero( + paddle.logical_and(tgt_labels >= 0, tgt_labels < + self.num_classes)).flatten() + if fg_inds.numel() == 0: + targets_list.append(targets_list[-1]) + else: + targets_list.append(targets) rois_feat = self.roi_extractor(body_feats, rois, rois_num) bbox_feat = self.head(rois_feat, i) @@ -227,6 +237,8 @@ class CascadeHead(BBoxHead): clip_box = clip_bbox(boxes_per_image, im_shape[i]) if self.training: keep = nonempty_bbox(clip_box) + if keep.shape[0] == 0: + continue clip_box = paddle.gather(clip_box, keep) rois.append(clip_box) rois_num = paddle.concat([paddle.shape(r)[0] for r in rois]) -- GitLab