未验证 提交 22848d06 编写于 作者: W wangguanzhong 提交者: GitHub

refine device when assign on cpu (#5925)

* refine device when assign on cpu

* cherry-pick assign on cpu in mask postprocess
上级 08d62bf3
......@@ -154,6 +154,7 @@ class BBoxPostProcess(nn.Layer):
@register
class MaskPostProcess(object):
__shared__ = ['assign_on_cpu']
"""
refer to:
https://github.com/facebookresearch/detectron2/layers/mask_ops.py
......@@ -161,9 +162,10 @@ class MaskPostProcess(object):
Get Mask output according to the output from model
"""
def __init__(self, binary_thresh=0.5):
def __init__(self, binary_thresh=0.5, assign_on_cpu=False):
super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh
self.assign_on_cpu = assign_on_cpu
def paste_mask(self, masks, boxes, im_h, im_w):
"""
......@@ -210,6 +212,9 @@ class MaskPostProcess(object):
if bbox_num == 1 and bboxes[0][0] == -1:
return pred_result
if self.assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device('cpu')
# TODO: optimize chunk paste
pred_result = []
for i in range(bboxes.shape[0]):
......@@ -220,6 +225,9 @@ class MaskPostProcess(object):
pred_mask = paddle.cast(pred_mask, 'int32')
pred_result.append(pred_mask)
pred_result = paddle.concat(pred_result)
if self.assign_on_cpu:
paddle.set_device(device)
return pred_result
......
......@@ -74,9 +74,10 @@ def label_box(anchors,
is_crowd=None,
assign_on_cpu=False):
if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device("gpu")
paddle.set_device(device)
else:
iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册