diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index de450d2ae1f8858aa559d0e70cb1d488a7aa7663..27890c17ec39f3e29a3126adab173bc9e3596bc2 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -179,7 +179,7 @@ class BBoxPostProcess(object): @register class MaskPostProcess(object): - __shared__ = ['export_onnx'] + __shared__ = ['export_onnx', 'assign_on_cpu'] """ refer to: https://github.com/facebookresearch/detectron2/layers/mask_ops.py @@ -187,10 +187,14 @@ class MaskPostProcess(object): Get Mask output according to the output from model """ - def __init__(self, binary_thresh=0.5, export_onnx=False): + def __init__(self, + binary_thresh=0.5, + export_onnx=False, + assign_on_cpu=False): super(MaskPostProcess, self).__init__() self.binary_thresh = binary_thresh self.export_onnx = export_onnx + self.assign_on_cpu = assign_on_cpu def paste_mask(self, masks, boxes, im_h, im_w): """ @@ -207,6 +211,8 @@ class MaskPostProcess(object): img_x = (img_x - x0) / (x1 - x0) * 2 - 1 # img_x, img_y have shapes (N, w), (N, h) + if self.assign_on_cpu: + paddle.set_device('cpu') gx = img_x[:, None, :].expand( [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]]) gy = img_y[:, :, None].expand( @@ -233,6 +239,7 @@ class MaskPostProcess(object): """ num_mask = mask_out.shape[0] origin_shape = paddle.cast(origin_shape, 'int32') + device = paddle.device.get_device() if self.export_onnx: h, w = origin_shape[0][0], origin_shape[0][1] @@ -261,6 +268,8 @@ class MaskPostProcess(object): pred_result[id_start:id_start + bbox_num[i], :im_h, : im_w] = pred_mask id_start += bbox_num[i] + if self.assign_on_cpu: + paddle.set_device(device) return pred_result diff --git a/ppdet/modeling/proposal_generator/target.py b/ppdet/modeling/proposal_generator/target.py index 7af30f64153acf5e9c68c51981a02c76acbe50f0..7202d048226e78a8ba470394253e7da488634b07 100644 --- a/ppdet/modeling/proposal_generator/target.py +++ b/ppdet/modeling/proposal_generator/target.py @@ -74,9 +74,11 @@ def label_box(anchors, is_crowd=None, assign_on_cpu=False): if assign_on_cpu: + device = paddle.device.get_device() paddle.set_device("cpu") iou = bbox_overlaps(gt_boxes, anchors) - paddle.set_device("gpu") + paddle.set_device(device) + else: iou = bbox_overlaps(gt_boxes, anchors) n_gt = gt_boxes.shape[0]