bbox_head.py 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

15
import paddle
16 17
from paddle import ParamAttr
import paddle.nn as nn
18
import paddle.nn.functional as F
19 20 21 22 23
from paddle.nn import ReLU
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling import ops
F
FDInSky 已提交
24

G
Guanghua Yu 已提交
25 26 27
from ..backbone.name_adapter import NameAdapter
from ..backbone.resnet import Blocks

F
FDInSky 已提交
28 29

@register
30
class TwoFCHead(nn.Layer):
31

W
wangguanzhong 已提交
32
    __shared__ = ['roi_stages']
F
FDInSky 已提交
33

W
wangguanzhong 已提交
34
    def __init__(self, in_dim=256, mlp_dim=1024, resolution=7, roi_stages=1):
35 36 37
        super(TwoFCHead, self).__init__()
        self.in_dim = in_dim
        self.mlp_dim = mlp_dim
W
wangguanzhong 已提交
38
        self.roi_stages = roi_stages
39 40
        fan = in_dim * resolution * resolution
        self.fc6_list = []
41
        self.fc6_relu_list = []
42
        self.fc7_list = []
43
        self.fc7_relu_list = []
W
wangguanzhong 已提交
44
        for stage in range(roi_stages):
45 46
            fc6_name = 'fc6_{}'.format(stage)
            fc7_name = 'fc7_{}'.format(stage)
W
wangguanzhong 已提交
47
            lr_factor = 2**stage
48 49
            fc6 = self.add_sublayer(
                fc6_name,
50
                nn.Linear(
51 52
                    in_dim * resolution * resolution,
                    mlp_dim,
53
                    weight_attr=ParamAttr(
W
wangguanzhong 已提交
54
                        learning_rate=lr_factor,
55
                        initializer=XavierUniform(fan_out=fan)),
56
                    bias_attr=ParamAttr(
W
wangguanzhong 已提交
57
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
58
            fc6_relu = self.add_sublayer(fc6_name + 'act', ReLU())
59 60
            fc7 = self.add_sublayer(
                fc7_name,
61
                nn.Linear(
62 63
                    mlp_dim,
                    mlp_dim,
W
wangguanzhong 已提交
64 65
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor, initializer=XavierUniform()),
66
                    bias_attr=ParamAttr(
W
wangguanzhong 已提交
67
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
68
            fc7_relu = self.add_sublayer(fc7_name + 'act', ReLU())
69
            self.fc6_list.append(fc6)
70
            self.fc6_relu_list.append(fc6_relu)
71
            self.fc7_list.append(fc7)
72
            self.fc7_relu_list.append(fc7_relu)
73 74

    def forward(self, rois_feat, stage=0):
75
        rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
76
        fc6 = self.fc6_list[stage](rois_feat)
77 78 79 80
        fc6_relu = self.fc6_relu_list[stage](fc6)
        fc7 = self.fc7_list[stage](fc6_relu)
        fc7_relu = self.fc7_relu_list[stage](fc7)
        return fc7_relu
81 82


G
Guanghua Yu 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
@register
class Res5Head(nn.Layer):
    def __init__(self, feat_in=1024, feat_out=512):
        super(Res5Head, self).__init__()
        na = NameAdapter(self)
        self.res5_conv = []
        self.res5 = self.add_sublayer(
            'res5_roi_feat',
            Blocks(
                feat_in, feat_out, count=3, name_adapter=na, stage_num=5))
        self.feat_out = feat_out * 4

    def forward(self, roi_feat, stage=0):
        y = self.res5(roi_feat)
        return y


100
@register
101
class BBoxFeat(nn.Layer):
102 103 104
    __inject__ = ['roi_extractor', 'head_feat']

    def __init__(self, roi_extractor, head_feat):
F
FDInSky 已提交
105 106
        super(BBoxFeat, self).__init__()
        self.roi_extractor = roi_extractor
107
        self.head_feat = head_feat
W
wangguanzhong 已提交
108
        self.rois_feat_list = []
109 110 111 112

    def forward(self, body_feats, rois, spatial_scale, stage=0):
        rois_feat = self.roi_extractor(body_feats, rois, spatial_scale)
        bbox_feat = self.head_feat(rois_feat, stage)
W
wangguanzhong 已提交
113
        return rois_feat, bbox_feat
F
FDInSky 已提交
114 115 116


@register
117
class BBoxHead(nn.Layer):
W
wangguanzhong 已提交
118
    __shared__ = ['num_classes', 'roi_stages']
119
    __inject__ = ['bbox_feat']
F
FDInSky 已提交
120 121

    def __init__(self,
122
                 bbox_feat,
123
                 in_feat=1024,
F
FDInSky 已提交
124
                 num_classes=81,
125
                 cls_agnostic=False,
W
wangguanzhong 已提交
126
                 roi_stages=1,
127 128 129
                 with_pool=False,
                 score_stage=[0, 1, 2],
                 delta_stage=[2]):
F
FDInSky 已提交
130
        super(BBoxHead, self).__init__()
131
        self.num_classes = num_classes
W
wangguanzhong 已提交
132
        self.cls_agnostic = cls_agnostic
133 134
        self.delta_dim = 2 if cls_agnostic else num_classes
        self.bbox_feat = bbox_feat
W
wangguanzhong 已提交
135
        self.roi_stages = roi_stages
136 137
        self.bbox_score_list = []
        self.bbox_delta_list = []
W
wangguanzhong 已提交
138
        self.roi_feat_list = [[] for i in range(roi_stages)]
139
        self.with_pool = with_pool
140 141
        self.score_stage = score_stage
        self.delta_stage = delta_stage
W
wangguanzhong 已提交
142
        for stage in range(roi_stages):
143 144
            score_name = 'bbox_score_{}'.format(stage)
            delta_name = 'bbox_delta_{}'.format(stage)
W
wangguanzhong 已提交
145
            lr_factor = 2**stage
146 147
            bbox_score = self.add_sublayer(
                score_name,
148 149 150
                nn.Linear(
                    in_feat,
                    1 * self.num_classes,
W
wangguanzhong 已提交
151 152 153 154
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor,
                        initializer=Normal(
                            mean=0.0, std=0.01)),
155
                    bias_attr=ParamAttr(
W
wangguanzhong 已提交
156
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
157 158 159

            bbox_delta = self.add_sublayer(
                delta_name,
160 161 162
                nn.Linear(
                    in_feat,
                    4 * self.delta_dim,
W
wangguanzhong 已提交
163 164 165 166
                    weight_attr=ParamAttr(
                        learning_rate=lr_factor,
                        initializer=Normal(
                            mean=0.0, std=0.001)),
167
                    bias_attr=ParamAttr(
W
wangguanzhong 已提交
168
                        learning_rate=2. * lr_factor, regularizer=L2Decay(0.))))
169 170 171
            self.bbox_score_list.append(bbox_score)
            self.bbox_delta_list.append(bbox_delta)

W
wangguanzhong 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184
    def forward(self,
                body_feats=None,
                rois=None,
                spatial_scale=None,
                stage=0,
                roi_stage=-1):
        if rois is not None:
            rois_feat, bbox_feat = self.bbox_feat(body_feats, rois,
                                                  spatial_scale, stage)
            self.roi_feat_list[stage] = rois_feat
        else:
            rois_feat = self.roi_feat_list[roi_stage]
            bbox_feat = self.bbox_feat.head_feat(rois_feat, stage)
G
Guanghua Yu 已提交
185 186 187 188 189 190 191 192
        if self.with_pool:
            bbox_feat_ = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
            bbox_feat_ = paddle.squeeze(bbox_feat_, axis=[2, 3])
            scores = self.bbox_score_list[stage](bbox_feat_)
            deltas = self.bbox_delta_list[stage](bbox_feat_)
        else:
            scores = self.bbox_score_list[stage](bbox_feat)
            deltas = self.bbox_delta_list[stage](bbox_feat)
W
wangguanzhong 已提交
193 194
        bbox_head_out = (scores, deltas)
        return bbox_feat, bbox_head_out, self.bbox_feat.head_feat
195 196

    def _get_head_loss(self, score, delta, target):
F
FDInSky 已提交
197
        # bbox cls  
198
        labels_int64 = paddle.cast(x=target['labels_int32'], dtype='int64')
F
FDInSky 已提交
199
        labels_int64.stop_gradient = True
200
        loss_bbox_cls = F.softmax_with_cross_entropy(
201
            logits=score, label=labels_int64)
202
        loss_bbox_cls = paddle.mean(loss_bbox_cls)
F
FDInSky 已提交
203
        # bbox reg
204 205 206
        loss_bbox_reg = ops.smooth_l1(
            input=delta,
            label=target['bbox_targets'],
207 208
            inside_weight=target['bbox_inside_weights'],
            outside_weight=target['bbox_outside_weights'],
F
FDInSky 已提交
209
            sigma=1.0)
210
        loss_bbox_reg = paddle.mean(loss_bbox_reg)
F
FDInSky 已提交
211
        return loss_bbox_cls, loss_bbox_reg
212

K
Kaipeng Deng 已提交
213
    def get_loss(self, bbox_head_out, targets):
214 215 216 217 218 219 220
        loss_bbox = {}
        for lvl, (bboxhead, target) in enumerate(zip(bbox_head_out, targets)):
            score, delta = bboxhead
            cls_name = 'loss_bbox_cls_{}'.format(lvl)
            reg_name = 'loss_bbox_reg_{}'.format(lvl)
            loss_bbox_cls, loss_bbox_reg = self._get_head_loss(score, delta,
                                                               target)
W
wangguanzhong 已提交
221 222 223
            loss_weight = 1. / 2**lvl
            loss_bbox[cls_name] = loss_bbox_cls * loss_weight
            loss_bbox[reg_name] = loss_bbox_reg * loss_weight
224
        return loss_bbox
225 226

    def get_prediction(self, bbox_head_out, rois):
W
wangguanzhong 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
        proposal, proposal_num = rois
        score, delta = bbox_head_out
        bbox_prob = F.softmax(score)
        delta = paddle.reshape(delta, (-1, self.delta_dim, 4))
        bbox_pred = (delta, bbox_prob)
        return bbox_pred, rois

    def get_cascade_prediction(self, bbox_head_out, rois):
        proposal_list = []
        prob_list = []
        delta_list = []
        for stage in range(len(rois)):
            proposals = rois[stage]
            bboxhead = bbox_head_out[stage]
            score, delta = bboxhead
            proposal, proposal_num = proposals
            if stage in self.score_stage:
                if stage < 2:
                    _, head_out, _ = self(stage=stage, roi_stage=-1)
                    score = head_out[0]

                bbox_prob = F.softmax(score)
                prob_list.append(bbox_prob)
            if stage in self.delta_stage:
                proposal_list.append(proposal)
                delta_list.append(delta)
        bbox_prob = paddle.mean(paddle.stack(prob_list), axis=0)
        delta = paddle.mean(paddle.stack(delta_list), axis=0)
        proposal = paddle.mean(paddle.stack(proposal_list), axis=0)
        delta = paddle.reshape(delta, (-1, self.delta_dim, 4))
        if self.cls_agnostic:
            N, C, M = delta.shape
            delta = delta[:, 1:2, :]
            delta = paddle.expand(delta, [N, self.num_classes, M])
261 262 263
        bboxes = (proposal, proposal_num)
        bbox_pred = (delta, bbox_prob)
        return bbox_pred, bboxes