bbox_head.py 8.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

15
import paddle
16 17
from paddle import ParamAttr
import paddle.nn as nn
18
import paddle.nn.functional as F
19 20 21 22 23
from paddle.nn import ReLU
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops
F
FDInSky 已提交
24 25 26


@register
27
class TwoFCHead(nn.Layer):
28

29
    __shared__ = ['num_stages']
F
FDInSky 已提交
30

31 32 33 34 35 36 37
    def __init__(self, in_dim=256, mlp_dim=1024, resolution=7, num_stages=1):
        super(TwoFCHead, self).__init__()
        self.in_dim = in_dim
        self.mlp_dim = mlp_dim
        self.num_stages = num_stages
        fan = in_dim * resolution * resolution
        self.fc6_list = []
38
        self.fc6_relu_list = []
39
        self.fc7_list = []
40
        self.fc7_relu_list = []
41 42 43 44 45
        for stage in range(num_stages):
            fc6_name = 'fc6_{}'.format(stage)
            fc7_name = 'fc7_{}'.format(stage)
            fc6 = self.add_sublayer(
                fc6_name,
46
                nn.Linear(
47 48
                    in_dim * resolution * resolution,
                    mlp_dim,
49 50
                    weight_attr=ParamAttr(
                        initializer=XavierUniform(fan_out=fan)),
51
                    bias_attr=ParamAttr(
52 53
                        learning_rate=2., regularizer=L2Decay(0.))))
            fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU())
54 55
            fc7 = self.add_sublayer(
                fc7_name,
56
                nn.Linear(
57 58
                    mlp_dim,
                    mlp_dim,
59
                    weight_attr=ParamAttr(initializer=XavierUniform()),
60
                    bias_attr=ParamAttr(
61 62
                        learning_rate=2., regularizer=L2Decay(0.))))
            fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU())
63
            self.fc6_list.append(fc6)
64
            self.fc6_relu_list.append(fc6_relu)
65
            self.fc7_list.append(fc7)
66
            self.fc7_relu_list.append(fc7_relu)
67 68

    def forward(self, rois_feat, stage=0):
69
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
70
        fc6 = self.fc6_list[stage](rois_feat)
71 72 73 74
        fc6_relu = self.fc6_relu_list[stage](fc6)
        fc7 = self.fc7_list[stage](fc6_relu)
        fc7_relu = self.fc7_relu_list[stage](fc7)
        return fc7_relu
75 76 77


@register
78
class BBoxFeat(nn.Layer):
79 80 81
    __inject__ = ['roi_extractor', 'head_feat']

    def __init__(self, roi_extractor, head_feat):
F
FDInSky 已提交
82 83
        super(BBoxFeat, self).__init__()
        self.roi_extractor = roi_extractor
84 85 86 87 88 89
        self.head_feat = head_feat

    def forward(self, body_feats, rois, spatial_scale, stage=0):
        rois_feat = self.roi_extractor(body_feats, rois, spatial_scale)
        bbox_feat = self.head_feat(rois_feat, stage)
        return bbox_feat
F
FDInSky 已提交
90 91 92


@register
93
class BBoxHead(nn.Layer):
94
    __shared__ = ['num_classes', 'num_stages']
95
    __inject__ = ['bbox_feat']
F
FDInSky 已提交
96 97

    def __init__(self,
98
                 bbox_feat,
99
                 in_feat=1024,
F
FDInSky 已提交
100
                 num_classes=81,
101 102
                 cls_agnostic=False,
                 num_stages=1,
103 104 105
                 with_pool=False,
                 score_stage=[0, 1, 2],
                 delta_stage=[2]):
F
FDInSky 已提交
106
        super(BBoxHead, self).__init__()
107
        self.num_classes = num_classes
108 109
        self.delta_dim = 2 if cls_agnostic else num_classes
        self.bbox_feat = bbox_feat
110
        self.num_stages = num_stages
111 112 113
        self.bbox_score_list = []
        self.bbox_delta_list = []
        self.with_pool = with_pool
114 115
        self.score_stage = score_stage
        self.delta_stage = delta_stage
116 117 118 119 120
        for stage in range(num_stages):
            score_name = 'bbox_score_{}'.format(stage)
            delta_name = 'bbox_delta_{}'.format(stage)
            bbox_score = self.add_sublayer(
                score_name,
121 122 123 124 125
                nn.Linear(
                    in_feat,
                    1 * self.num_classes,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0.0, std=0.01)),
126
                    bias_attr=ParamAttr(
127
                        learning_rate=2., regularizer=L2Decay(0.))))
128 129 130

            bbox_delta = self.add_sublayer(
                delta_name,
131 132 133 134 135
                nn.Linear(
                    in_feat,
                    4 * self.delta_dim,
                    weight_attr=ParamAttr(initializer=Normal(
                        mean=0.0, std=0.001)),
136
                    bias_attr=ParamAttr(
137
                        learning_rate=2., regularizer=L2Decay(0.))))
138 139 140 141 142 143
            self.bbox_score_list.append(bbox_score)
            self.bbox_delta_list.append(bbox_delta)

    def forward(self, body_feats, rois, spatial_scale, stage=0):
        bbox_feat = self.bbox_feat(body_feats, rois, spatial_scale, stage)
        if self.with_pool:
144
            bbox_feat = F.pool2d(
145 146 147 148 149 150 151 152
                bbox_feat, pool_type='avg', global_pooling=True)
        bbox_head_out = []
        scores = self.bbox_score_list[stage](bbox_feat)
        deltas = self.bbox_delta_list[stage](bbox_feat)
        bbox_head_out.append((scores, deltas))
        return bbox_feat, bbox_head_out

    def _get_head_loss(self, score, delta, target):
F
FDInSky 已提交
153
        # bbox cls  
154
        labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64')
F
FDInSky 已提交
155
        labels_int64.stop_gradient = True
156
        loss_bbox_cls = F.softmax_with_cross_entropy(
157
            logits=score, label=labels_int64)
158
        loss_bbox_cls = paddle.mean(loss_bbox_cls)
F
FDInSky 已提交
159
        # bbox reg
160 161 162
        loss_bbox_reg = ops.smooth_l1(
            input=delta,
            label=target['bbox_targets'],
163 164
            inside_weight=target['bbox_inside_weights'],
            outside_weight=target['bbox_outside_weights'],
F
FDInSky 已提交
165
            sigma=1.0)
166
        loss_bbox_reg = paddle.mean(loss_bbox_reg)
F
FDInSky 已提交
167
        return loss_bbox_cls, loss_bbox_reg
168

K
Kaipeng Deng 已提交
169
    def get_loss(self, bbox_head_out, targets):
170 171 172 173 174 175 176 177 178 179
        loss_bbox = {}
        for lvl, (bboxhead, target) in enumerate(zip(bbox_head_out, targets)):
            score, delta = bboxhead
            cls_name = 'loss_bbox_cls_{}'.format(lvl)
            reg_name = 'loss_bbox_reg_{}'.format(lvl)
            loss_bbox_cls, loss_bbox_reg = self._get_head_loss(score, delta,
                                                               target)
            loss_bbox[cls_name] = loss_bbox_cls
            loss_bbox[reg_name] = loss_bbox_reg
        return loss_bbox
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211

    def get_prediction(self, bbox_head_out, rois):
        if len(bbox_head_out) == 1:
            proposal, proposal_num = rois
            score, delta = bbox_head_out[0]
            bbox_prob = F.softmax(score)
            delta = paddle.reshape(delta, (-1, self.delta_dim, 4))
        else:
            num_stage = len(rois)
            proposal_list = []
            prob_list = []
            delta_list = []
            for stage, (proposals, bboxhead) in zip(rois, bboxheads):
                score, delta = bboxhead
                proposal, proposal_num = proposals
                if stage in self.score_stage:
                    bbox_prob = F.softmax(score)
                    prob_list.append(bbox_prob)
                if stage in self.delta_stage:
                    proposal_list.append(proposal)
                    delta_list.append(delta)
            bbox_prob = paddle.mean(paddle.stack(prob_list), axis=0)
            delta = paddle.mean(paddle.stack(delta_list), axis=0)
            proposal = paddle.mean(paddle.stack(proposal_list), axis=0)
            delta = paddle.reshape(delta, (-1, self.out_dim, 4))
            if self.cls_agnostic:
                N, C, M = delta.shape
                delta = delta[:, 1:2, :]
                delta = paddle.expand(delta, [N, self.num_classes, M])
        bboxes = (proposal, proposal_num)
        bbox_pred = (delta, bbox_prob)
        return bbox_pred, bboxes