From 22848d06a30af870fae56e1c888f7ebea646a155 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Tue, 10 May 2022 19:33:32 +0800 Subject: [PATCH] refine device when assign on cpu (#5925) * refine device when assign on cpu * cherry-pick assign on cpu in mask postprocess --- ppdet/modeling/post_process.py | 10 +++++++++- ppdet/modeling/proposal_generator/target.py | 3 ++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 6cbb079d6..1765945df 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -154,6 +154,7 @@ class BBoxPostProcess(nn.Layer): @register class MaskPostProcess(object): + __shared__ = ['assign_on_cpu'] """ refer to: https://github.com/facebookresearch/detectron2/layers/mask_ops.py @@ -161,9 +162,10 @@ class MaskPostProcess(object): Get Mask output according to the output from model """ - def __init__(self, binary_thresh=0.5): + def __init__(self, binary_thresh=0.5, assign_on_cpu=False): super(MaskPostProcess, self).__init__() self.binary_thresh = binary_thresh + self.assign_on_cpu = assign_on_cpu def paste_mask(self, masks, boxes, im_h, im_w): """ @@ -210,6 +212,9 @@ class MaskPostProcess(object): if bbox_num == 1 and bboxes[0][0] == -1: return pred_result + if self.assign_on_cpu: + device = paddle.device.get_device() + paddle.set_device('cpu') # TODO: optimize chunk paste pred_result = [] for i in range(bboxes.shape[0]): @@ -220,6 +225,9 @@ class MaskPostProcess(object): pred_mask = paddle.cast(pred_mask, 'int32') pred_result.append(pred_mask) pred_result = paddle.concat(pred_result) + if self.assign_on_cpu: + paddle.set_device(device) + return pred_result diff --git a/ppdet/modeling/proposal_generator/target.py b/ppdet/modeling/proposal_generator/target.py index af83cfdb8..496417602 100644 --- a/ppdet/modeling/proposal_generator/target.py +++ b/ppdet/modeling/proposal_generator/target.py @@ -74,9 +74,10 @@ def label_box(anchors, is_crowd=None, assign_on_cpu=False): if assign_on_cpu: + device = paddle.device.get_device() paddle.set_device("cpu") iou = bbox_overlaps(gt_boxes, anchors) - paddle.set_device("gpu") + paddle.set_device(device) else: iou = bbox_overlaps(gt_boxes, anchors) n_gt = gt_boxes.shape[0] -- GitLab