efficientnet_v2.py 33.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Code was based on https://github.com/lukemelas/EfficientNet-PyTorch
# reference: https://arxiv.org/abs/1905.11946

import math
import re

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant, Normal, Uniform
from paddle.regularizer import L2Decay

H
HydrogenSulfate 已提交
29
from ....utils.config import AttrDict
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270

from ....utils.save_load import (load_dygraph_pretrain,
                                 load_dygraph_pretrain_from_url)

MODEL_URLS = {
    "EfficientNetV2_S":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_S_pretrained.pdparams",
    "EfficientNetV2_M":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_M_pretrained.pdparams",
    "EfficientNetV2_L":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_L_pretrained.pdparams",
    "EfficientNetV2_XL":
    "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_XL_pretrained.pdparams",
}

__all__ = list(MODEL_URLS.keys())

inp_shape = {
    "efficientnetv2-s": [384, 192, 192, 96, 48, 24, 24, 12],
    "efficientnetv2-m": [384, 192, 192, 96, 48, 24, 24, 12],
    "efficientnetv2-l": [384, 192, 192, 96, 48, 24, 24, 12],
    "efficientnetv2-xl": [384, 192, 192, 96, 48, 24, 24, 12],
}


def cal_padding(img_size, stride, kernel_size):
    """Calculate padding size."""
    if img_size % stride == 0:
        out_size = max(kernel_size - stride, 0)
    else:
        out_size = max(kernel_size - (img_size % stride), 0)
    return out_size // 2, out_size - out_size // 2


class Conv2ds(nn.Layer):
    """Customed Conv2D with tensorflow's padding style

    Args:
        input_channels (int): input channels
        output_channels (int): output channels
        kernel_size (int): filter size
        stride (int, optional): stride. Defaults to 1.
        padding (int, optional): padding. Defaults to 0.
        groups (int, optional): groups. Defaults to None.
        act (str, optional): act. Defaults to None.
        use_bias (bool, optional): use_bias. Defaults to None.
        padding_type (str, optional): padding_type. Defaults to None.
        model_name (str, optional): model name. Defaults to None.
        cur_stage (int, optional): current stage. Defaults to None.

    Returns:
        nn.Layer: Customed Conv2D instance
    """

    def __init__(self,
                 input_channels: int,
                 output_channels: int,
                 kernel_size: int,
                 stride=1,
                 padding=0,
                 groups=None,
                 act=None,
                 use_bias=None,
                 padding_type=None,
                 model_name=None,
                 cur_stage=None):
        super(Conv2ds, self).__init__()
        assert act in [None, "swish", "sigmoid"]
        self._act = act

        def get_padding(kernel_size, stride=1, dilation=1):
            padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
            return padding

        inps = inp_shape[model_name][cur_stage]
        self.need_crop = False
        if padding_type == "SAME":
            top_padding, bottom_padding = cal_padding(inps, stride,
                                                      kernel_size)
            left_padding, right_padding = cal_padding(inps, stride,
                                                      kernel_size)
            height_padding = bottom_padding
            width_padding = right_padding
            if top_padding != bottom_padding or left_padding != right_padding:
                height_padding = top_padding + stride
                width_padding = left_padding + stride
                self.need_crop = True
            padding = [height_padding, width_padding]
        elif padding_type == "VALID":
            height_padding = 0
            width_padding = 0
            padding = [height_padding, width_padding]
        elif padding_type == "DYNAMIC":
            padding = get_padding(kernel_size, stride)
        else:
            padding = padding_type

        groups = 1 if groups is None else groups
        self._conv = nn.Conv2D(
            input_channels,
            output_channels,
            kernel_size,
            groups=groups,
            stride=stride,
            padding=padding,
            weight_attr=None,
            bias_attr=use_bias
            if not use_bias else ParamAttr(regularizer=L2Decay(0.0)))

    def forward(self, inputs):
        x = self._conv(inputs)
        if self._act == "swish":
            x = F.swish(x)
        elif self._act == "sigmoid":
            x = F.sigmoid(x)

        if self.need_crop:
            x = x[:, :, 1:, 1:]
        return x


class BlockDecoder(object):
    """Block Decoder for readability."""

    def _decode_block_string(self, block_string):
        """Gets a block through a string notation of arguments."""
        assert isinstance(block_string, str)
        ops = block_string.split('_')
        options = AttrDict()
        for op in ops:
            splits = re.split(r'(\d.*)', op)
            if len(splits) >= 2:
                key, value = splits[:2]
                options[key] = value

        t = AttrDict(
            kernel_size=int(options['k']),
            num_repeat=int(options['r']),
            in_channels=int(options['i']),
            out_channels=int(options['o']),
            expand_ratio=int(options['e']),
            se_ratio=float(options['se']) if 'se' in options else None,
            strides=int(options['s']),
            conv_type=int(options['c']) if 'c' in options else 0, )
        return t

    def _encode_block_string(self, block):
        """Encodes a block to a string."""
        args = [
            'r%d' % block.num_repeat,
            'k%d' % block.kernel_size,
            's%d' % block.strides,
            'e%s' % block.expand_ratio,
            'i%d' % block.in_channels,
            'o%d' % block.out_channels,
            'c%d' % block.conv_type,
            'f%d' % block.fused_conv,
        ]
        if block.se_ratio > 0 and block.se_ratio <= 1:
            args.append('se%s' % block.se_ratio)
        return '_'.join(args)

    def decode(self, string_list):
        """Decodes a list of string notations to specify blocks inside the network.

        Args:
        string_list: a list of strings, each string is a notation of block.

        Returns:
        A list of namedtuples to represent blocks arguments.
        """
        assert isinstance(string_list, list)
        blocks_args = []
        for block_string in string_list:
            blocks_args.append(self._decode_block_string(block_string))
        return blocks_args

    def encode(self, blocks_args):
        """Encodes a list of Blocks to a list of strings.

        Args:
        blocks_args: A list of namedtuples to represent blocks arguments.
        Returns:
        a list of strings, each string is a notation of block.
        """
        block_strings = []
        for block in blocks_args:
            block_strings.append(self._encode_block_string(block))
        return block_strings


#################### EfficientNet V2 configs ####################
v2_base_block = [  # The baseline config for v2 models.
    "r1_k3_s1_e1_i32_o16_c1",
    "r2_k3_s2_e4_i16_o32_c1",
    "r2_k3_s2_e4_i32_o48_c1",
    "r3_k3_s2_e4_i48_o96_se0.25",
    "r5_k3_s1_e6_i96_o112_se0.25",
    "r8_k3_s2_e6_i112_o192_se0.25",
]

v2_s_block = [  # about base * (width1.4, depth1.8)
    "r2_k3_s1_e1_i24_o24_c1",
    "r4_k3_s2_e4_i24_o48_c1",
    "r4_k3_s2_e4_i48_o64_c1",
    "r6_k3_s2_e4_i64_o128_se0.25",
    "r9_k3_s1_e6_i128_o160_se0.25",
    "r15_k3_s2_e6_i160_o256_se0.25",
]

v2_m_block = [  # about base * (width1.6, depth2.2)
    "r3_k3_s1_e1_i24_o24_c1",
    "r5_k3_s2_e4_i24_o48_c1",
    "r5_k3_s2_e4_i48_o80_c1",
    "r7_k3_s2_e4_i80_o160_se0.25",
    "r14_k3_s1_e6_i160_o176_se0.25",
    "r18_k3_s2_e6_i176_o304_se0.25",
    "r5_k3_s1_e6_i304_o512_se0.25",
]

v2_l_block = [  # about base * (width2.0, depth3.1)
    "r4_k3_s1_e1_i32_o32_c1",
    "r7_k3_s2_e4_i32_o64_c1",
    "r7_k3_s2_e4_i64_o96_c1",
    "r10_k3_s2_e4_i96_o192_se0.25",
    "r19_k3_s1_e6_i192_o224_se0.25",
    "r25_k3_s2_e6_i224_o384_se0.25",
    "r7_k3_s1_e6_i384_o640_se0.25",
]

v2_xl_block = [  # only for 21k pretraining.
    "r4_k3_s1_e1_i32_o32_c1",
    "r8_k3_s2_e4_i32_o64_c1",
    "r8_k3_s2_e4_i64_o96_c1",
    "r16_k3_s2_e4_i96_o192_se0.25",
    "r24_k3_s1_e6_i192_o256_se0.25",
    "r32_k3_s2_e6_i256_o512_se0.25",
    "r8_k3_s1_e6_i512_o640_se0.25",
]
efficientnetv2_params = {
    # params:            (block, width, depth, dropout)
271
    "efficientnetv2-s": (v2_s_block, 1.0, 1.0, np.linspace(0.1, 0.3, 4)),
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
    "efficientnetv2-m": (v2_m_block, 1.0, 1.0, 0.3),
    "efficientnetv2-l": (v2_l_block, 1.0, 1.0, 0.4),
    "efficientnetv2-xl": (v2_xl_block, 1.0, 1.0, 0.4),
}


def efficientnetv2_config(model_name: str):
    """EfficientNetV2 model config."""
    block, width, depth, dropout = efficientnetv2_params[model_name]

    cfg = AttrDict(model=AttrDict(
        model_name=model_name,
        blocks_args=BlockDecoder().decode(block),
        width_coefficient=width,
        depth_coefficient=depth,
        dropout_rate=dropout,
        feature_size=1280,
        bn_momentum=0.9,
        bn_epsilon=1e-3,
        depth_divisor=8,
        min_depth=8,
        act_fn="silu",
        survival_prob=0.8,
        local_pooling=False,
296
        conv_dropout=0,
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
        num_classes=1000))
    return cfg


def get_model_config(model_name: str):
    """Main entry for model name to config."""
    if model_name.startswith("efficientnetv2-"):
        return efficientnetv2_config(model_name)
    raise ValueError(f"Unknown model_name {model_name}")


################################################################################


def round_filters(filters,
                  width_coefficient,
                  depth_divisor,
                  min_depth,
                  skip=False):
    """Round number of filters based on depth multiplier."""
    multiplier = width_coefficient
    divisor = depth_divisor
    min_depth = min_depth
    if skip or not multiplier:
        return filters

    filters *= multiplier
    min_depth = min_depth or divisor
    new_filters = max(min_depth,
                      int(filters + divisor / 2) // divisor * divisor)
    return int(new_filters)


def round_repeats(repeats, multiplier, skip=False):
    """Round number of filters based on depth multiplier."""
    if skip or not multiplier:
        return repeats
    return int(math.ceil(multiplier * repeats))


def activation_fn(act_fn: str):
    """Customized non-linear activation type."""
    if not act_fn:
        return nn.Silu()
    elif act_fn in ("silu", "swish"):
        return nn.Swish()
    elif act_fn == "relu":
        return nn.ReLU()
    elif act_fn == "relu6":
        return nn.ReLU6()
    elif act_fn == "elu":
        return nn.ELU()
    elif act_fn == "leaky_relu":
        return nn.LeakyReLU()
    elif act_fn == "selu":
        return nn.SELU()
    elif act_fn == "mish":
        return nn.Mish()
    else:
        raise ValueError("Unsupported act_fn {}".format(act_fn))


def drop_path(x, training=False, survival_prob=1.0):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if not training:
        return x
    shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
    keep_prob = paddle.to_tensor(survival_prob)
    random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class SE(nn.Layer):
    """Squeeze-and-excitation layer.

    Args:
        local_pooling (bool): local_pooling
        act_fn (str): act_fn
        in_channels (int): in_channels
        se_channels (int): se_channels
        out_channels (int): out_channels
        cur_stage (int): cur_stage
        padding_type (str): padding_type
        model_name (str): model_name
    """

    def __init__(self,
                 local_pooling: bool,
                 act_fn: str,
                 in_channels: int,
                 se_channels: int,
                 out_channels: int,
                 cur_stage: int,
                 padding_type: str,
                 model_name: str):
        super(SE, self).__init__()

        self._local_pooling = local_pooling
        self._act = activation_fn(act_fn)

        # Squeeze and Excitation layer.
        self._se_reduce = Conv2ds(
            in_channels,
            se_channels,
            1,
            stride=1,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)
        self._se_expand = Conv2ds(
            se_channels,
            out_channels,
            1,
            stride=1,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)

    def forward(self, x):
        if self._local_pooling:
            se_tensor = F.adaptive_avg_pool2d(x, output_size=1)
        else:
            se_tensor = paddle.mean(x, axis=[2, 3], keepdim=True)
        se_tensor = self._se_expand(self._act(self._se_reduce(se_tensor)))
        return F.sigmoid(se_tensor) * x


class MBConvBlock(nn.Layer):
    """A class of MBConv: Mobile Inverted Residual Bottleneck.

    Args:
        se_ratio (int): se_ratio
        in_channels (int): in_channels
        expand_ratio (int): expand_ratio
        kernel_size (int): kernel_size
        strides (int): strides
        out_channels (int): out_channels
        bn_momentum (float): bn_momentum
        bn_epsilon (float): bn_epsilon
        local_pooling (bool): local_pooling
        conv_dropout (float): conv_dropout
        cur_stage (int): cur_stage
        padding_type (str): padding_type
        model_name (str): model_name
    """

    def __init__(self,
                 se_ratio: int,
                 in_channels: int,
                 expand_ratio: int,
                 kernel_size: int,
                 strides: int,
                 out_channels: int,
                 bn_momentum: float,
                 bn_epsilon: float,
                 local_pooling: bool,
                 conv_dropout: float,
                 cur_stage: int,
                 padding_type: str,
                 model_name: str):
        super(MBConvBlock, self).__init__()

        self.se_ratio = se_ratio
        self.in_channels = in_channels
        self.expand_ratio = expand_ratio
        self.kernel_size = kernel_size
        self.strides = strides
        self.out_channels = out_channels

        self.bn_momentum = bn_momentum
        self.bn_epsilon = bn_epsilon

        self._local_pooling = local_pooling
        self.act_fn = None
        self.conv_dropout = conv_dropout

        self._act = activation_fn(None)
        self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1)
        """Builds block according to the arguments."""
        expand_channels = self.in_channels * self.expand_ratio
        kernel_size = self.kernel_size

        # Expansion phase. Called if not using fused convolutions and expansion
        # phase is necessary.
        if self.expand_ratio != 1:
            self._expand_conv = Conv2ds(
                self.in_channels,
                expand_channels,
                1,
                stride=1,
                use_bias=False,
                padding_type=padding_type,
                model_name=model_name,
                cur_stage=cur_stage)
            self._norm0 = nn.BatchNorm2D(
                expand_channels,
                self.bn_momentum,
                self.bn_epsilon,
                weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
                bias_attr=ParamAttr(regularizer=L2Decay(0.0)))

        # Depth-wise convolution phase. Called if not using fused convolutions.
        self._depthwise_conv = Conv2ds(
            expand_channels,
            expand_channels,
            kernel_size,
            padding=kernel_size // 2,
            stride=self.strides,
            groups=expand_channels,
            use_bias=False,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)

        self._norm1 = nn.BatchNorm2D(
            expand_channels,
            self.bn_momentum,
            self.bn_epsilon,
            weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))

        if self._has_se:
            num_reduced_filters = max(1, int(self.in_channels * self.se_ratio))
            self._se = SE(self._local_pooling, None, expand_channels,
                          num_reduced_filters, expand_channels, cur_stage,
                          padding_type, model_name)
        else:
            self._se = None

        # Output phase.
        self._project_conv = Conv2ds(
            expand_channels,
            self.out_channels,
            1,
            stride=1,
            use_bias=False,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)
        self._norm2 = nn.BatchNorm2D(
            self.out_channels,
            self.bn_momentum,
            self.bn_epsilon,
            weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
        self.drop_out = nn.Dropout(self.conv_dropout)

    def residual(self, inputs, x, survival_prob):
        if (self.strides == 1 and self.in_channels == self.out_channels):
            # Apply only if skip connection presents.
            if survival_prob:
                x = drop_path(x, self.training, survival_prob)
            x = paddle.add(x, inputs)

        return x

    def forward(self, inputs, survival_prob=None):
        """Implementation of call().

        Args:
            inputs: the inputs tensor.
            survival_prob: float, between 0 to 1, drop connect rate.

        Returns:
            A output tensor.
        """
        x = inputs
        if self.expand_ratio != 1:
            x = self._act(self._norm0(self._expand_conv(x)))

        x = self._act(self._norm1(self._depthwise_conv(x)))

        if self.conv_dropout and self.expand_ratio > 1:
            x = self.drop_out(x)

        if self._se:
            x = self._se(x)

        x = self._norm2(self._project_conv(x))
        x = self.residual(inputs, x, survival_prob)

        return x


class FusedMBConvBlock(MBConvBlock):
    """Fusing the proj conv1x1 and depthwise_conv into a conv2d."""

    def __init__(self, se_ratio, in_channels, expand_ratio, kernel_size,
                 strides, out_channels, bn_momentum, bn_epsilon, local_pooling,
                 conv_dropout, cur_stage, padding_type, model_name):
        """Builds block according to the arguments."""
        super(MBConvBlock, self).__init__()
        self.se_ratio = se_ratio
        self.in_channels = in_channels
        self.expand_ratio = expand_ratio
        self.kernel_size = kernel_size
        self.strides = strides
        self.out_channels = out_channels

        self.bn_momentum = bn_momentum
        self.bn_epsilon = bn_epsilon

        self._local_pooling = local_pooling
        self.act_fn = None
        self.conv_dropout = conv_dropout

        self._act = activation_fn(None)
        self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1)

        expand_channels = self.in_channels * self.expand_ratio
        kernel_size = self.kernel_size
        if self.expand_ratio != 1:
            # Expansion phase:
            self._expand_conv = Conv2ds(
                self.in_channels,
                expand_channels,
                kernel_size,
                padding=kernel_size // 2,
                stride=self.strides,
                use_bias=False,
                padding_type=padding_type,
                model_name=model_name,
                cur_stage=cur_stage)
            self._norm0 = nn.BatchNorm2D(
                expand_channels,
                self.bn_momentum,
                self.bn_epsilon,
                weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
                bias_attr=ParamAttr(regularizer=L2Decay(0.0)))

        if self._has_se:
            num_reduced_filters = max(1, int(self.in_channels * self.se_ratio))
            self._se = SE(self._local_pooling, None, expand_channels,
                          num_reduced_filters, expand_channels, cur_stage,
                          padding_type, model_name)
        else:
            self._se = None

        # Output phase:
        self._project_conv = Conv2ds(
            expand_channels,
            self.out_channels,
            1 if (self.expand_ratio != 1) else kernel_size,
            padding=(1 if (self.expand_ratio != 1) else kernel_size) // 2,
            stride=1 if (self.expand_ratio != 1) else self.strides,
            use_bias=False,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)
        self._norm1 = nn.BatchNorm2D(
            self.out_channels,
            self.bn_momentum,
            self.bn_epsilon,
            weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
        self.drop_out = nn.Dropout(conv_dropout)

    def forward(self, inputs, survival_prob=None):
        """Implementation of call().

        Args:
            inputs: the inputs tensor.
            training: boolean, whether the model is constructed for training.
            survival_prob: float, between 0 to 1, drop connect rate.

        Returns:
            A output tensor.
        """
        x = inputs
        if self.expand_ratio != 1:
            x = self._act(self._norm0(self._expand_conv(x)))

        if self.conv_dropout and self.expand_ratio > 1:
            x = self.drop_out(x)

        if self._se:
            x = self._se(x)

        x = self._norm1(self._project_conv(x))
        if self.expand_ratio == 1:
            x = self._act(x)  # add act if no expansion.

        x = self.residual(inputs, x, survival_prob)
        return x


class Stem(nn.Layer):
    """Stem layer at the begining of the network."""

    def __init__(self, width_coefficient, depth_divisor, min_depth, skip,
                 bn_momentum, bn_epsilon, act_fn, stem_channels, cur_stage,
                 padding_type, model_name):
        super(Stem, self).__init__()
        self._conv_stem = Conv2ds(
            3,
            round_filters(stem_channels, width_coefficient, depth_divisor,
                          min_depth, skip),
            3,
            padding=1,
            stride=2,
            use_bias=False,
            padding_type=padding_type,
            model_name=model_name,
            cur_stage=cur_stage)
        self._norm = nn.BatchNorm2D(
            round_filters(stem_channels, width_coefficient, depth_divisor,
                          min_depth, skip),
            bn_momentum,
            bn_epsilon,
            weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
        self._act = activation_fn(act_fn)

    def forward(self, inputs):
        return self._act(self._norm(self._conv_stem(inputs)))


class Head(nn.Layer):
    """Head layer for network outputs."""

    def __init__(self,
                 in_channels,
                 feature_size,
                 bn_momentum,
                 bn_epsilon,
                 act_fn,
                 dropout_rate,
                 local_pooling,
                 width_coefficient,
                 depth_divisor,
                 min_depth,
                 skip=False):
        super(Head, self).__init__()
        self.in_channels = in_channels
        self.feature_size = feature_size
        self.bn_momentum = bn_momentum
        self.bn_epsilon = bn_epsilon
        self.dropout_rate = dropout_rate
        self._local_pooling = local_pooling
        self._conv_head = nn.Conv2D(
            in_channels,
            round_filters(self.feature_size or 1280, width_coefficient,
                          depth_divisor, min_depth, skip),
            kernel_size=1,
            stride=1,
            bias_attr=False)
        self._norm = nn.BatchNorm2D(
            round_filters(self.feature_size or 1280, width_coefficient,
                          depth_divisor, min_depth, skip),
            self.bn_momentum,
            self.bn_epsilon,
            weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
        self._act = activation_fn(act_fn)

        self._avg_pooling = nn.AdaptiveAvgPool2D(output_size=1)

759 760 761 762
        if isinstance(self.dropout_rate,
                      (list, tuple)) or self.dropout_rate > 0:
            self._dropout = nn.Dropout(self.dropout_rate[0] if isinstance(
                self.dropout_rate, (list, tuple)) else self.dropout_rate)
763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993
        else:
            self._dropout = None

    def forward(self, x):
        """Call the layer."""
        outputs = self._act(self._norm(self._conv_head(x)))

        if self._local_pooling:
            outputs = F.adaptive_avg_pool2d(outputs, output_size=1)
            if self._dropout:
                outputs = self._dropout(outputs)
            if self._fc:
                outputs = paddle.squeeze(outputs, axis=[2, 3])
                outputs = self._fc(outputs)
        else:
            outputs = self._avg_pooling(outputs)
            if self._dropout:
                outputs = self._dropout(outputs)
        return paddle.flatten(outputs, start_axis=1)


class EfficientNetV2(nn.Layer):
    """A class implements tf.keras.Model.

        Reference: https://arxiv.org/abs/1807.11626
    """

    def __init__(self,
                 model_name,
                 blocks_args=None,
                 mconfig=None,
                 include_top=True,
                 class_num=1000,
                 padding_type="SAME"):
        """Initializes an `Model` instance.

        Args:
            model_name: A string of model name.
            model_config: A dict of model configurations or a string of hparams.
        Raises:
            ValueError: when blocks_args is not specified as a list.
        """
        super(EfficientNetV2, self).__init__()
        self.blocks_args = blocks_args
        self.mconfig = mconfig
        """Builds a model."""
        self._blocks = nn.LayerList()

        cur_stage = 0
        # Stem part.
        self._stem = Stem(
            self.mconfig.width_coefficient,
            self.mconfig.depth_divisor,
            self.mconfig.min_depth,
            False,
            self.mconfig.bn_momentum,
            self.mconfig.bn_epsilon,
            self.mconfig.act_fn,
            stem_channels=self.blocks_args[0].in_channels,
            cur_stage=cur_stage,
            padding_type=padding_type,
            model_name=model_name)
        cur_stage += 1

        # Builds blocks.
        for block_args in self.blocks_args:
            assert block_args.num_repeat > 0
            # Update block input and output filters based on depth multiplier.
            in_channels = round_filters(
                block_args.in_channels, self.mconfig.width_coefficient,
                self.mconfig.depth_divisor, self.mconfig.min_depth, False)
            out_channels = round_filters(
                block_args.out_channels, self.mconfig.width_coefficient,
                self.mconfig.depth_divisor, self.mconfig.min_depth, False)

            repeats = round_repeats(block_args.num_repeat,
                                    self.mconfig.depth_coefficient)
            block_args.update(
                dict(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    num_repeat=repeats))

            # The first block needs to take care of stride and filter size increase.
            conv_block = {
                0: MBConvBlock,
                1: FusedMBConvBlock
            }[block_args.conv_type]
            self._blocks.append(
                conv_block(block_args.se_ratio, block_args.in_channels,
                           block_args.expand_ratio, block_args.kernel_size,
                           block_args.strides, block_args.out_channels,
                           self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
                           self.mconfig.local_pooling, self.mconfig.
                           conv_dropout, cur_stage, padding_type, model_name))
            if block_args.num_repeat > 1:  # rest of blocks with the same block_arg
                block_args.in_channels = block_args.out_channels
                block_args.strides = 1
            for _ in range(block_args.num_repeat - 1):
                self._blocks.append(
                    conv_block(
                        block_args.se_ratio, block_args.in_channels,
                        block_args.expand_ratio, block_args.kernel_size,
                        block_args.strides, block_args.out_channels,
                        self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
                        self.mconfig.local_pooling, self.mconfig.conv_dropout,
                        cur_stage, padding_type, model_name))
            cur_stage += 1

        # Head part.
        self._head = Head(
            self.blocks_args[-1].out_channels, self.mconfig.feature_size,
            self.mconfig.bn_momentum, self.mconfig.bn_epsilon,
            self.mconfig.act_fn, self.mconfig.dropout_rate,
            self.mconfig.local_pooling, self.mconfig.width_coefficient,
            self.mconfig.depth_divisor, self.mconfig.min_depth, False)

        # top part for classification
        if include_top and class_num:
            self._fc = nn.Linear(
                self.mconfig.feature_size,
                class_num,
                bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
        else:
            self._fc = None

        # initialize weight
        def _init_weights(m):
            if isinstance(m, nn.Conv2D):
                out_filters, in_channels, kernel_height, kernel_width = m.weight.shape
                if in_channels == 1 and out_filters > in_channels:
                    out_filters = in_channels
                fan_out = int(kernel_height * kernel_width * out_filters)
                Normal(mean=0.0, std=np.sqrt(2.0 / fan_out))(m.weight)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / np.sqrt(m.weight.shape[1])
                Uniform(-init_range, init_range)(m.weight)
                Constant(0.0)(m.bias)

        self.apply(_init_weights)

    def forward(self, inputs):
        # Calls Stem layers
        outputs = self._stem(inputs)
        # print(f"stem: {outputs.mean().item():.10f}")

        # Calls blocks.
        for idx, block in enumerate(self._blocks):
            survival_prob = self.mconfig.survival_prob
            if survival_prob:
                drop_rate = 1.0 - survival_prob
                survival_prob = 1.0 - drop_rate * float(idx) / len(
                    self._blocks)
            outputs = block(outputs, survival_prob=survival_prob)

        # Head to obtain the final feature.
        outputs = self._head(outputs)
        # Calls final dense layers and returns logits.
        if self._fc:
            outputs = self._fc(outputs)

        return outputs


def _load_pretrained(pretrained, model, model_url, use_ssld=False):
    if pretrained is False:
        pass
    elif pretrained is True:
        load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
    elif isinstance(pretrained, str):
        load_dygraph_pretrain(model, pretrained)
    else:
        raise RuntimeError(
            "pretrained type is not available. Please use `string` or `boolean` type."
        )


def EfficientNetV2_S(include_top=True, pretrained=False, **kwargs):
    """Get a V2 model instance.

    Returns:
        nn.Layer: A single model instantce
    """
    model_name = "efficientnetv2-s"
    model_config = efficientnetv2_config(model_name)
    model = EfficientNetV2(model_name, model_config.model.blocks_args,
                           model_config.model, include_top, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_S"])
    return model


def EfficientNetV2_M(include_top=True, pretrained=False, **kwargs):
    """Get a V2 model instance.

    Returns:
        nn.Layer: A single model instantce
    """
    model_name = "efficientnetv2-m"
    model_config = efficientnetv2_config(model_name)
    model = EfficientNetV2(model_name, model_config.model.blocks_args,
                           model_config.model, include_top, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_M"])
    return model


def EfficientNetV2_L(include_top=True, pretrained=False, **kwargs):
    """Get a V2 model instance.

    Returns:
        nn.Layer: A single model instantce
    """
    model_name = "efficientnetv2-l"
    model_config = efficientnetv2_config(model_name)
    model = EfficientNetV2(model_name, model_config.model.blocks_args,
                           model_config.model, include_top, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_L"])
    return model


def EfficientNetV2_XL(include_top=True, pretrained=False, **kwargs):
    """Get a V2 model instance.

    Returns:
        nn.Layer: A single model instantce
    """
    model_name = "efficientnetv2-xl"
    model_config = efficientnetv2_config(model_name)
    model = EfficientNetV2(model_name, model_config.model.blocks_args,
                           model_config.model, include_top, **kwargs)
    _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_XL"])
    return model