bbox.py 7.2 KB
Newer Older
F
FDInSky 已提交
1 2
import numpy as np
import paddle.fluid as fluid
W
wangxinxin08 已提交
3 4 5
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
FDInSky 已提交
6
from ppdet.core.workspace import register
7
from . import ops
F
FDInSky 已提交
8 9 10


@register
11
class Anchor(object):
F
FDInSky 已提交
12 13
    __inject__ = ['anchor_generator', 'anchor_target_generator']

14
    def __init__(self, anchor_generator, anchor_target_generator):
15
        super(Anchor, self).__init__()
F
FDInSky 已提交
16 17 18
        self.anchor_generator = anchor_generator
        self.anchor_target_generator = anchor_target_generator

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
    def __call__(self, rpn_feats):
        anchors = []
        num_level = len(rpn_feats)
        for i, rpn_feat in enumerate(rpn_feats):
            anchor, var = self.anchor_generator(rpn_feat, i)
            anchors.append((anchor, var))
        return anchors

    def _get_target_input(self, rpn_feats, anchors):
        rpn_score_list = []
        rpn_delta_list = []
        anchor_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_feats, anchors):
            rpn_score = fluid.layers.transpose(rpn_score, perm=[0, 2, 3, 1])
            rpn_delta = fluid.layers.transpose(rpn_delta, perm=[0, 2, 3, 1])
            rpn_score = fluid.layers.reshape(x=rpn_score, shape=(0, -1, 1))
            rpn_delta = fluid.layers.reshape(x=rpn_delta, shape=(0, -1, 4))

            anchor = fluid.layers.reshape(anchor, shape=(-1, 4))
            var = fluid.layers.reshape(var, shape=(-1, 4))

            rpn_score_list.append(rpn_score)
            rpn_delta_list.append(rpn_delta)
            anchor_list.append(anchor)

        rpn_scores = fluid.layers.concat(rpn_score_list, axis=1)
        rpn_deltas = fluid.layers.concat(rpn_delta_list, axis=1)
        anchors = fluid.layers.concat(anchor_list)
        return rpn_scores, rpn_deltas, anchors

    def generate_loss_inputs(self, inputs, rpn_head_out, anchors):
        assert len(rpn_head_out) == len(
            anchors
        ), "rpn_head_out and anchors should have same length, but received rpn_head_out' length is {} and anchors' length is {}".format(
            len(rpn_head_out), len(anchors))
        rpn_score, rpn_delta, anchors = self._get_target_input(rpn_head_out,
                                                               anchors)
F
FDInSky 已提交
56 57

        score_pred, roi_pred, score_tgt, roi_tgt, roi_weight = self.anchor_target_generator(
58 59 60
            bbox_pred=rpn_delta,
            cls_logits=rpn_score,
            anchor_box=anchors,
F
FDInSky 已提交
61 62
            gt_boxes=inputs['gt_bbox'],
            is_crowd=inputs['is_crowd'],
63
            im_info=inputs['im_info'])
F
FDInSky 已提交
64 65 66 67 68 69 70 71 72
        outs = {
            'rpn_score_pred': score_pred,
            'rpn_score_target': score_tgt,
            'rpn_rois_pred': roi_pred,
            'rpn_rois_target': roi_tgt,
            'rpn_rois_weight': roi_weight
        }
        return outs

73 74 75

@register
class AnchorYOLO(object):
76
    __inject__ = ['anchor_generator']
77

78
    def __init__(self, anchor_generator):
79 80 81
        super(AnchorYOLO, self).__init__()
        self.anchor_generator = anchor_generator

W
wangguanzhong 已提交
82 83
    def __call__(self):
        return self.anchor_generator()
84

F
FDInSky 已提交
85 86 87

@register
class Proposal(object):
88
    __inject__ = ['proposal_generator', 'proposal_target_generator']
F
FDInSky 已提交
89

90
    def __init__(self, proposal_generator, proposal_target_generator):
F
FDInSky 已提交
91 92 93
        super(Proposal, self).__init__()
        self.proposal_generator = proposal_generator
        self.proposal_target_generator = proposal_target_generator
94

95
    def generate_proposal(self, inputs, rpn_head_out, anchor_out):
96 97 98 99 100
        # TODO: delete im_info 
        try:
            im_shape = inputs['im_info']
        except:
            im_shape = inputs['im_shape']
101 102 103 104 105 106 107 108 109 110 111
        rpn_rois_list = []
        rpn_prob_list = []
        rpn_rois_num_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_head_out,
                                                         anchor_out):
            rpn_prob = fluid.layers.sigmoid(rpn_score)
            rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = self.proposal_generator(
                scores=rpn_prob,
                bbox_deltas=rpn_delta,
                anchors=anchor,
                variances=var,
112
                im_shape=im_shape,
113 114 115 116 117 118 119 120 121
                mode=inputs['mode'])
            if len(rpn_head_out) == 1:
                return rpn_rois, rpn_rois_num
            rpn_rois_list.append(rpn_rois)
            rpn_prob_list.append(rpn_rois_prob)
            rpn_rois_num_list.append(rpn_rois_num)

        start_level = 2
        end_level = start_level + len(rpn_head_out)
122
        rois_collect, rois_num_collect = ops.collect_fpn_proposals(
123 124 125 126 127 128 129 130 131
            rpn_rois_list,
            rpn_prob_list,
            start_level,
            end_level,
            post_nms_top_n,
            rois_num_per_level=rpn_rois_num_list)
        return rois_collect, rois_num_collect

    def generate_proposal_target(self, inputs, rois, rois_num, stage=0):
F
FDInSky 已提交
132
        outs = self.proposal_target_generator(
133
            rpn_rois=rois,
134
            rpn_rois_num=rois_num,
F
FDInSky 已提交
135 136 137 138
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_boxes=inputs['gt_bbox'],
            im_info=inputs['im_info'],
139 140 141 142
            stage=stage)
        rois = outs[0]
        rois_num = outs[-1]
        targets = {
F
FDInSky 已提交
143 144 145
            'labels_int32': outs[1],
            'bbox_targets': outs[2],
            'bbox_inside_weights': outs[3],
146
            'bbox_outside_weights': outs[4]
F
FDInSky 已提交
147
        }
148
        return rois, rois_num, targets
149

150 151 152
    def refine_bbox(self, rois, bbox_delta, stage=0):
        out_dim = bbox_delta.shape[1] / 4
        bbox_delta_r = fluid.layers.reshape(bbox_delta, (-1, out_dim, 4))
153 154 155
        bbox_delta_s = fluid.layers.slice(
            bbox_delta_r, axes=[1], starts=[1], ends=[2])

W
wangguanzhong 已提交
156
        refined_bbox = ops.box_coder(
157 158
            prior_box=rois,
            prior_box_var=self.proposal_target_generator.bbox_reg_weights[
159
                stage],
160 161 162 163 164
            target_box=bbox_delta_s,
            code_type='decode_center_size',
            box_normalized=False,
            axis=1)
        refined_bbox = fluid.layers.reshape(refined_bbox, shape=[-1, 4])
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        return refined_bbox

    def __call__(self,
                 inputs,
                 rpn_head_out,
                 anchor_out,
                 stage=0,
                 proposal_out=None,
                 bbox_head_outs=None,
                 refined=False):
        if refined:
            assert proposal_out is not None, "If proposal has been refined, proposal_out should not be None."
            return proposal_out
        if stage == 0:
            roi, rois_num = self.generate_proposal(inputs, rpn_head_out,
                                                   anchor_out)
            self.proposals_list = []
            self.targets_list = []
F
FDInSky 已提交
183

184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
        else:
            bbox_delta = bbox_head_outs[stage][0]
            roi = self.refine_bbox(proposal_out[0], bbox_delta, stage - 1)
            rois_num = proposal_out[1]
        if inputs['mode'] == 'train':
            roi, rois_num, targets = self.generate_proposal_target(
                inputs, roi, rois_num, stage)
            self.targets_list.append(targets)
        self.proposals_list.append((roi, rois_num))
        return roi, rois_num

    def get_targets(self):
        return self.targets_list

    def get_proposals(self):
        return self.proposals_list