未验证 提交 a8530168 编写于 作者: W wangguanzhong 提交者: GitHub

refine device when assign on cpu (#5924)

上级 7d01ea47
...@@ -194,8 +194,7 @@ class MaskPostProcess(object): ...@@ -194,8 +194,7 @@ class MaskPostProcess(object):
super(MaskPostProcess, self).__init__() super(MaskPostProcess, self).__init__()
self.binary_thresh = binary_thresh self.binary_thresh = binary_thresh
self.export_onnx = export_onnx self.export_onnx = export_onnx
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda( self.assign_on_cpu = assign_on_cpu
)
def paste_mask(self, masks, boxes, im_h, im_w): def paste_mask(self, masks, boxes, im_h, im_w):
""" """
...@@ -240,6 +239,7 @@ class MaskPostProcess(object): ...@@ -240,6 +239,7 @@ class MaskPostProcess(object):
""" """
num_mask = mask_out.shape[0] num_mask = mask_out.shape[0]
origin_shape = paddle.cast(origin_shape, 'int32') origin_shape = paddle.cast(origin_shape, 'int32')
device = paddle.device.get_device()
if self.export_onnx: if self.export_onnx:
h, w = origin_shape[0][0], origin_shape[0][1] h, w = origin_shape[0][0], origin_shape[0][1]
...@@ -269,7 +269,7 @@ class MaskPostProcess(object): ...@@ -269,7 +269,7 @@ class MaskPostProcess(object):
im_w] = pred_mask im_w] = pred_mask
id_start += bbox_num[i] id_start += bbox_num[i]
if self.assign_on_cpu: if self.assign_on_cpu:
paddle.set_device('gpu') paddle.set_device(device)
return pred_result return pred_result
......
...@@ -74,9 +74,11 @@ def label_box(anchors, ...@@ -74,9 +74,11 @@ def label_box(anchors,
is_crowd=None, is_crowd=None,
assign_on_cpu=False): assign_on_cpu=False):
if assign_on_cpu: if assign_on_cpu:
device = paddle.device.get_device()
paddle.set_device("cpu") paddle.set_device("cpu")
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
paddle.set_device("gpu") paddle.set_device(device)
else: else:
iou = bbox_overlaps(gt_boxes, anchors) iou = bbox_overlaps(gt_boxes, anchors)
n_gt = gt_boxes.shape[0] n_gt = gt_boxes.shape[0]
......
...@@ -68,8 +68,7 @@ class RPNTargetAssign(object): ...@@ -68,8 +68,7 @@ class RPNTargetAssign(object):
self.negative_overlap = negative_overlap self.negative_overlap = negative_overlap
self.ignore_thresh = ignore_thresh self.ignore_thresh = ignore_thresh
self.use_random = use_random self.use_random = use_random
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda( self.assign_on_cpu = assign_on_cpu
)
def __call__(self, inputs, anchors): def __call__(self, inputs, anchors):
""" """
...@@ -150,8 +149,7 @@ class BBoxAssigner(object): ...@@ -150,8 +149,7 @@ class BBoxAssigner(object):
self.use_random = use_random self.use_random = use_random
self.cascade_iou = cascade_iou self.cascade_iou = cascade_iou
self.num_classes = num_classes self.num_classes = num_classes
self.assign_on_cpu = assign_on_cpu and paddle.device.is_compiled_with_cuda( self.assign_on_cpu = assign_on_cpu
)
def __call__(self, def __call__(self,
rpn_rois, rpn_rois,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册