bbox.py 7.0 KB
Newer Older
F
FDInSky 已提交
1
import numpy as np
W
wangxinxin08 已提交
2 3 4
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
FDInSky 已提交
5
from ppdet.core.workspace import register
6
from . import ops
F
FDInSky 已提交
7 8 9


@register
10
class Anchor(object):
F
FDInSky 已提交
11 12
    __inject__ = ['anchor_generator', 'anchor_target_generator']

13
    def __init__(self, anchor_generator, anchor_target_generator):
14
        super(Anchor, self).__init__()
F
FDInSky 已提交
15 16 17
        self.anchor_generator = anchor_generator
        self.anchor_target_generator = anchor_target_generator

18 19 20 21 22 23 24 25 26 27 28 29 30
    def __call__(self, rpn_feats):
        anchors = []
        num_level = len(rpn_feats)
        for i, rpn_feat in enumerate(rpn_feats):
            anchor, var = self.anchor_generator(rpn_feat, i)
            anchors.append((anchor, var))
        return anchors

    def _get_target_input(self, rpn_feats, anchors):
        rpn_score_list = []
        rpn_delta_list = []
        anchor_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_feats, anchors):
W
wangguanzhong 已提交
31 32 33 34
            rpn_score = paddle.transpose(rpn_score, perm=[0, 2, 3, 1])
            rpn_delta = paddle.transpose(rpn_delta, perm=[0, 2, 3, 1])
            rpn_score = paddle.reshape(x=rpn_score, shape=(0, -1, 1))
            rpn_delta = paddle.reshape(x=rpn_delta, shape=(0, -1, 4))
35

W
wangguanzhong 已提交
36 37
            anchor = paddle.reshape(anchor, shape=(-1, 4))
            var = paddle.reshape(var, shape=(-1, 4))
38 39 40 41
            rpn_score_list.append(rpn_score)
            rpn_delta_list.append(rpn_delta)
            anchor_list.append(anchor)

W
wangguanzhong 已提交
42 43 44
        rpn_scores = paddle.concat(rpn_score_list, axis=1)
        rpn_deltas = paddle.concat(rpn_delta_list, axis=1)
        anchors = paddle.concat(anchor_list)
45 46 47
        return rpn_scores, rpn_deltas, anchors

    def generate_loss_inputs(self, inputs, rpn_head_out, anchors):
Q
qingqing01 已提交
48 49 50 51 52
        if len(rpn_head_out) != len(anchors):
            raise ValueError(
                "rpn_head_out and anchors should have same length, "
                " but received rpn_head_out' length is {} and anchors' "
                " length is {}".format(len(rpn_head_out), len(anchors)))
53 54
        rpn_score, rpn_delta, anchors = self._get_target_input(rpn_head_out,
                                                               anchors)
F
FDInSky 已提交
55 56

        score_pred, roi_pred, score_tgt, roi_tgt, roi_weight = self.anchor_target_generator(
57 58 59
            bbox_pred=rpn_delta,
            cls_logits=rpn_score,
            anchor_box=anchors,
F
FDInSky 已提交
60 61
            gt_boxes=inputs['gt_bbox'],
            is_crowd=inputs['is_crowd'],
62
            im_info=inputs['im_info'])
F
FDInSky 已提交
63 64 65 66 67 68 69 70 71
        outs = {
            'rpn_score_pred': score_pred,
            'rpn_score_target': score_tgt,
            'rpn_rois_pred': roi_pred,
            'rpn_rois_target': roi_tgt,
            'rpn_rois_weight': roi_weight
        }
        return outs

72

F
FDInSky 已提交
73 74
@register
class Proposal(object):
75
    __inject__ = ['proposal_generator', 'proposal_target_generator']
F
FDInSky 已提交
76

77
    def __init__(self, proposal_generator, proposal_target_generator):
F
FDInSky 已提交
78 79 80
        super(Proposal, self).__init__()
        self.proposal_generator = proposal_generator
        self.proposal_target_generator = proposal_target_generator
81

82
    def generate_proposal(self, inputs, rpn_head_out, anchor_out):
83 84 85 86 87
        # TODO: delete im_info 
        try:
            im_shape = inputs['im_info']
        except:
            im_shape = inputs['im_shape']
88 89 90 91 92
        rpn_rois_list = []
        rpn_prob_list = []
        rpn_rois_num_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_head_out,
                                                         anchor_out):
W
wangguanzhong 已提交
93
            rpn_prob = F.sigmoid(rpn_score)
94 95 96 97 98
            rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = self.proposal_generator(
                scores=rpn_prob,
                bbox_deltas=rpn_delta,
                anchors=anchor,
                variances=var,
99
                im_shape=im_shape,
100 101 102 103 104 105 106 107 108
                mode=inputs['mode'])
            if len(rpn_head_out) == 1:
                return rpn_rois, rpn_rois_num
            rpn_rois_list.append(rpn_rois)
            rpn_prob_list.append(rpn_rois_prob)
            rpn_rois_num_list.append(rpn_rois_num)

        start_level = 2
        end_level = start_level + len(rpn_head_out)
109
        rois_collect, rois_num_collect = ops.collect_fpn_proposals(
110 111 112 113 114 115 116 117
            rpn_rois_list,
            rpn_prob_list,
            start_level,
            end_level,
            post_nms_top_n,
            rois_num_per_level=rpn_rois_num_list)
        return rois_collect, rois_num_collect

W
wangguanzhong 已提交
118 119 120 121 122 123
    def generate_proposal_target(self,
                                 inputs,
                                 rois,
                                 rois_num,
                                 stage=0,
                                 max_overlap=None):
F
FDInSky 已提交
124
        outs = self.proposal_target_generator(
125
            rpn_rois=rois,
126
            rpn_rois_num=rois_num,
F
FDInSky 已提交
127 128 129 130
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_boxes=inputs['gt_bbox'],
            im_info=inputs['im_info'],
W
wangguanzhong 已提交
131 132
            stage=stage,
            max_overlap=max_overlap)
133
        rois = outs[0]
W
wangguanzhong 已提交
134 135
        max_overlap = outs[-1]
        rois_num = outs[-2]
136
        targets = {
F
FDInSky 已提交
137 138 139
            'labels_int32': outs[1],
            'bbox_targets': outs[2],
            'bbox_inside_weights': outs[3],
140
            'bbox_outside_weights': outs[4]
F
FDInSky 已提交
141
        }
W
wangguanzhong 已提交
142
        return rois, rois_num, targets, max_overlap
143

W
wangguanzhong 已提交
144 145 146 147
    def refine_bbox(self, roi, bbox_delta, stage=1):
        out_dim = bbox_delta.shape[1] // 4
        bbox_delta_r = paddle.reshape(bbox_delta, (-1, out_dim, 4))
        bbox_delta_s = paddle.slice(
148 149
            bbox_delta_r, axes=[1], starts=[1], ends=[2])

W
wangguanzhong 已提交
150 151 152
        reg_weights = [
            i / stage for i in self.proposal_target_generator.bbox_reg_weights
        ]
W
wangguanzhong 已提交
153
        refined_bbox = ops.box_coder(
W
wangguanzhong 已提交
154 155
            prior_box=roi,
            prior_box_var=reg_weights,
156 157 158 159
            target_box=bbox_delta_s,
            code_type='decode_center_size',
            box_normalized=False,
            axis=1)
W
wangguanzhong 已提交
160
        refined_bbox = paddle.reshape(refined_bbox, shape=[-1, 4])
161 162 163 164 165 166 167 168
        return refined_bbox

    def __call__(self,
                 inputs,
                 rpn_head_out,
                 anchor_out,
                 stage=0,
                 proposal_out=None,
W
wangguanzhong 已提交
169 170
                 bbox_head_out=None,
                 max_overlap=None):
171 172 173 174
        if stage == 0:
            roi, rois_num = self.generate_proposal(inputs, rpn_head_out,
                                                   anchor_out)
            self.targets_list = []
W
wangguanzhong 已提交
175
            self.max_overlap = None
F
FDInSky 已提交
176

177
        else:
W
wangguanzhong 已提交
178 179
            bbox_delta = bbox_head_out[1]
            roi = self.refine_bbox(proposal_out[0], bbox_delta, stage)
180 181
            rois_num = proposal_out[1]
        if inputs['mode'] == 'train':
W
wangguanzhong 已提交
182 183
            roi, rois_num, targets, self.max_overlap = self.generate_proposal_target(
                inputs, roi, rois_num, stage, self.max_overlap)
184 185 186 187 188 189
            self.targets_list.append(targets)
        return roi, rois_num

    def get_targets(self):
        return self.targets_list

W
wangguanzhong 已提交
190 191
    def get_max_overlap(self):
        return self.max_overlap