efficientnet.py 12.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division

import collections
import math
import re

from paddle import fluid
from paddle.fluid.regularizer import L2Decay

from ppdet.core.workspace import register

__all__ = ['EfficientNet']

GlobalParams = collections.namedtuple('GlobalParams', [
    'batch_norm_momentum', 'batch_norm_epsilon', 'width_coefficient',
31 32 33
    'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate',
    'relu_fn', 'batch_norm', 'use_se', 'local_pooling', 'condconv_num_experts',
    'clip_projection_output', 'blocks_args', 'fix_head_stem'
34 35 36 37
])

BlockArgs = collections.namedtuple('BlockArgs', [
    'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
38 39
    'expand_ratio', 'id_skip', 'stride', 'se_ratio', 'conv_type', 'fused_conv',
    'super_pixel', 'condconv'
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
])

GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)


def _decode_block_string(block_string):
    assert isinstance(block_string, str)

    ops = block_string.split('_')
    options = {}
    for op in ops:
        splits = re.split(r'(\d.*)', op)
        if len(splits) >= 2:
            key, value = splits[:2]
            options[key] = value

57 58
    if 's' not in options or len(options['s']) != 2:
        raise ValueError('Strides options should be a pair of integers.')
59 60 61 62 63 64 65

    return BlockArgs(
        kernel_size=int(options['k']),
        num_repeat=int(options['r']),
        input_filters=int(options['i']),
        output_filters=int(options['o']),
        expand_ratio=int(options['e']),
66
        id_skip=('noskip' not in block_string),
67
        se_ratio=float(options['se']) if 'se' in options else None,
68 69 70 71 72
        stride=int(options['s'][0]),
        conv_type=int(options['c']) if 'c' in options else 0,
        fused_conv=int(options['f']) if 'f' in options else 0,
        super_pixel=int(options['p']) if 'p' in options else 0,
        condconv=('cc' in block_string))
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98


def get_model_params(scale):
    block_strings = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25',
        'r4_k5_s22_e6_i112_o192_se0.25',
        'r1_k3_s11_e6_i192_o320_se0.25',
    ]
    block_args = []
    for block_string in block_strings:
        block_args.append(_decode_block_string(block_string))

    params_dict = {
        # width, depth
        'b0': (1.0, 1.0),
        'b1': (1.0, 1.1),
        'b2': (1.1, 1.2),
        'b3': (1.2, 1.4),
        'b4': (1.4, 1.8),
        'b5': (1.6, 2.2),
        'b6': (1.8, 2.6),
        'b7': (2.0, 3.1),
99
        'l2': (4.3, 5.3),
100 101 102 103 104
    }

    w, d = params_dict[scale]

    global_params = GlobalParams(
105
        blocks_args=block_strings,
106 107
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
108
        drop_connect_rate=0 if scale == 'b0' else 0.2,
109 110
        width_coefficient=w,
        depth_coefficient=d,
111 112 113 114 115
        depth_divisor=8,
        min_depth=None,
        fix_head_stem=False,
        use_se=True,
        clip_projection_output=False)
116 117 118 119

    return block_args, global_params


120 121
def round_filters(filters, global_params, skip=False):
    """Round number of filters based on depth multiplier."""
122 123
    multiplier = global_params.width_coefficient
    divisor = global_params.depth_divisor
124 125 126 127
    min_depth = global_params.min_depth
    if skip or not multiplier:
        return filters

128
    filters *= multiplier
129 130
    min_depth = min_depth or divisor
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
131 132 133 134 135
    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
        new_filters += divisor
    return int(new_filters)


136 137
def round_repeats(repeats, global_params, skip=False):
    """Round number of filters based on depth multiplier."""
138
    multiplier = global_params.depth_coefficient
139
    if skip or not multiplier:
140 141 142 143 144 145 146 147 148 149 150
        return repeats
    return int(math.ceil(multiplier * repeats))


def conv2d(inputs,
           num_filters,
           filter_size,
           stride=1,
           padding='SAME',
           groups=1,
           use_bias=False,
151 152
           name='conv2d',
           use_cudnn=True):
153 154 155 156 157 158 159 160 161 162 163 164 165 166
    param_attr = fluid.ParamAttr(name=name + '_weights')
    bias_attr = False
    if use_bias:
        bias_attr = fluid.ParamAttr(
            name=name + '_offset', regularizer=L2Decay(0.))
    feats = fluid.layers.conv2d(
        inputs,
        num_filters,
        filter_size,
        groups=groups,
        name=name,
        stride=stride,
        padding=padding,
        param_attr=param_attr,
167 168
        bias_attr=bias_attr,
        use_cudnn=use_cudnn)
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    return feats


def batch_norm(inputs, momentum, eps, name=None):
    param_attr = fluid.ParamAttr(name=name + '_scale', regularizer=L2Decay(0.))
    bias_attr = fluid.ParamAttr(name=name + '_offset', regularizer=L2Decay(0.))
    return fluid.layers.batch_norm(
        input=inputs,
        momentum=momentum,
        epsilon=eps,
        name=name,
        moving_mean_name=name + '_mean',
        moving_variance_name=name + '_variance',
        param_attr=param_attr,
        bias_attr=bias_attr)


186 187 188 189 190 191 192 193 194 195
def _drop_connect(inputs, prob, mode):
    if mode != 'train':
        return inputs
    keep_prob = 1.0 - prob
    inputs_shape = fluid.layers.shape(inputs)
    random_tensor = keep_prob + fluid.layers.uniform_random(shape=[inputs_shape[0], 1, 1, 1], min=0., max=1.)
    binary_tensor = fluid.layers.floor(random_tensor)
    output = inputs / keep_prob * binary_tensor
    return output

196 197 198 199 200 201 202 203
def mb_conv_block(inputs,
                  input_filters,
                  output_filters,
                  expand_ratio,
                  kernel_size,
                  stride,
                  momentum,
                  eps,
204 205 206
                  block_arg,
                  drop_connect_rate,
                  mode,
207 208 209 210 211
                  se_ratio=None,
                  name=None):
    feats = inputs
    num_filters = input_filters * expand_ratio

212
    # Expansion
213 214 215 216 217
    if expand_ratio != 1:
        feats = conv2d(feats, num_filters, 1, name=name + '_expand_conv')
        feats = batch_norm(feats, momentum, eps, name=name + '_bn0')
        feats = fluid.layers.swish(feats)

218
    # Depthwise Convolution
219 220 221 222 223 224
    feats = conv2d(
        feats,
        num_filters,
        kernel_size,
        stride,
        groups=num_filters,
225 226
        name=name + '_depthwise_conv',
        use_cudnn=False)
227 228 229
    feats = batch_norm(feats, momentum, eps, name=name + '_bn1')
    feats = fluid.layers.swish(feats)

230
    # Squeeze and Excitation
231 232 233
    if se_ratio is not None:
        filter_squeezed = max(1, int(input_filters * se_ratio))
        squeezed = fluid.layers.pool2d(
234
            feats, pool_type='avg', global_pooling=True, use_cudnn=True)
235 236 237 238 239 240 241 242 243 244 245
        squeezed = conv2d(
            squeezed,
            filter_squeezed,
            1,
            use_bias=True,
            name=name + '_se_reduce')
        squeezed = fluid.layers.swish(squeezed)
        squeezed = conv2d(
            squeezed, num_filters, 1, use_bias=True, name=name + '_se_expand')
        feats = feats * fluid.layers.sigmoid(squeezed)

246
    # Project_conv_norm
247 248 249
    feats = conv2d(feats, output_filters, 1, name=name + '_project_conv')
    feats = batch_norm(feats, momentum, eps, name=name + '_bn2')

250 251 252 253
    # Skip connection and drop connect
    if block_arg.id_skip and block_arg.stride == 1 and input_filters == output_filters:
        if drop_connect_rate:
            feats = _drop_connect(feats, drop_connect_rate, mode)
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
        feats = fluid.layers.elementwise_add(feats, inputs)

    return feats


@register
class EfficientNet(object):
    """
    EfficientNet, see https://arxiv.org/abs/1905.11946

    Args:
        scale (str): compounding scale factor, 'b0' - 'b7'.
        use_se (bool): use squeeze and excite module.
        norm_type (str): normalization type, 'bn' and 'sync_bn' are supported
    """
    __shared__ = ['norm_type']

271 272 273 274
    def __init__(self,
                 scale='b0',
                 use_se=True,
                 norm_type='bn'):
275 276 277 278 279 280 281 282 283 284
        assert scale in ['b' + str(i) for i in range(8)], \
            "valid scales are b0 - b7"
        assert norm_type in ['bn', 'sync_bn'], \
            "only 'bn' and 'sync_bn' are supported"

        super(EfficientNet, self).__init__()
        self.norm_type = norm_type
        self.scale = scale
        self.use_se = use_se

285 286 287 288
    def __call__(self, inputs, mode):
        assert mode in ['train', 'test'], \
            "only 'train' and 'test' mode are supported"

289 290 291 292
        blocks_args, global_params = get_model_params(self.scale)
        momentum = global_params.batch_norm_momentum
        eps = global_params.batch_norm_epsilon

293 294 295
        # Stem part.
        num_filters = round_filters(blocks_args[0].input_filters, global_params, global_params.fix_head_stem)
        feats = conv2d(inputs, num_filters=num_filters, filter_size=3, stride=2, name='_conv_stem')
296 297 298
        feats = batch_norm(feats, momentum=momentum, eps=eps, name='_bn0')
        feats = fluid.layers.swish(feats)

299
        # Builds blocks.
300
        feature_maps = []
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
        layer_count = 0
        num_blocks = sum([block_arg.num_repeat for block_arg in blocks_args])

        for block_arg in blocks_args:
            # Update block input and output filters based on depth multiplier.
            block_arg = block_arg._replace(
                input_filters=round_filters(block_arg.input_filters,
                                            global_params),
                output_filters=round_filters(block_arg.output_filters,
                                             global_params),
                num_repeat=round_repeats(block_arg.num_repeat,
                                         global_params))

            # The first block needs to take care of stride,
            # and filter size increase.
            drop_connect_rate = global_params.drop_connect_rate
            if drop_connect_rate:
                drop_connect_rate *= float(layer_count) / num_blocks
            feats = mb_conv_block(
                feats,
                block_arg.input_filters,
                block_arg.output_filters,
                block_arg.expand_ratio,
                block_arg.kernel_size,
                block_arg.stride,
                momentum,
                eps,
                block_arg,
                drop_connect_rate,
                mode,
                se_ratio=block_arg.se_ratio,
                name='_blocks.{}.'.format(layer_count))
            layer_count += 1

            # Other block
            if block_arg.num_repeat > 1:
                block_arg = block_arg._replace(input_filters=block_arg.output_filters, stride=1)

            for _ in range(block_arg.num_repeat - 1):
                drop_connect_rate = global_params.drop_connect_rate
                if drop_connect_rate:
                    drop_connect_rate *= float(layer_count) / num_blocks
343 344
                feats = mb_conv_block(
                    feats,
345 346
                    block_arg.input_filters,
                    block_arg.output_filters,
347
                    block_arg.expand_ratio,
348 349
                    block_arg.kernel_size,
                    block_arg.stride,
350 351
                    momentum,
                    eps,
352 353 354 355
                    block_arg,
                    drop_connect_rate,
                    mode,
                    se_ratio=block_arg.se_ratio,
356 357 358 359 360
                    name='_blocks.{}.'.format(layer_count))
                layer_count += 1

            feature_maps.append(feats)

361
        return list(feature_maps[i] for i in [2, 4, 6])  # 1/8, 1/16, 1/32