bbox.py 7.2 KB
Newer Older
F
FDInSky 已提交
1
import numpy as np
W
wangxinxin08 已提交
2 3 4
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
FDInSky 已提交
5
from ppdet.core.workspace import register
6
from . import ops
F
FDInSky 已提交
7 8 9


@register
10
class Anchor(object):
F
FDInSky 已提交
11 12
    __inject__ = ['anchor_generator', 'anchor_target_generator']

13
    def __init__(self, anchor_generator, anchor_target_generator):
14
        super(Anchor, self).__init__()
F
FDInSky 已提交
15 16 17
        self.anchor_generator = anchor_generator
        self.anchor_target_generator = anchor_target_generator

18 19 20 21 22 23 24 25 26 27 28 29 30
    def __call__(self, rpn_feats):
        anchors = []
        num_level = len(rpn_feats)
        for i, rpn_feat in enumerate(rpn_feats):
            anchor, var = self.anchor_generator(rpn_feat, i)
            anchors.append((anchor, var))
        return anchors

    def _get_target_input(self, rpn_feats, anchors):
        rpn_score_list = []
        rpn_delta_list = []
        anchor_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_feats, anchors):
W
wangguanzhong 已提交
31 32 33 34
            rpn_score = paddle.transpose(rpn_score, perm=[0, 2, 3, 1])
            rpn_delta = paddle.transpose(rpn_delta, perm=[0, 2, 3, 1])
            rpn_score = paddle.reshape(x=rpn_score, shape=(0, -1, 1))
            rpn_delta = paddle.reshape(x=rpn_delta, shape=(0, -1, 4))
35

W
wangguanzhong 已提交
36 37
            anchor = paddle.reshape(anchor, shape=(-1, 4))
            var = paddle.reshape(var, shape=(-1, 4))
38 39 40 41
            rpn_score_list.append(rpn_score)
            rpn_delta_list.append(rpn_delta)
            anchor_list.append(anchor)

W
wangguanzhong 已提交
42 43 44
        rpn_scores = paddle.concat(rpn_score_list, axis=1)
        rpn_deltas = paddle.concat(rpn_delta_list, axis=1)
        anchors = paddle.concat(anchor_list)
45 46 47 48 49 50 51 52 53
        return rpn_scores, rpn_deltas, anchors

    def generate_loss_inputs(self, inputs, rpn_head_out, anchors):
        assert len(rpn_head_out) == len(
            anchors
        ), "rpn_head_out and anchors should have same length, but received rpn_head_out' length is {} and anchors' length is {}".format(
            len(rpn_head_out), len(anchors))
        rpn_score, rpn_delta, anchors = self._get_target_input(rpn_head_out,
                                                               anchors)
F
FDInSky 已提交
54 55

        score_pred, roi_pred, score_tgt, roi_tgt, roi_weight = self.anchor_target_generator(
56 57 58
            bbox_pred=rpn_delta,
            cls_logits=rpn_score,
            anchor_box=anchors,
F
FDInSky 已提交
59 60
            gt_boxes=inputs['gt_bbox'],
            is_crowd=inputs['is_crowd'],
61
            im_info=inputs['im_info'])
F
FDInSky 已提交
62 63 64 65 66 67 68 69 70
        outs = {
            'rpn_score_pred': score_pred,
            'rpn_score_target': score_tgt,
            'rpn_rois_pred': roi_pred,
            'rpn_rois_target': roi_tgt,
            'rpn_rois_weight': roi_weight
        }
        return outs

71 72 73

@register
class AnchorYOLO(object):
74
    __inject__ = ['anchor_generator']
75

76
    def __init__(self, anchor_generator):
77 78 79
        super(AnchorYOLO, self).__init__()
        self.anchor_generator = anchor_generator

W
wangguanzhong 已提交
80 81
    def __call__(self):
        return self.anchor_generator()
82

F
FDInSky 已提交
83 84 85

@register
class Proposal(object):
86
    __inject__ = ['proposal_generator', 'proposal_target_generator']
F
FDInSky 已提交
87

88
    def __init__(self, proposal_generator, proposal_target_generator):
F
FDInSky 已提交
89 90 91
        super(Proposal, self).__init__()
        self.proposal_generator = proposal_generator
        self.proposal_target_generator = proposal_target_generator
92

93
    def generate_proposal(self, inputs, rpn_head_out, anchor_out):
94 95 96 97 98
        # TODO: delete im_info 
        try:
            im_shape = inputs['im_info']
        except:
            im_shape = inputs['im_shape']
99 100 101 102 103
        rpn_rois_list = []
        rpn_prob_list = []
        rpn_rois_num_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_head_out,
                                                         anchor_out):
W
wangguanzhong 已提交
104
            rpn_prob = F.sigmoid(rpn_score)
105 106 107 108 109
            rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = self.proposal_generator(
                scores=rpn_prob,
                bbox_deltas=rpn_delta,
                anchors=anchor,
                variances=var,
110
                im_shape=im_shape,
111 112 113 114 115 116 117 118 119
                mode=inputs['mode'])
            if len(rpn_head_out) == 1:
                return rpn_rois, rpn_rois_num
            rpn_rois_list.append(rpn_rois)
            rpn_prob_list.append(rpn_rois_prob)
            rpn_rois_num_list.append(rpn_rois_num)

        start_level = 2
        end_level = start_level + len(rpn_head_out)
120
        rois_collect, rois_num_collect = ops.collect_fpn_proposals(
121 122 123 124 125 126 127 128
            rpn_rois_list,
            rpn_prob_list,
            start_level,
            end_level,
            post_nms_top_n,
            rois_num_per_level=rpn_rois_num_list)
        return rois_collect, rois_num_collect

W
wangguanzhong 已提交
129 130 131 132 133 134
    def generate_proposal_target(self,
                                 inputs,
                                 rois,
                                 rois_num,
                                 stage=0,
                                 max_overlap=None):
F
FDInSky 已提交
135
        outs = self.proposal_target_generator(
136
            rpn_rois=rois,
137
            rpn_rois_num=rois_num,
F
FDInSky 已提交
138 139 140 141
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_boxes=inputs['gt_bbox'],
            im_info=inputs['im_info'],
W
wangguanzhong 已提交
142 143
            stage=stage,
            max_overlap=max_overlap)
144
        rois = outs[0]
W
wangguanzhong 已提交
145 146
        max_overlap = outs[-1]
        rois_num = outs[-2]
147
        targets = {
F
FDInSky 已提交
148 149 150
            'labels_int32': outs[1],
            'bbox_targets': outs[2],
            'bbox_inside_weights': outs[3],
151
            'bbox_outside_weights': outs[4]
F
FDInSky 已提交
152
        }
W
wangguanzhong 已提交
153
        return rois, rois_num, targets, max_overlap
154

W
wangguanzhong 已提交
155 156 157 158
    def refine_bbox(self, roi, bbox_delta, stage=1):
        out_dim = bbox_delta.shape[1] // 4
        bbox_delta_r = paddle.reshape(bbox_delta, (-1, out_dim, 4))
        bbox_delta_s = paddle.slice(
159 160
            bbox_delta_r, axes=[1], starts=[1], ends=[2])

W
wangguanzhong 已提交
161 162 163
        reg_weights = [
            i / stage for i in self.proposal_target_generator.bbox_reg_weights
        ]
W
wangguanzhong 已提交
164
        refined_bbox = ops.box_coder(
W
wangguanzhong 已提交
165 166
            prior_box=roi,
            prior_box_var=reg_weights,
167 168 169 170
            target_box=bbox_delta_s,
            code_type='decode_center_size',
            box_normalized=False,
            axis=1)
W
wangguanzhong 已提交
171
        refined_bbox = paddle.reshape(refined_bbox, shape=[-1, 4])
172 173 174 175 176 177 178 179
        return refined_bbox

    def __call__(self,
                 inputs,
                 rpn_head_out,
                 anchor_out,
                 stage=0,
                 proposal_out=None,
W
wangguanzhong 已提交
180 181
                 bbox_head_out=None,
                 max_overlap=None):
182 183 184 185
        if stage == 0:
            roi, rois_num = self.generate_proposal(inputs, rpn_head_out,
                                                   anchor_out)
            self.targets_list = []
W
wangguanzhong 已提交
186
            self.max_overlap = None
F
FDInSky 已提交
187

188
        else:
W
wangguanzhong 已提交
189 190
            bbox_delta = bbox_head_out[1]
            roi = self.refine_bbox(proposal_out[0], bbox_delta, stage)
191 192
            rois_num = proposal_out[1]
        if inputs['mode'] == 'train':
W
wangguanzhong 已提交
193 194
            roi, rois_num, targets, self.max_overlap = self.generate_proposal_target(
                inputs, roi, rois_num, stage, self.max_overlap)
195 196 197 198 199 200
            self.targets_list.append(targets)
        return roi, rois_num

    def get_targets(self):
        return self.targets_list

W
wangguanzhong 已提交
201 202
    def get_max_overlap(self):
        return self.max_overlap