mask.py 1.2 KB
Newer Older
1 2 3 4 5 6 7
import numpy as np
import paddle.fluid as fluid
from ppdet.core.workspace import register


@register
class Mask(object):
8
    __inject__ = ['mask_target_generator']
9

10
    def __init__(self, mask_target_generator):
11 12 13
        super(Mask, self).__init__()
        self.mask_target_generator = mask_target_generator

14 15 16 17
    def __call__(self, inputs, rois, targets):
        mask_rois, rois_has_mask_int32 = self.generate_mask_target(inputs, rois,
                                                                   targets)
        return mask_rois, rois_has_mask_int32
18

19 20 21 22
    def generate_mask_target(self, inputs, rois, targets):
        labels_int32 = targets['labels_int32']
        proposals, proposals_num = rois
        mask_rois, mask_rois_num, self.rois_has_mask_int32, self.mask_int32 = self.mask_target_generator(
23 24 25 26
            im_info=inputs['im_info'],
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_segms=inputs['gt_mask'],
27 28 29 30 31 32 33 34
            rois=proposals,
            rois_num=proposals_num,
            labels_int32=labels_int32)
        self.mask_rois = (mask_rois, mask_rois_num)
        return self.mask_rois, self.rois_has_mask_int32

    def get_targets(self):
        return self.mask_int32