mask.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
import numpy as np
import paddle.fluid as fluid
from ppdet.core.workspace import register

# TODO: regitster mask_post_process op 
from ppdet.py_op.post_process import mask_post_process


@register
class MaskPostProcess(object):
11
    __shared__ = ['mask_resolution']
12

13
    def __init__(self, mask_resolution=28, binary_thresh=0.5):
14
        super(MaskPostProcess, self).__init__()
15 16
        self.mask_resolution = mask_resolution
        self.binary_thresh = binary_thresh
17

18
    def __call__(self, bboxes, mask_head_out, im_info):
19
        # TODO: modify related ops for deploying
20 21 22 23 24 25 26
        bboxes_np = (i.numpy() for i in bboxes)
        mask = mask_post_process(bboxes_np,
                                 mask_head_out.numpy(),
                                 im_info.numpy(), self.mask_resolution,
                                 self.binary_thresh)
        mask = {'mask': mask}
        return mask
27 28 29 30 31 32 33 34 35 36 37


@register
class Mask(object):
    __inject__ = ['mask_target_generator', 'mask_post_process']

    def __init__(self, mask_target_generator, mask_post_process):
        super(Mask, self).__init__()
        self.mask_target_generator = mask_target_generator
        self.mask_post_process = mask_post_process

38 39 40 41
    def __call__(self, inputs, rois, targets):
        mask_rois, rois_has_mask_int32 = self.generate_mask_target(inputs, rois,
                                                                   targets)
        return mask_rois, rois_has_mask_int32
42

43 44 45 46
    def generate_mask_target(self, inputs, rois, targets):
        labels_int32 = targets['labels_int32']
        proposals, proposals_num = rois
        mask_rois, mask_rois_num, self.rois_has_mask_int32, self.mask_int32 = self.mask_target_generator(
47 48 49 50
            im_info=inputs['im_info'],
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_segms=inputs['gt_mask'],
51 52 53 54 55 56 57 58 59 60 61 62
            rois=proposals,
            rois_num=proposals_num,
            labels_int32=labels_int32)
        self.mask_rois = (mask_rois, mask_rois_num)
        return self.mask_rois, self.rois_has_mask_int32

    def get_targets(self):
        return self.mask_int32

    def post_process(self, bboxes, mask_head_out, im_info):
        mask = self.mask_post_process(bboxes, mask_head_out, im_info)
        return mask