未验证 提交 6ef164f6 编写于 作者: W wangguanzhong 提交者: GitHub

cherry-pick refine device (#5926)

* add assign_on_cpu on postprocess

* add gpu check

* refine device when assign on cpu
上级 ef155443
...@@ -179,7 +179,7 @@ class BBoxPostProcess(object): ...@@ -179,7 +179,7 @@ class BBoxPostProcess(object):
@register @register
class MaskPostProcess(object): class MaskPostProcess(object):
__shared__ = ['export_onnx'] __shared__ = ['export_onnx', 'assign_on_cpu']
""" """
refer to: refer to:
https://github.com/facebookresearch/detectron2/layers/mask_ops.py https://github.com/facebookresearch/detectron2/layers/mask_ops.py
...@@ -187,10 +187,14 @@ class MaskPostProcess(object): ...@@ -187,10 +187,14 @@ class MaskPostProcess(object):
Get Mask output according to the output from model Get Mask output according to the output from model
""" """
def __init__(self, binary_thresh=0.5, export_onnx=False): def __init__(self,
binary_thresh=0.5,
export_onnx=False,
assign_on_cpu=False):
super(MaskPostProcess, self).__init__() super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh self.binary_thresh = binary_thresh
self.export_onnx = export_onnx self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu
def paste_mask(self, masks, boxes, im_h, im_w): def paste_mask(self, masks, boxes, im_h, im_w):
""" """
...@@ -207,6 +211,8 @@ class MaskPostProcess(object): ...@@ -207,6 +211,8 @@ class MaskPostProcess(object):
img_x = (img_x - x0) / (x1 - x0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h) # img_x, img_y have shapes (N, w), (N, h)
if self.assign_on_cpu:
paddle.set_device('cpu')
gx = img_x[:, None, :].expand( gx = img_x[:, None, :].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand( gy = img_y[:, :, None].expand(
...@@ -233,6 +239,7 @@ class MaskPostProcess(object): ...@@ -233,6 +239,7 @@ class MaskPostProcess(object):
""" """
num_mask = mask_out.shape[0] num_mask = mask_out.shape[0]
origin_shape = paddle.cast(origin_shape, 'int32') origin_shape = paddle.cast(origin_shape, 'int32')
device = paddle.device.get_device()
if self.export_onnx: if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1] h, w = origin_shape[0][0], origin_shape[0][1]
...@@ -261,6 +268,8 @@ class MaskPostProcess(object): ...@@ -261,6 +268,8 @@ class MaskPostProcess(object):
pred_result[id_start:id_start + bbox_num[i], :im_h, : pred_result[id_start:id_start + bbox_num[i], :im_h, :
im_w] = pred_mask im_w] = pred_mask
id_start += bbox_num[i] id_start += bbox_num[i]
if self.assign_on_cpu:
paddle.set_device(device)
return pred_result return pred_result
......
...@@ -74,9 +74,11 @@ def label_box(anchors, ...@@ -74,9 +74,11 @@ def label_box(anchors,
is_crowd=None, is_crowd=None,
assign_on_cpu=False): assign_on_cpu=False):
if assign_on_cpu: if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu") paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device("gpu") paddle.set_device(device)
else: else:
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0] n_gt = gt_boxes.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册