未验证 提交 26a86017 编写于 作者: W wangguanzhong 提交者: GitHub

add assign_on_cpu on postprocess (#5625)

* add assign_on_cpu on postprocess

* add gpu check
上级 c7c59112
...@@ -179,7 +179,7 @@ class BBoxPostProcess(object): ...@@ -179,7 +179,7 @@ class BBoxPostProcess(object):
@register @register
class MaskPostProcess(object): class MaskPostProcess(object):
__shared__ = ['export_onnx'] __shared__ = ['export_onnx', 'assign_on_cpu']
""" """
refer to: refer to:
https://github.com/facebookresearch/detectron2/layers/mask_ops.py https://github.com/facebookresearch/detectron2/layers/mask_ops.py
...@@ -187,10 +187,15 @@ class MaskPostProcess(object): ...@@ -187,10 +187,15 @@ class MaskPostProcess(object):
Get Mask output according to the output from model Get Mask output according to the output from model
""" """
def __init__(self, binary_thresh=0.5, export_onnx=False): def __init__(self,
binary_thresh=0.5,
export_onnx=False,
assign_on_cpu=False):
super(MaskPostProcess, self).__init__() super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh self.binary_thresh = binary_thresh
self.export_onnx = export_onnx self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
def paste_mask(self, masks, boxes, im_h, im_w): def paste_mask(self, masks, boxes, im_h, im_w):
""" """
...@@ -207,6 +212,8 @@ class MaskPostProcess(object): ...@@ -207,6 +212,8 @@ class MaskPostProcess(object):
img_x = (img_x - x0) / (x1 - x0) * 2 - 1 img_x = (img_x - x0) / (x1 - x0) * 2 - 1
# img_x, img_y have shapes (N, w), (N, h) # img_x, img_y have shapes (N, w), (N, h)
if self.assign_on_cpu:
paddle.set_device('cpu')
gx = img_x[:, None, :].expand( gx = img_x[:, None, :].expand(
[N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
gy = img_y[:, :, None].expand( gy = img_y[:, :, None].expand(
...@@ -261,6 +268,8 @@ class MaskPostProcess(object): ...@@ -261,6 +268,8 @@ class MaskPostProcess(object):
pred_result[id_start:id_start + bbox_num[i], :im_h, : pred_result[id_start:id_start + bbox_num[i], :im_h, :
im_w] = pred_mask im_w] = pred_mask
id_start += bbox_num[i] id_start += bbox_num[i]
if self.assign_on_cpu:
paddle.set_device('gpu')
return pred_result return pred_result
......
...@@ -68,7 +68,8 @@ class RPNTargetAssign(object): ...@@ -68,7 +68,8 @@ class RPNTargetAssign(object):
self.negative_overlap = negative_overlap self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
self.assign_on_cpu = assign_on_cpu self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
def __call__(self, inputs, anchors): def __call__(self, inputs, anchors):
""" """
...@@ -149,7 +150,8 @@ class BBoxAssigner(object): ...@@ -149,7 +150,8 @@ class BBoxAssigner(object):
self.use_random = use_random self.use_random = use_random
self.cascade_iou = cascade_iou self.cascade_iou = cascade_iou
self.num_classes = num_classes self.num_classes = num_classes
self.assign_on_cpu = assign_on_cpu self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
def __call__(self, def __call__(self,
rpn_rois, rpn_rois,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册