cascade_head.py 10.4 KB
Newer Older
W
wangguanzhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. 
#   
# Licensed under the Apache License, Version 2.0 (the "License");   
# you may not use this file except in compliance with the License.  
# You may obtain a copy of the License at   
#   
#     http://www.apache.org/licenses/LICENSE-2.0    
#   
# Unless required by applicable law or agreed to in writing, software   
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  
# See the License for the specific language governing permissions and   
# limitations under the License.

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import Normal, XavierUniform
from paddle.regularizer import L2Decay

from ppdet.core.workspace import register, create
from ppdet.modeling import ops

F
Feng Ni 已提交
24
from .bbox_head import BBoxHead, TwoFCHead, XConvNormHead
W
wangguanzhong 已提交
25 26 27 28
from .roi_extractor import RoIAlign
from ..shape_spec import ShapeSpec
from ..bbox_utils import bbox2delta, delta2bbox, clip_bbox, nonempty_bbox

F
Feng Ni 已提交
29 30
__all__ = ['CascadeTwoFCHead', 'CascadeXConvNormHead', 'CascadeHead']

W
wangguanzhong 已提交
31 32 33 34

@register
class CascadeTwoFCHead(nn.Layer):
    __shared__ = ['num_cascade_stage']
W
wangguanzhong 已提交
35 36 37 38 39 40 41 42 43
    """
    Cascade RCNN bbox head  with Two fc layers to extract feature

    Args:
        in_channel (int): Input channel which can be derived by from_config
        out_channel (int): Output channel
        resolution (int): Resolution of input feature map, default 7
        num_cascade_stage (int): The number of cascade stage, default 3
    """
W
wangguanzhong 已提交
44 45

    def __init__(self,
W
wangguanzhong 已提交
46 47
                 in_channel=256,
                 out_channel=1024,
W
wangguanzhong 已提交
48 49 50 51
                 resolution=7,
                 num_cascade_stage=3):
        super(CascadeTwoFCHead, self).__init__()

W
wangguanzhong 已提交
52 53
        self.in_channel = in_channel
        self.out_channel = out_channel
W
wangguanzhong 已提交
54 55 56 57

        self.head_list = []
        for stage in range(num_cascade_stage):
            head_per_stage = self.add_sublayer(
W
wangguanzhong 已提交
58
                str(stage), TwoFCHead(in_channel, out_channel, resolution))
W
wangguanzhong 已提交
59 60 61 62 63 64
            self.head_list.append(head_per_stage)

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
65
        return {'in_channel': s.channels}
F
Feng Ni 已提交
66 67 68

    @property
    def out_shape(self):
W
wangguanzhong 已提交
69
        return [ShapeSpec(channels=self.out_channel, )]
F
Feng Ni 已提交
70 71 72 73 74 75 76 77 78

    def forward(self, rois_feat, stage=0):
        out = self.head_list[stage](rois_feat)
        return out


@register
class CascadeXConvNormHead(nn.Layer):
    __shared__ = ['norm_type', 'freeze_norm', 'num_cascade_stage']
W
wangguanzhong 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92
    """
    Cascade RCNN bbox head with serveral convolution layers

    Args:
        in_channel (int): Input channels which can be derived by from_config
        num_convs (int): The number of conv layers
        conv_dim (int): The number of channels for the conv layers
        out_channel (int): Output channels
        resolution (int): Resolution of input feature map
        norm_type (string): Norm type, bn, gn, sync_bn are available, 
            default `gn`
        freeze_norm (bool): Whether to freeze the norm
        num_cascade_stage (int): The number of cascade stage, default 3
    """
F
Feng Ni 已提交
93 94

    def __init__(self,
W
wangguanzhong 已提交
95
                 in_channel=256,
F
Feng Ni 已提交
96 97
                 num_convs=4,
                 conv_dim=256,
W
wangguanzhong 已提交
98
                 out_channel=1024,
F
Feng Ni 已提交
99 100 101 102 103
                 resolution=7,
                 norm_type='gn',
                 freeze_norm=False,
                 num_cascade_stage=3):
        super(CascadeXConvNormHead, self).__init__()
W
wangguanzhong 已提交
104 105
        self.in_channel = in_channel
        self.out_channel = out_channel
F
Feng Ni 已提交
106 107 108 109 110 111

        self.head_list = []
        for stage in range(num_cascade_stage):
            head_per_stage = self.add_sublayer(
                str(stage),
                XConvNormHead(
W
wangguanzhong 已提交
112
                    in_channel,
F
Feng Ni 已提交
113 114
                    num_convs,
                    conv_dim,
W
wangguanzhong 已提交
115
                    out_channel,
F
Feng Ni 已提交
116 117 118 119 120 121 122 123 124 125
                    resolution,
                    norm_type,
                    freeze_norm,
                    stage_name='stage{}_'.format(stage)))
            self.head_list.append(head_per_stage)

    @classmethod
    def from_config(cls, cfg, input_shape):
        s = input_shape
        s = s[0] if isinstance(s, (list, tuple)) else s
W
wangguanzhong 已提交
126
        return {'in_channel': s.channels}
W
wangguanzhong 已提交
127 128 129

    @property
    def out_shape(self):
W
wangguanzhong 已提交
130
        return [ShapeSpec(channels=self.out_channel, )]
W
wangguanzhong 已提交
131 132 133 134 135 136 137 138 139

    def forward(self, rois_feat, stage=0):
        out = self.head_list[stage](rois_feat)
        return out


@register
class CascadeHead(BBoxHead):
    __shared__ = ['num_classes', 'num_cascade_stages']
140
    __inject__ = ['bbox_assigner', 'bbox_loss']
W
wangguanzhong 已提交
141
    """
W
wangguanzhong 已提交
142 143 144 145 146 147 148 149 150 151 152 153
    Cascade RCNN bbox head

    Args:
        head (nn.Layer): Extract feature in bbox head
        in_channel (int): Input channel after RoI extractor
        roi_extractor (object): The module of RoI Extractor
        bbox_assigner (object): The module of Box Assigner, label and sample the 
            box.
        num_classes (int): The number of classes
        bbox_weight (List[List[float]]): The weight to get the decode box and the 
            length of weight is the number of cascade stage
        num_cascade_stages (int): THe number of stage to refine the box
W
wangguanzhong 已提交
154 155 156 157 158 159 160 161 162 163
    """

    def __init__(self,
                 head,
                 in_channel,
                 roi_extractor=RoIAlign().__dict__,
                 bbox_assigner='BboxAssigner',
                 num_classes=80,
                 bbox_weight=[[10., 10., 5., 5.], [20.0, 20.0, 10.0, 10.0],
                              [30.0, 30.0, 15.0, 15.0]],
164 165
                 num_cascade_stages=3,
                 bbox_loss=None):
W
wangguanzhong 已提交
166 167 168 169 170 171 172 173 174 175
        nn.Layer.__init__(self, )
        self.head = head
        self.roi_extractor = roi_extractor
        if isinstance(roi_extractor, dict):
            self.roi_extractor = RoIAlign(**roi_extractor)
        self.bbox_assigner = bbox_assigner

        self.num_classes = num_classes
        self.bbox_weight = bbox_weight
        self.num_cascade_stages = num_cascade_stages
176
        self.bbox_loss = bbox_loss
W
wangguanzhong 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225

        self.bbox_score_list = []
        self.bbox_delta_list = []
        for i in range(num_cascade_stages):
            score_name = 'bbox_score_stage{}'.format(i)
            delta_name = 'bbox_delta_stage{}'.format(i)
            bbox_score = self.add_sublayer(
                score_name,
                nn.Linear(
                    in_channel,
                    self.num_classes + 1,
                    weight_attr=paddle.ParamAttr(initializer=Normal(
                        mean=0.0, std=0.01))))

            bbox_delta = self.add_sublayer(
                delta_name,
                nn.Linear(
                    in_channel,
                    4,
                    weight_attr=paddle.ParamAttr(initializer=Normal(
                        mean=0.0, std=0.001))))
            self.bbox_score_list.append(bbox_score)
            self.bbox_delta_list.append(bbox_delta)
        self.assigned_label = None
        self.assigned_rois = None

    def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
        """
        body_feats (list[Tensor]): Feature maps from backbone
        rois (Tensor): RoIs generated from RPN module
        rois_num (Tensor): The number of RoIs in each image
        inputs (dict{Tensor}): The ground-truth of image
        """
        targets = []
        if self.training:
            rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
            targets_list = [targets]
            self.assigned_rois = (rois, rois_num)
            self.assigned_targets = targets

        pred_bbox = None
        head_out_list = []
        for i in range(self.num_cascade_stages):
            if i > 0:
                rois, rois_num = self._get_rois_from_boxes(pred_bbox,
                                                           inputs['im_shape'])
                if self.training:
                    rois, rois_num, targets = self.bbox_assigner(
                        rois, rois_num, inputs, i, is_cascade=True)
W
wangguanzhong 已提交
226
                    targets_list.append(targets)
W
wangguanzhong 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256

            rois_feat = self.roi_extractor(body_feats, rois, rois_num)
            bbox_feat = self.head(rois_feat, i)
            scores = self.bbox_score_list[i](bbox_feat)
            deltas = self.bbox_delta_list[i](bbox_feat)
            head_out_list.append([scores, deltas, rois])
            pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i])

        if self.training:
            loss = {}
            for stage, value in enumerate(zip(head_out_list, targets_list)):
                (scores, deltas, rois), targets = value
                loss_stage = self.get_loss(scores, deltas, targets, rois,
                                           self.bbox_weight[stage])
                for k, v in loss_stage.items():
                    loss[k + "_stage{}".format(
                        stage)] = v / self.num_cascade_stages

            return loss, bbox_feat
        else:
            scores, deltas, self.refined_rois = self.get_prediction(
                head_out_list)
            return (deltas, scores), self.head

    def _get_rois_from_boxes(self, boxes, im_shape):
        rois = []
        for i, boxes_per_image in enumerate(boxes):
            clip_box = clip_bbox(boxes_per_image, im_shape[i])
            if self.training:
                keep = nonempty_bbox(clip_box)
257 258
                if keep.shape[0] == 0:
                    keep = paddle.zeros([1], dtype='int32')
W
wangguanzhong 已提交
259 260 261 262 263 264 265 266 267
                clip_box = paddle.gather(clip_box, keep)
            rois.append(clip_box)
        rois_num = paddle.concat([paddle.shape(r)[0] for r in rois])
        return rois, rois_num

    def _get_pred_bbox(self, deltas, proposals, weights):
        pred_proposals = paddle.concat(proposals) if len(
            proposals) > 1 else proposals[0]
        pred_bbox = delta2bbox(deltas, pred_proposals, weights)
268
        pred_bbox = paddle.reshape(pred_bbox, [-1, deltas.shape[-1]])
W
wangguanzhong 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
        num_prop = [p.shape[0] for p in proposals]
        return pred_bbox.split(num_prop)

    def get_prediction(self, head_out_list):
        """
        head_out_list(List[Tensor]): scores, deltas, rois
        """
        pred_list = []
        scores_list = [F.softmax(head[0]) for head in head_out_list]
        scores = paddle.add_n(scores_list) / self.num_cascade_stages
        # Get deltas and rois from the last stage
        _, deltas, rois = head_out_list[-1]
        return scores, deltas, rois

    def get_refined_rois(self, ):
        return self.refined_rois