ssd_head.py 5.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
5 6 7
from paddle.regularizer import L2Decay
from paddle import ParamAttr

8 9
from ..layers import AnchorGeneratorSSD

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

class SepConvLayer(nn.Layer):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 padding=1,
                 conv_decay=0,
                 name=None):
        super(SepConvLayer, self).__init__()
        self.dw_conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=padding,
            groups=in_channels,
            weight_attr=ParamAttr(
                name=name + "_dw_weights", regularizer=L2Decay(conv_decay)),
            bias_attr=False)

        self.bn = nn.BatchNorm2D(
            in_channels,
            weight_attr=ParamAttr(
                name=name + "_bn_scale", regularizer=L2Decay(0.)),
            bias_attr=ParamAttr(
                name=name + "_bn_offset", regularizer=L2Decay(0.)))

        self.pw_conv = nn.Conv2D(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            weight_attr=ParamAttr(
                name=name + "_pw_weights", regularizer=L2Decay(conv_decay)),
            bias_attr=False)

    def forward(self, x):
        x = self.dw_conv(x)
        x = F.relu6(self.bn(x))
        x = self.pw_conv(x)
        return x
Q
qingqing01 已提交
53 54 55 56 57 58 59 60


@register
class SSDHead(nn.Layer):
    __shared__ = ['num_classes']
    __inject__ = ['anchor_generator', 'loss']

    def __init__(self,
61
                 num_classes=80,
Q
qingqing01 已提交
62
                 in_channels=(512, 1024, 512, 256, 256, 256),
63
                 anchor_generator=AnchorGeneratorSSD().__dict__,
64 65 66 67
                 kernel_size=3,
                 padding=1,
                 use_sepconv=False,
                 conv_decay=0.,
Q
qingqing01 已提交
68 69
                 loss='SSDLoss'):
        super(SSDHead, self).__init__()
70 71
        # add background class
        self.num_classes = num_classes + 1
Q
qingqing01 已提交
72 73 74 75
        self.in_channels = in_channels
        self.anchor_generator = anchor_generator
        self.loss = loss

76 77 78 79
        if isinstance(anchor_generator, dict):
            self.anchor_generator = AnchorGeneratorSSD(**anchor_generator)

        self.num_priors = self.anchor_generator.num_priors
Q
qingqing01 已提交
80 81 82
        self.box_convs = []
        self.score_convs = []
        for i, num_prior in enumerate(self.num_priors):
83 84 85 86
            box_conv_name = "boxes{}".format(i)
            if not use_sepconv:
                box_conv = self.add_sublayer(
                    box_conv_name,
Q
qingqing01 已提交
87 88 89
                    nn.Conv2D(
                        in_channels=in_channels[i],
                        out_channels=num_prior * 4,
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
                        kernel_size=kernel_size,
                        padding=padding))
            else:
                box_conv = self.add_sublayer(
                    box_conv_name,
                    SepConvLayer(
                        in_channels=in_channels[i],
                        out_channels=num_prior * 4,
                        kernel_size=kernel_size,
                        padding=padding,
                        conv_decay=conv_decay,
                        name=box_conv_name))
            self.box_convs.append(box_conv)

            score_conv_name = "scores{}".format(i)
            if not use_sepconv:
                score_conv = self.add_sublayer(
                    score_conv_name,
Q
qingqing01 已提交
108 109
                    nn.Conv2D(
                        in_channels=in_channels[i],
110
                        out_channels=num_prior * self.num_classes,
111 112 113 114 115 116 117
                        kernel_size=kernel_size,
                        padding=padding))
            else:
                score_conv = self.add_sublayer(
                    score_conv_name,
                    SepConvLayer(
                        in_channels=in_channels[i],
118
                        out_channels=num_prior * self.num_classes,
119 120 121 122 123
                        kernel_size=kernel_size,
                        padding=padding,
                        conv_decay=conv_decay,
                        name=score_conv_name))
            self.score_convs.append(score_conv)
Q
qingqing01 已提交
124

125 126 127 128 129
    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }

    def forward(self, feats, image, gt_bbox=None, gt_class=None):
Q
qingqing01 已提交
130 131 132
        box_preds = []
        cls_scores = []
        prior_boxes = []
K
Kaipeng Deng 已提交
133 134
        for feat, box_conv, score_conv in zip(feats, self.box_convs,
                                              self.score_convs):
Q
qingqing01 已提交
135 136 137 138 139 140 141 142 143 144 145 146
            box_pred = box_conv(feat)
            box_pred = paddle.transpose(box_pred, [0, 2, 3, 1])
            box_pred = paddle.reshape(box_pred, [0, -1, 4])
            box_preds.append(box_pred)

            cls_score = score_conv(feat)
            cls_score = paddle.transpose(cls_score, [0, 2, 3, 1])
            cls_score = paddle.reshape(cls_score, [0, -1, self.num_classes])
            cls_scores.append(cls_score)

        prior_boxes = self.anchor_generator(feats, image)

147 148 149 150
        if self.training:
            return self.get_loss(box_preds, cls_scores, gt_bbox, gt_class,
                                 prior_boxes)
        else:
K
Kaipeng Deng 已提交
151
            return (box_preds, cls_scores), prior_boxes
Q
qingqing01 已提交
152

153 154
    def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes):
        return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes)