post_process.py 1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.py_op.post_process import mask_post_process
from . import ops


@register
class BBoxPostProcess(object):
    __inject__ = ['decode', 'nms']

    def __init__(self, decode=None, nms=None):
        super(BBoxPostProcess, self).__init__()
        self.decode = decode
        self.nms = nms

W
wangguanzhong 已提交
19 20 21 22 23 24 25 26
    def __call__(self,
                 head_out,
                 rois,
                 im_shape,
                 scale_factor=None,
                 var_weight=1.):
        bboxes, score = self.decode(head_out, rois, im_shape, scale_factor,
                                    var_weight)
27
        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
28 29 30 31 32 33 34 35 36 37 38 39
        return bbox_pred, bbox_num


@register
class MaskPostProcess(object):
    __shared__ = ['mask_resolution']

    def __init__(self, mask_resolution=28, binary_thresh=0.5):
        super(MaskPostProcess, self).__init__()
        self.mask_resolution = mask_resolution
        self.binary_thresh = binary_thresh

40
    def __call__(self, bboxes, mask_head_out, im_shape, scale_factor=None):
41 42 43 44
        # TODO: modify related ops for deploying
        bboxes_np = (i.numpy() for i in bboxes)
        mask = mask_post_process(bboxes_np,
                                 mask_head_out.numpy(),
45 46
                                 im_shape.numpy(), scale_factor[:, 0].numpy(),
                                 self.mask_resolution, self.binary_thresh)
47 48
        mask = {'mask': mask}
        return mask