未验证 提交 a8530168 编写于 作者: W wangguanzhong 提交者: GitHub

refine device when assign on cpu (#5924)

上级 7d01ea47
......@@ -194,8 +194,7 @@ class MaskPostProcess(object):
super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh
self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
self.assign_on_cpu = assign_on_cpu
def paste_mask(self, masks, boxes, im_h, im_w):
"""
......@@ -240,6 +239,7 @@ class MaskPostProcess(object):
"""
num_mask = mask_out.shape[0]
origin_shape = paddle.cast(origin_shape, 'int32')
device = paddle.device.get_device()
if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1]
......@@ -269,7 +269,7 @@ class MaskPostProcess(object):
im_w] = pred_mask
id_start += bbox_num[i]
if self.assign_on_cpu:
paddle.set_device('gpu')
paddle.set_device(device)
return pred_result
......
......@@ -74,9 +74,11 @@ def label_box(anchors,
is_crowd=None,
assign_on_cpu=False):
if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device("gpu")
paddle.set_device(device)
else:
iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0]
......
......@@ -68,8 +68,7 @@ class RPNTargetAssign(object):
self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh
self.use_random = use_random
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
self.assign_on_cpu = assign_on_cpu
def __call__(self, inputs, anchors):
"""
......@@ -150,8 +149,7 @@ class BBoxAssigner(object):
self.use_random = use_random
self.cascade_iou = cascade_iou
self.num_classes = num_classes
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda(
)
self.assign_on_cpu = assign_on_cpu
def __call__(self,
rpn_rois,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册