resnet.py 10.0 KB
Newer Older
F
FDInSky 已提交
1 2
import numpy as np
import paddle.fluid as fluid
3
from paddle.fluid.dygraph import Layer, Sequential
F
FDInSky 已提交
4 5 6 7
from paddle.fluid.dygraph import Conv2D, Pool2D, BatchNorm
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Constant
from ppdet.core.workspace import register, serializable
8 9 10
from paddle.fluid.regularizer import L2Decay
from .name_adapter import NameAdapter
from numbers import Integral
F
FDInSky 已提交
11 12


13
class ConvNormLayer(Layer):
F
FDInSky 已提交
14 15 16 17 18
    def __init__(self,
                 ch_in,
                 ch_out,
                 filter_size,
                 stride,
19 20 21 22 23
                 name_adapter,
                 act=None,
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True,
24
                 lr=1.0,
25 26 27 28 29
                 name=None):
        super(ConvNormLayer, self).__init__()
        assert norm_type in ['bn', 'affine_channel']
        self.norm_type = norm_type
        self.act = act
F
FDInSky 已提交
30

31
        self.conv = Conv2D(
F
FDInSky 已提交
32 33 34 35
            num_channels=ch_in,
            num_filters=ch_out,
            filter_size=filter_size,
            stride=stride,
36 37
            padding=(filter_size - 1) // 2,
            groups=1,
F
FDInSky 已提交
38 39
            act=None,
            param_attr=ParamAttr(
40
                learning_rate=lr, name=name + "_weights"),
F
FDInSky 已提交
41 42
            bias_attr=False)

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
        bn_name = name_adapter.fix_conv_norm_name(name)
        norm_lr = 0. if freeze_norm else lr
        param_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            name=bn_name + "_scale",
            trainable=False if freeze_norm else True)
        bias_attr = ParamAttr(
            learning_rate=norm_lr,
            regularizer=L2Decay(norm_decay),
            name=bn_name + "_offset",
            trainable=False if freeze_norm else True)

        if norm_type in ['bn', 'sync_bn']:
            global_stats = True if freeze_norm else False
            self.norm = BatchNorm(
                num_channels=ch_out,
                act=act,
                param_attr=param_attr,
                bias_attr=bias_attr,
                use_global_stats=global_stats,
                moving_mean_name=bn_name + '_mean',
                moving_variance_name=bn_name + '_variance')
            norm_params = self.norm.parameters()
        elif norm_type == 'affine_channel':
            self.scale = fluid.layers.create_parameter(
                shape=[ch_out],
                dtype='float32',
                attr=param_attr,
                default_initializer=Constant(1.))

            self.offset = fluid.layers.create_parameter(
                shape=[ch_out],
                dtype='float32',
                attr=bias_attr,
                default_initializer=Constant(0.))
            norm_params = [self.scale, self.offset]

        if freeze_norm:
            for param in norm_params:
                param.stop_gradient = True
F
FDInSky 已提交
84 85

    def forward(self, inputs):
86
        out = self.conv(inputs)
87 88 89 90 91
        if self.norm_type == 'bn':
            out = self.norm(out)
        elif self.norm_type == 'affine_channel':
            out = fluid.layers.affine_channel(
                out, scale=self.scale, bias=self.offset, act=self.act)
F
FDInSky 已提交
92 93 94 95 96 97 98 99
        return out


class BottleNeck(Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 stride,
100 101 102 103
                 shortcut,
                 name_adapter,
                 name,
                 variant='b',
104
                 lr=1.0,
105 106 107
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True):
F
FDInSky 已提交
108
        super(BottleNeck, self).__init__()
109 110
        if variant == 'a':
            stride1, stride2 = stride, 1
111
        else:
112 113 114 115
            stride1, stride2 = 1, stride

        conv_name1, conv_name2, conv_name3, \
            shortcut_name = name_adapter.fix_bottleneck_name(name)
F
FDInSky 已提交
116 117 118

        self.shortcut = shortcut
        if not shortcut:
119
            self.short = ConvNormLayer(
F
FDInSky 已提交
120 121 122 123
                ch_in=ch_in,
                ch_out=ch_out * 4,
                filter_size=1,
                stride=stride,
124 125 126 127 128 129 130 131
                name_adapter=name_adapter,
                norm_type=norm_type,
                norm_decay=norm_decay,
                freeze_norm=freeze_norm,
                lr=lr,
                name=shortcut_name)

        self.branch2a = ConvNormLayer(
F
FDInSky 已提交
132 133 134
            ch_in=ch_in,
            ch_out=ch_out,
            filter_size=1,
135 136 137 138 139 140 141 142
            stride=stride1,
            name_adapter=name_adapter,
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name1)
F
FDInSky 已提交
143

144
        self.branch2b = ConvNormLayer(
F
FDInSky 已提交
145 146 147
            ch_in=ch_out,
            ch_out=ch_out,
            filter_size=3,
148 149 150 151 152 153 154 155
            stride=stride2,
            name_adapter=name_adapter,
            act='relu',
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
            lr=lr,
            name=conv_name2)
F
FDInSky 已提交
156

157
        self.branch2c = ConvNormLayer(
F
FDInSky 已提交
158 159 160 161
            ch_in=ch_out,
            ch_out=ch_out * 4,
            filter_size=1,
            stride=1,
162 163 164 165
            name_adapter=name_adapter,
            norm_type=norm_type,
            norm_decay=norm_decay,
            freeze_norm=freeze_norm,
166
            lr=lr,
167
            name=conv_name3)
F
FDInSky 已提交
168 169 170 171 172

    def forward(self, inputs):
        if self.shortcut:
            short = inputs
        else:
173
            short = self.short(inputs)
F
FDInSky 已提交
174

175 176 177
        out = self.branch2a(inputs)
        out = self.branch2b(out)
        out = self.branch2c(out)
F
FDInSky 已提交
178

179
        out = fluid.layers.elementwise_add(x=short, y=out, act='relu')
F
FDInSky 已提交
180 181 182 183 184 185 186 187 188

        return out


class Blocks(Layer):
    def __init__(self,
                 ch_in,
                 ch_out,
                 count,
189 190
                 name_adapter,
                 stage_num,
191
                 lr=1.0,
192 193 194
                 norm_type='bn',
                 norm_decay=0.,
                 freeze_norm=True):
F
FDInSky 已提交
195 196 197 198
        super(Blocks, self).__init__()

        self.blocks = []
        for i in range(count):
199
            conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
F
FDInSky 已提交
200 201

            block = self.add_sublayer(
202
                conv_name,
F
FDInSky 已提交
203 204 205
                BottleNeck(
                    ch_in=ch_in if i == 0 else ch_out * 4,
                    ch_out=ch_out,
206 207 208 209 210
                    stride=2 if i == 0 and stage_num != 2 else 1,
                    shortcut=False if i == 0 else True,
                    name_adapter=name_adapter,
                    name=conv_name,
                    variant=name_adapter.variant,
211
                    lr=lr,
212 213 214
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm))
F
FDInSky 已提交
215 216 217
            self.blocks.append(block)

    def forward(self, inputs):
218 219 220 221
        block_out = inputs
        for block in self.blocks:
            block_out = block(block_out)
        return block_out
F
FDInSky 已提交
222 223


224
ResNet_cfg = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}
225 226


F
FDInSky 已提交
227 228 229
@register
@serializable
class ResNet(Layer):
230 231 232 233 234 235 236 237 238 239
    def __init__(self,
                 depth=50,
                 variant='b',
                 lr_mult=1.,
                 norm_type='bn',
                 norm_decay=0,
                 freeze_norm=True,
                 freeze_at=0,
                 return_idx=[0, 1, 2, 3],
                 num_stages=4):
F
FDInSky 已提交
240
        super(ResNet, self).__init__()
241
        self.depth = depth
242
        self.variant = variant
243
        self.norm_type = norm_type
244 245
        self.norm_decay = norm_decay
        self.freeze_norm = freeze_norm
246
        self.freeze_at = freeze_at
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        if isinstance(return_idx, Integral):
            return_idx = [return_idx]
        assert max(return_idx) < num_stages, \
            'the maximum return index must smaller than num_stages, ' \
            'but received maximum return index is {} and num_stages ' \
            'is {}'.format(max(return_idx), num_stages)
        self.return_idx = return_idx
        self.num_stages = num_stages

        block_nums = ResNet_cfg[depth]
        na = NameAdapter(self)

        conv1_name = na.fix_c1_stage_name()
        if variant in ['c', 'd']:
            conv_def = [
                [3, 32, 3, 2, "conv1_1"],
                [32, 32, 3, 1, "conv1_2"],
                [32, 64, 3, 1, "conv1_3"],
            ]
266
        else:
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
            conv_def = [[3, 64, 7, 2, conv1_name]]
        self.conv1 = Sequential()
        for (c_in, c_out, k, s, _name) in conv_def:
            self.conv1.add_sublayer(
                _name,
                ConvNormLayer(
                    ch_in=c_in,
                    ch_out=c_out,
                    filter_size=k,
                    stride=s,
                    name_adapter=na,
                    act='relu',
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm,
                    lr=lr_mult,
                    name=_name))
F
FDInSky 已提交
284

285
        self.pool = Pool2D(
F
FDInSky 已提交
286 287
            pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)

288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
        ch_in_list = [64, 256, 512, 1024]
        ch_out_list = [64, 128, 256, 512]

        self.res_layers = []
        for i in range(num_stages):
            stage_num = i + 2
            res_name = "res{}".format(stage_num)
            res_layer = self.add_sublayer(
                res_name,
                Blocks(
                    ch_in_list[i],
                    ch_out_list[i],
                    count=block_nums[i],
                    name_adapter=na,
                    stage_num=stage_num,
                    lr=lr_mult,
                    norm_type=norm_type,
                    norm_decay=norm_decay,
                    freeze_norm=freeze_norm))
            self.res_layers.append(res_layer)
F
FDInSky 已提交
308 309 310

    def forward(self, inputs):
        x = inputs['image']
311
        conv1 = self.conv1(x)
312 313 314 315 316 317 318 319
        x = self.pool(conv1)
        outs = []
        for idx, stage in enumerate(self.res_layers):
            x = stage(x)
            if idx == self.freeze_at:
                x.stop_gradient = True
            if idx in self.return_idx:
                outs.append(x)
F
FDInSky 已提交
320
        return outs