bbox.py 7.3 KB
Newer Older
F
FDInSky 已提交
1 2
import numpy as np
import paddle.fluid as fluid
W
wangxinxin08 已提交
3 4 5
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
F
FDInSky 已提交
6
from ppdet.core.workspace import register
7
from . import ops
F
FDInSky 已提交
8 9 10


@register
11
class Anchor(object):
F
FDInSky 已提交
12 13
    __inject__ = ['anchor_generator', 'anchor_target_generator']

14
    def __init__(self, anchor_generator, anchor_target_generator):
15
        super(Anchor, self).__init__()
F
FDInSky 已提交
16 17 18
        self.anchor_generator = anchor_generator
        self.anchor_target_generator = anchor_target_generator

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
    def __call__(self, rpn_feats):
        anchors = []
        num_level = len(rpn_feats)
        for i, rpn_feat in enumerate(rpn_feats):
            anchor, var = self.anchor_generator(rpn_feat, i)
            anchors.append((anchor, var))
        return anchors

    def _get_target_input(self, rpn_feats, anchors):
        rpn_score_list = []
        rpn_delta_list = []
        anchor_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_feats, anchors):
            rpn_score = fluid.layers.transpose(rpn_score, perm=[0, 2, 3, 1])
            rpn_delta = fluid.layers.transpose(rpn_delta, perm=[0, 2, 3, 1])
            rpn_score = fluid.layers.reshape(x=rpn_score, shape=(0, -1, 1))
            rpn_delta = fluid.layers.reshape(x=rpn_delta, shape=(0, -1, 4))

            anchor = fluid.layers.reshape(anchor, shape=(-1, 4))
            var = fluid.layers.reshape(var, shape=(-1, 4))
            rpn_score_list.append(rpn_score)
            rpn_delta_list.append(rpn_delta)
            anchor_list.append(anchor)

        rpn_scores = fluid.layers.concat(rpn_score_list, axis=1)
        rpn_deltas = fluid.layers.concat(rpn_delta_list, axis=1)
        anchors = fluid.layers.concat(anchor_list)
        return rpn_scores, rpn_deltas, anchors

    def generate_loss_inputs(self, inputs, rpn_head_out, anchors):
        assert len(rpn_head_out) == len(
            anchors
        ), "rpn_head_out and anchors should have same length, but received rpn_head_out' length is {} and anchors' length is {}".format(
            len(rpn_head_out), len(anchors))
        rpn_score, rpn_delta, anchors = self._get_target_input(rpn_head_out,
                                                               anchors)
F
FDInSky 已提交
55 56

        score_pred, roi_pred, score_tgt, roi_tgt, roi_weight = self.anchor_target_generator(
57 58 59
            bbox_pred=rpn_delta,
            cls_logits=rpn_score,
            anchor_box=anchors,
F
FDInSky 已提交
60 61
            gt_boxes=inputs['gt_bbox'],
            is_crowd=inputs['is_crowd'],
62
            im_info=inputs['im_info'])
F
FDInSky 已提交
63 64 65 66 67 68 69 70 71
        outs = {
            'rpn_score_pred': score_pred,
            'rpn_score_target': score_tgt,
            'rpn_rois_pred': roi_pred,
            'rpn_rois_target': roi_tgt,
            'rpn_rois_weight': roi_weight
        }
        return outs

72 73 74

@register
class AnchorYOLO(object):
75
    __inject__ = ['anchor_generator']
76

77
    def __init__(self, anchor_generator):
78 79 80
        super(AnchorYOLO, self).__init__()
        self.anchor_generator = anchor_generator

W
wangguanzhong 已提交
81 82
    def __call__(self):
        return self.anchor_generator()
83

F
FDInSky 已提交
84 85 86

@register
class Proposal(object):
87
    __inject__ = ['proposal_generator', 'proposal_target_generator']
F
FDInSky 已提交
88

89
    def __init__(self, proposal_generator, proposal_target_generator):
F
FDInSky 已提交
90 91 92
        super(Proposal, self).__init__()
        self.proposal_generator = proposal_generator
        self.proposal_target_generator = proposal_target_generator
93

94
    def generate_proposal(self, inputs, rpn_head_out, anchor_out):
95 96 97 98 99
        # TODO: delete im_info 
        try:
            im_shape = inputs['im_info']
        except:
            im_shape = inputs['im_shape']
100 101 102 103 104 105 106 107 108 109 110
        rpn_rois_list = []
        rpn_prob_list = []
        rpn_rois_num_list = []
        for (rpn_score, rpn_delta), (anchor, var) in zip(rpn_head_out,
                                                         anchor_out):
            rpn_prob = fluid.layers.sigmoid(rpn_score)
            rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = self.proposal_generator(
                scores=rpn_prob,
                bbox_deltas=rpn_delta,
                anchors=anchor,
                variances=var,
111
                im_shape=im_shape,
112 113 114 115 116 117 118 119 120
                mode=inputs['mode'])
            if len(rpn_head_out) == 1:
                return rpn_rois, rpn_rois_num
            rpn_rois_list.append(rpn_rois)
            rpn_prob_list.append(rpn_rois_prob)
            rpn_rois_num_list.append(rpn_rois_num)

        start_level = 2
        end_level = start_level + len(rpn_head_out)
121
        rois_collect, rois_num_collect = ops.collect_fpn_proposals(
122 123 124 125 126 127 128 129
            rpn_rois_list,
            rpn_prob_list,
            start_level,
            end_level,
            post_nms_top_n,
            rois_num_per_level=rpn_rois_num_list)
        return rois_collect, rois_num_collect

W
wangguanzhong 已提交
130 131 132 133 134 135
    def generate_proposal_target(self,
                                 inputs,
                                 rois,
                                 rois_num,
                                 stage=0,
                                 max_overlap=None):
F
FDInSky 已提交
136
        outs = self.proposal_target_generator(
137
            rpn_rois=rois,
138
            rpn_rois_num=rois_num,
F
FDInSky 已提交
139 140 141 142
            gt_classes=inputs['gt_class'],
            is_crowd=inputs['is_crowd'],
            gt_boxes=inputs['gt_bbox'],
            im_info=inputs['im_info'],
W
wangguanzhong 已提交
143 144
            stage=stage,
            max_overlap=max_overlap)
145
        rois = outs[0]
W
wangguanzhong 已提交
146 147
        max_overlap = outs[-1]
        rois_num = outs[-2]
148
        targets = {
F
FDInSky 已提交
149 150 151
            'labels_int32': outs[1],
            'bbox_targets': outs[2],
            'bbox_inside_weights': outs[3],
152
            'bbox_outside_weights': outs[4]
F
FDInSky 已提交
153
        }
W
wangguanzhong 已提交
154
        return rois, rois_num, targets, max_overlap
155

W
wangguanzhong 已提交
156 157 158 159
    def refine_bbox(self, roi, bbox_delta, stage=1):
        out_dim = bbox_delta.shape[1] // 4
        bbox_delta_r = paddle.reshape(bbox_delta, (-1, out_dim, 4))
        bbox_delta_s = paddle.slice(
160 161
            bbox_delta_r, axes=[1], starts=[1], ends=[2])

W
wangguanzhong 已提交
162 163 164
        reg_weights = [
            i / stage for i in self.proposal_target_generator.bbox_reg_weights
        ]
W
wangguanzhong 已提交
165
        refined_bbox = ops.box_coder(
W
wangguanzhong 已提交
166 167
            prior_box=roi,
            prior_box_var=reg_weights,
168 169 170 171
            target_box=bbox_delta_s,
            code_type='decode_center_size',
            box_normalized=False,
            axis=1)
W
wangguanzhong 已提交
172
        refined_bbox = paddle.reshape(refined_bbox, shape=[-1, 4])
173 174 175 176 177 178 179 180
        return refined_bbox

    def __call__(self,
                 inputs,
                 rpn_head_out,
                 anchor_out,
                 stage=0,
                 proposal_out=None,
W
wangguanzhong 已提交
181 182
                 bbox_head_out=None,
                 max_overlap=None):
183 184 185 186
        if stage == 0:
            roi, rois_num = self.generate_proposal(inputs, rpn_head_out,
                                                   anchor_out)
            self.targets_list = []
W
wangguanzhong 已提交
187
            self.max_overlap = None
F
FDInSky 已提交
188

189
        else:
W
wangguanzhong 已提交
190 191
            bbox_delta = bbox_head_out[1]
            roi = self.refine_bbox(proposal_out[0], bbox_delta, stage)
192 193
            rois_num = proposal_out[1]
        if inputs['mode'] == 'train':
W
wangguanzhong 已提交
194 195
            roi, rois_num, targets, self.max_overlap = self.generate_proposal_target(
                inputs, roi, rois_num, stage, self.max_overlap)
196 197 198 199 200 201
            self.targets_list.append(targets)
        return roi, rois_num

    def get_targets(self):
        return self.targets_list

W
wangguanzhong 已提交
202 203
    def get_max_overlap(self):
        return self.max_overlap