未验证 提交 79e2436a 编写于 作者: G Guanghua Yu 提交者: GitHub

fix distribute train in cascade_rcnn (#2246)

上级 79b1d807
...@@ -265,10 +265,6 @@ class BBoxHead(nn.Layer): ...@@ -265,10 +265,6 @@ class BBoxHead(nn.Layer):
reg_name = 'loss_bbox_reg' reg_name = 'loss_bbox_reg'
loss_bbox = {} loss_bbox = {}
if fg_inds.numel() == 0:
loss_bbox[cls_name] = paddle.to_tensor(0., dtype='float32')
loss_bbox[reg_name] = paddle.to_tensor(0., dtype='float32')
return loss_bbox
if cls_agnostic_bbox_reg: if cls_agnostic_bbox_reg:
reg_delta = paddle.gather(deltas, fg_inds) reg_delta = paddle.gather(deltas, fg_inds)
else: else:
......
...@@ -196,6 +196,16 @@ class CascadeHead(BBoxHead): ...@@ -196,6 +196,16 @@ class CascadeHead(BBoxHead):
if self.training: if self.training:
rois, rois_num, targets = self.bbox_assigner( rois, rois_num, targets = self.bbox_assigner(
rois, rois_num, inputs, i, is_cascade=True) rois, rois_num, inputs, i, is_cascade=True)
tgt_labels = targets[0]
tgt_labels = paddle.concat(tgt_labels) if len(
tgt_labels) > 1 else tgt_labels[0]
tgt_labels.stop_gradient = True
fg_inds = paddle.nonzero(
paddle.logical_and(tgt_labels >= 0, tgt_labels <
self.num_classes)).flatten()
if fg_inds.numel() == 0:
targets_list.append(targets_list[-1])
else:
targets_list.append(targets) targets_list.append(targets)
rois_feat = self.roi_extractor(body_feats, rois, rois_num) rois_feat = self.roi_extractor(body_feats, rois, rois_num)
...@@ -227,6 +237,8 @@ class CascadeHead(BBoxHead): ...@@ -227,6 +237,8 @@ class CascadeHead(BBoxHead):
clip_box = clip_bbox(boxes_per_image, im_shape[i]) clip_box = clip_bbox(boxes_per_image, im_shape[i])
if self.training: if self.training:
keep = nonempty_bbox(clip_box) keep = nonempty_bbox(clip_box)
if keep.shape[0] == 0:
continue
clip_box = paddle.gather(clip_box, keep) clip_box = paddle.gather(clip_box, keep)
rois.append(clip_box) rois.append(clip_box)
rois_num = paddle.concat([paddle.shape(r)[0] for r in rois]) rois_num = paddle.concat([paddle.shape(r)[0] for r in rois])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册