未验证 提交 6ef164f6 编写于 作者: W wangguanzhong 提交者: GitHub

cherry-pick refine device (#5926)

* add assign_on_cpu on postprocess

* add gpu check

* refine device when assign on cpu
上级 ef155443
......@@ -179,7 +179,7 @@ class BBoxPostProcess(object):
@register
class MaskPostProcess(object):
__shared__ = ['export_onnx']
__shared__ = ['export_onnx', 'assign_on_cpu']
"""
refer to:
https://github.com/facebookresearch/detectron2/layers/mask_ops.py
......@@ -187,10 +187,14 @@ class MaskPostProcess(object):
Get Mask output according to the output from model
"""
def __init__(self, binary_thresh=0.5, export_onnx=False):
def __init__(self,
binary_thresh=0.5,
export_onnx=False,
assign_on_cpu=False):
super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh
self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu
def paste_mask(self, masks, boxes, im_h, im_w):
"""
......@@ -207,6 +211,8 @@ class MaskPostProcess(object):
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h)
if self.assign_on_cpu:
paddle.set_device('cpu')
gx = img_x[:, None, :].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand(
......@@ -233,6 +239,7 @@ class MaskPostProcess(object):
"""
num_mask = mask_out.shape[0]
origin_shape = paddle.cast(origin_shape, 'int32')
device = paddle.device.get_device()
if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1]
......@@ -261,6 +268,8 @@ class MaskPostProcess(object):
pred_result[id_start:id_start + bbox_num[i], :im_h, :
im_w] = pred_mask
id_start += bbox_num[i]
if self.assign_on_cpu:
paddle.set_device(device)
return pred_result
......
......@@ -74,9 +74,11 @@ def label_box(anchors,
is_crowd=None,
assign_on_cpu=False):
if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device("gpu")
paddle.set_device(device)
else:
iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册