post_process.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.py_op.post_process import mask_post_process
from . import ops


@register
class BBoxPostProcess(object):
    __inject__ = ['decode', 'nms']

    def __init__(self, decode=None, nms=None):
        super(BBoxPostProcess, self).__init__()
        self.decode = decode
        self.nms = nms

    def __call__(self, head_out, rois, im_shape, scale_factor=None):
20
        bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
21
        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
22 23 24 25 26 27 28 29 30 31 32 33
        return bbox_pred, bbox_num


@register
class MaskPostProcess(object):
    __shared__ = ['mask_resolution']

    def __init__(self, mask_resolution=28, binary_thresh=0.5):
        super(MaskPostProcess, self).__init__()
        self.mask_resolution = mask_resolution
        self.binary_thresh = binary_thresh

34
    def __call__(self, bboxes, mask_head_out, im_shape, scale_factor=None):
35 36 37 38
        # TODO: modify related ops for deploying
        bboxes_np = (i.numpy() for i in bboxes)
        mask = mask_post_process(bboxes_np,
                                 mask_head_out.numpy(),
39 40
                                 im_shape.numpy(), scale_factor[:, 0].numpy(),
                                 self.mask_resolution, self.binary_thresh)
41 42
        mask = {'mask': mask}
        return mask