diff --git a/deploy/utils/config.py b/deploy/utils/config.py index b5a2be503cf087a602c60f4cfd6f9673a3153bac..30628f32bda3a1bc19d5f60a21d9e574a3f83eda 100644 --- a/deploy/utils/config.py +++ b/deploy/utils/config.py @@ -33,7 +33,7 @@ class AttrDict(dict): self[key] = value def __deepcopy__(self, content): - return copy.deepcopy(dict(self)) + return AttrDict(copy.deepcopy(dict(self))) def create_attr_dict(yaml_config): diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 8cadb7529abaf6c623fdc3931ec1ffea7a22889f..edc5385fed4a5d9854aba55fe8acf7192e6846b0 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -38,6 +38,7 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131 from .model_zoo.dsnet import DSNet_tiny, DSNet_small, DSNet_base from .model_zoo.densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264 from .model_zoo.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7, EfficientNetB0_small +from .model_zoo.efficientnet_v2 import EfficientNetV2_S from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269 from .model_zoo.googlenet import GoogLeNet from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 diff --git a/ppcls/arch/backbone/model_zoo/efficientnet_v2.py b/ppcls/arch/backbone/model_zoo/efficientnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..37514378d221e6f64817f2638d4e0386e7cb7a40 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/efficientnet_v2.py @@ -0,0 +1,991 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was based on https://github.com/lukemelas/EfficientNet-PyTorch +# reference: https://arxiv.org/abs/1905.11946 + +import math +import re + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn.initializer import Constant, Normal, Uniform +from paddle.regularizer import L2Decay + +from ppcls.utils.config import AttrDict + +from ....utils.save_load import (load_dygraph_pretrain, + load_dygraph_pretrain_from_url) + +MODEL_URLS = { + "EfficientNetV2_S": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_S_pretrained.pdparams", + "EfficientNetV2_M": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_M_pretrained.pdparams", + "EfficientNetV2_L": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_L_pretrained.pdparams", + "EfficientNetV2_XL": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/EfficientNetV2_XL_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +inp_shape = { + "efficientnetv2-s": [384, 192, 192, 96, 48, 24, 24, 12], + "efficientnetv2-m": [384, 192, 192, 96, 48, 24, 24, 12], + "efficientnetv2-l": [384, 192, 192, 96, 48, 24, 24, 12], + "efficientnetv2-xl": [384, 192, 192, 96, 48, 24, 24, 12], +} + + +def cal_padding(img_size, stride, kernel_size): + """Calculate padding size.""" + if img_size % stride == 0: + out_size = max(kernel_size - stride, 0) + else: + out_size = max(kernel_size - (img_size % stride), 0) + return out_size // 2, out_size - out_size // 2 + + +class Conv2ds(nn.Layer): + """Customed Conv2D with tensorflow's padding style + + Args: + input_channels (int): input channels + output_channels (int): output channels + kernel_size (int): filter size + stride (int, optional): stride. Defaults to 1. + padding (int, optional): padding. Defaults to 0. + groups (int, optional): groups. Defaults to None. + act (str, optional): act. Defaults to None. + use_bias (bool, optional): use_bias. Defaults to None. + padding_type (str, optional): padding_type. Defaults to None. + model_name (str, optional): model name. Defaults to None. + cur_stage (int, optional): current stage. Defaults to None. + + Returns: + nn.Layer: Customed Conv2D instance + """ + + def __init__(self, + input_channels: int, + output_channels: int, + kernel_size: int, + stride=1, + padding=0, + groups=None, + act=None, + use_bias=None, + padding_type=None, + model_name=None, + cur_stage=None): + super(Conv2ds, self).__init__() + assert act in [None, "swish", "sigmoid"] + self._act = act + + def get_padding(kernel_size, stride=1, dilation=1): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + inps = inp_shape[model_name][cur_stage] + self.need_crop = False + if padding_type == "SAME": + top_padding, bottom_padding = cal_padding(inps, stride, + kernel_size) + left_padding, right_padding = cal_padding(inps, stride, + kernel_size) + height_padding = bottom_padding + width_padding = right_padding + if top_padding != bottom_padding or left_padding != right_padding: + height_padding = top_padding + stride + width_padding = left_padding + stride + self.need_crop = True + padding = [height_padding, width_padding] + elif padding_type == "VALID": + height_padding = 0 + width_padding = 0 + padding = [height_padding, width_padding] + elif padding_type == "DYNAMIC": + padding = get_padding(kernel_size, stride) + else: + padding = padding_type + + groups = 1 if groups is None else groups + self._conv = nn.Conv2D( + input_channels, + output_channels, + kernel_size, + groups=groups, + stride=stride, + padding=padding, + weight_attr=None, + bias_attr=use_bias + if not use_bias else ParamAttr(regularizer=L2Decay(0.0))) + + def forward(self, inputs): + x = self._conv(inputs) + if self._act == "swish": + x = F.swish(x) + elif self._act == "sigmoid": + x = F.sigmoid(x) + + if self.need_crop: + x = x[:, :, 1:, 1:] + return x + + +class BlockDecoder(object): + """Block Decoder for readability.""" + + def _decode_block_string(self, block_string): + """Gets a block through a string notation of arguments.""" + assert isinstance(block_string, str) + ops = block_string.split('_') + options = AttrDict() + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + t = AttrDict( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + in_channels=int(options['i']), + out_channels=int(options['o']), + expand_ratio=int(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + strides=int(options['s']), + conv_type=int(options['c']) if 'c' in options else 0, ) + return t + + def _encode_block_string(self, block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d' % block.strides, + 'e%s' % block.expand_ratio, + 'i%d' % block.in_channels, + 'o%d' % block.out_channels, + 'c%d' % block.conv_type, + 'f%d' % block.fused_conv, + ] + if block.se_ratio > 0 and block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + return '_'.join(args) + + def decode(self, string_list): + """Decodes a list of string notations to specify blocks inside the network. + + Args: + string_list: a list of strings, each string is a notation of block. + + Returns: + A list of namedtuples to represent blocks arguments. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(self._decode_block_string(block_string)) + return blocks_args + + def encode(self, blocks_args): + """Encodes a list of Blocks to a list of strings. + + Args: + blocks_args: A list of namedtuples to represent blocks arguments. + Returns: + a list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(self._encode_block_string(block)) + return block_strings + + +#################### EfficientNet V2 configs #################### +v2_base_block = [ # The baseline config for v2 models. + "r1_k3_s1_e1_i32_o16_c1", + "r2_k3_s2_e4_i16_o32_c1", + "r2_k3_s2_e4_i32_o48_c1", + "r3_k3_s2_e4_i48_o96_se0.25", + "r5_k3_s1_e6_i96_o112_se0.25", + "r8_k3_s2_e6_i112_o192_se0.25", +] + +v2_s_block = [ # about base * (width1.4, depth1.8) + "r2_k3_s1_e1_i24_o24_c1", + "r4_k3_s2_e4_i24_o48_c1", + "r4_k3_s2_e4_i48_o64_c1", + "r6_k3_s2_e4_i64_o128_se0.25", + "r9_k3_s1_e6_i128_o160_se0.25", + "r15_k3_s2_e6_i160_o256_se0.25", +] + +v2_m_block = [ # about base * (width1.6, depth2.2) + "r3_k3_s1_e1_i24_o24_c1", + "r5_k3_s2_e4_i24_o48_c1", + "r5_k3_s2_e4_i48_o80_c1", + "r7_k3_s2_e4_i80_o160_se0.25", + "r14_k3_s1_e6_i160_o176_se0.25", + "r18_k3_s2_e6_i176_o304_se0.25", + "r5_k3_s1_e6_i304_o512_se0.25", +] + +v2_l_block = [ # about base * (width2.0, depth3.1) + "r4_k3_s1_e1_i32_o32_c1", + "r7_k3_s2_e4_i32_o64_c1", + "r7_k3_s2_e4_i64_o96_c1", + "r10_k3_s2_e4_i96_o192_se0.25", + "r19_k3_s1_e6_i192_o224_se0.25", + "r25_k3_s2_e6_i224_o384_se0.25", + "r7_k3_s1_e6_i384_o640_se0.25", +] + +v2_xl_block = [ # only for 21k pretraining. + "r4_k3_s1_e1_i32_o32_c1", + "r8_k3_s2_e4_i32_o64_c1", + "r8_k3_s2_e4_i64_o96_c1", + "r16_k3_s2_e4_i96_o192_se0.25", + "r24_k3_s1_e6_i192_o256_se0.25", + "r32_k3_s2_e6_i256_o512_se0.25", + "r8_k3_s1_e6_i512_o640_se0.25", +] +efficientnetv2_params = { + # params: (block, width, depth, dropout) + "efficientnetv2-s": (v2_s_block, 1.0, 1.0, 0.2), + "efficientnetv2-m": (v2_m_block, 1.0, 1.0, 0.3), + "efficientnetv2-l": (v2_l_block, 1.0, 1.0, 0.4), + "efficientnetv2-xl": (v2_xl_block, 1.0, 1.0, 0.4), +} + + +def efficientnetv2_config(model_name: str): + """EfficientNetV2 model config.""" + block, width, depth, dropout = efficientnetv2_params[model_name] + + cfg = AttrDict(model=AttrDict( + model_name=model_name, + blocks_args=BlockDecoder().decode(block), + width_coefficient=width, + depth_coefficient=depth, + dropout_rate=dropout, + feature_size=1280, + bn_momentum=0.9, + bn_epsilon=1e-3, + depth_divisor=8, + min_depth=8, + act_fn="silu", + survival_prob=0.8, + local_pooling=False, + conv_dropout=None, + num_classes=1000)) + return cfg + + +def get_model_config(model_name: str): + """Main entry for model name to config.""" + if model_name.startswith("efficientnetv2-"): + return efficientnetv2_config(model_name) + raise ValueError(f"Unknown model_name {model_name}") + + +################################################################################ + + +def round_filters(filters, + width_coefficient, + depth_divisor, + min_depth, + skip=False): + """Round number of filters based on depth multiplier.""" + multiplier = width_coefficient + divisor = depth_divisor + min_depth = min_depth + if skip or not multiplier: + return filters + + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + return int(new_filters) + + +def round_repeats(repeats, multiplier, skip=False): + """Round number of filters based on depth multiplier.""" + if skip or not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def activation_fn(act_fn: str): + """Customized non-linear activation type.""" + if not act_fn: + return nn.Silu() + elif act_fn in ("silu", "swish"): + return nn.Swish() + elif act_fn == "relu": + return nn.ReLU() + elif act_fn == "relu6": + return nn.ReLU6() + elif act_fn == "elu": + return nn.ELU() + elif act_fn == "leaky_relu": + return nn.LeakyReLU() + elif act_fn == "selu": + return nn.SELU() + elif act_fn == "mish": + return nn.Mish() + else: + raise ValueError("Unsupported act_fn {}".format(act_fn)) + + +def drop_path(x, training=False, survival_prob=1.0): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if not training: + return x + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + keep_prob = paddle.to_tensor(survival_prob) + random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class SE(nn.Layer): + """Squeeze-and-excitation layer. + + Args: + local_pooling (bool): local_pooling + act_fn (str): act_fn + in_channels (int): in_channels + se_channels (int): se_channels + out_channels (int): out_channels + cur_stage (int): cur_stage + padding_type (str): padding_type + model_name (str): model_name + """ + + def __init__(self, + local_pooling: bool, + act_fn: str, + in_channels: int, + se_channels: int, + out_channels: int, + cur_stage: int, + padding_type: str, + model_name: str): + super(SE, self).__init__() + + self._local_pooling = local_pooling + self._act = activation_fn(act_fn) + + # Squeeze and Excitation layer. + self._se_reduce = Conv2ds( + in_channels, + se_channels, + 1, + stride=1, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._se_expand = Conv2ds( + se_channels, + out_channels, + 1, + stride=1, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + + def forward(self, x): + if self._local_pooling: + se_tensor = F.adaptive_avg_pool2d(x, output_size=1) + else: + se_tensor = paddle.mean(x, axis=[2, 3], keepdim=True) + se_tensor = self._se_expand(self._act(self._se_reduce(se_tensor))) + return F.sigmoid(se_tensor) * x + + +class MBConvBlock(nn.Layer): + """A class of MBConv: Mobile Inverted Residual Bottleneck. + + Args: + se_ratio (int): se_ratio + in_channels (int): in_channels + expand_ratio (int): expand_ratio + kernel_size (int): kernel_size + strides (int): strides + out_channels (int): out_channels + bn_momentum (float): bn_momentum + bn_epsilon (float): bn_epsilon + local_pooling (bool): local_pooling + conv_dropout (float): conv_dropout + cur_stage (int): cur_stage + padding_type (str): padding_type + model_name (str): model_name + """ + + def __init__(self, + se_ratio: int, + in_channels: int, + expand_ratio: int, + kernel_size: int, + strides: int, + out_channels: int, + bn_momentum: float, + bn_epsilon: float, + local_pooling: bool, + conv_dropout: float, + cur_stage: int, + padding_type: str, + model_name: str): + super(MBConvBlock, self).__init__() + + self.se_ratio = se_ratio + self.in_channels = in_channels + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.out_channels = out_channels + + self.bn_momentum = bn_momentum + self.bn_epsilon = bn_epsilon + + self._local_pooling = local_pooling + self.act_fn = None + self.conv_dropout = conv_dropout + + self._act = activation_fn(None) + self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1) + """Builds block according to the arguments.""" + expand_channels = self.in_channels * self.expand_ratio + kernel_size = self.kernel_size + + # Expansion phase. Called if not using fused convolutions and expansion + # phase is necessary. + if self.expand_ratio != 1: + self._expand_conv = Conv2ds( + self.in_channels, + expand_channels, + 1, + stride=1, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._norm0 = nn.BatchNorm2D( + expand_channels, + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + # Depth-wise convolution phase. Called if not using fused convolutions. + self._depthwise_conv = Conv2ds( + expand_channels, + expand_channels, + kernel_size, + padding=kernel_size // 2, + stride=self.strides, + groups=expand_channels, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + + self._norm1 = nn.BatchNorm2D( + expand_channels, + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + if self._has_se: + num_reduced_filters = max(1, int(self.in_channels * self.se_ratio)) + self._se = SE(self._local_pooling, None, expand_channels, + num_reduced_filters, expand_channels, cur_stage, + padding_type, model_name) + else: + self._se = None + + # Output phase. + self._project_conv = Conv2ds( + expand_channels, + self.out_channels, + 1, + stride=1, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._norm2 = nn.BatchNorm2D( + self.out_channels, + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.drop_out = nn.Dropout(self.conv_dropout) + + def residual(self, inputs, x, survival_prob): + if (self.strides == 1 and self.in_channels == self.out_channels): + # Apply only if skip connection presents. + if survival_prob: + x = drop_path(x, self.training, survival_prob) + x = paddle.add(x, inputs) + + return x + + def forward(self, inputs, survival_prob=None): + """Implementation of call(). + + Args: + inputs: the inputs tensor. + survival_prob: float, between 0 to 1, drop connect rate. + + Returns: + A output tensor. + """ + x = inputs + if self.expand_ratio != 1: + x = self._act(self._norm0(self._expand_conv(x))) + + x = self._act(self._norm1(self._depthwise_conv(x))) + + if self.conv_dropout and self.expand_ratio > 1: + x = self.drop_out(x) + + if self._se: + x = self._se(x) + + x = self._norm2(self._project_conv(x)) + x = self.residual(inputs, x, survival_prob) + + return x + + +class FusedMBConvBlock(MBConvBlock): + """Fusing the proj conv1x1 and depthwise_conv into a conv2d.""" + + def __init__(self, se_ratio, in_channels, expand_ratio, kernel_size, + strides, out_channels, bn_momentum, bn_epsilon, local_pooling, + conv_dropout, cur_stage, padding_type, model_name): + """Builds block according to the arguments.""" + super(MBConvBlock, self).__init__() + self.se_ratio = se_ratio + self.in_channels = in_channels + self.expand_ratio = expand_ratio + self.kernel_size = kernel_size + self.strides = strides + self.out_channels = out_channels + + self.bn_momentum = bn_momentum + self.bn_epsilon = bn_epsilon + + self._local_pooling = local_pooling + self.act_fn = None + self.conv_dropout = conv_dropout + + self._act = activation_fn(None) + self._has_se = (self.se_ratio is not None and 0 < self.se_ratio <= 1) + + expand_channels = self.in_channels * self.expand_ratio + kernel_size = self.kernel_size + if self.expand_ratio != 1: + # Expansion phase: + self._expand_conv = Conv2ds( + self.in_channels, + expand_channels, + kernel_size, + padding=kernel_size // 2, + stride=self.strides, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._norm0 = nn.BatchNorm2D( + expand_channels, + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + + if self._has_se: + num_reduced_filters = max(1, int(self.in_channels * self.se_ratio)) + self._se = SE(self._local_pooling, None, expand_channels, + num_reduced_filters, expand_channels, cur_stage, + padding_type, model_name) + else: + self._se = None + + # Output phase: + self._project_conv = Conv2ds( + expand_channels, + self.out_channels, + 1 if (self.expand_ratio != 1) else kernel_size, + padding=(1 if (self.expand_ratio != 1) else kernel_size) // 2, + stride=1 if (self.expand_ratio != 1) else self.strides, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._norm1 = nn.BatchNorm2D( + self.out_channels, + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.drop_out = nn.Dropout(conv_dropout) + + def forward(self, inputs, survival_prob=None): + """Implementation of call(). + + Args: + inputs: the inputs tensor. + training: boolean, whether the model is constructed for training. + survival_prob: float, between 0 to 1, drop connect rate. + + Returns: + A output tensor. + """ + x = inputs + if self.expand_ratio != 1: + x = self._act(self._norm0(self._expand_conv(x))) + + if self.conv_dropout and self.expand_ratio > 1: + x = self.drop_out(x) + + if self._se: + x = self._se(x) + + x = self._norm1(self._project_conv(x)) + if self.expand_ratio == 1: + x = self._act(x) # add act if no expansion. + + x = self.residual(inputs, x, survival_prob) + return x + + +class Stem(nn.Layer): + """Stem layer at the begining of the network.""" + + def __init__(self, width_coefficient, depth_divisor, min_depth, skip, + bn_momentum, bn_epsilon, act_fn, stem_channels, cur_stage, + padding_type, model_name): + super(Stem, self).__init__() + self._conv_stem = Conv2ds( + 3, + round_filters(stem_channels, width_coefficient, depth_divisor, + min_depth, skip), + 3, + padding=1, + stride=2, + use_bias=False, + padding_type=padding_type, + model_name=model_name, + cur_stage=cur_stage) + self._norm = nn.BatchNorm2D( + round_filters(stem_channels, width_coefficient, depth_divisor, + min_depth, skip), + bn_momentum, + bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self._act = activation_fn(act_fn) + + def forward(self, inputs): + return self._act(self._norm(self._conv_stem(inputs))) + + +class Head(nn.Layer): + """Head layer for network outputs.""" + + def __init__(self, + in_channels, + feature_size, + bn_momentum, + bn_epsilon, + act_fn, + dropout_rate, + local_pooling, + width_coefficient, + depth_divisor, + min_depth, + skip=False): + super(Head, self).__init__() + self.in_channels = in_channels + self.feature_size = feature_size + self.bn_momentum = bn_momentum + self.bn_epsilon = bn_epsilon + self.dropout_rate = dropout_rate + self._local_pooling = local_pooling + self._conv_head = nn.Conv2D( + in_channels, + round_filters(self.feature_size or 1280, width_coefficient, + depth_divisor, min_depth, skip), + kernel_size=1, + stride=1, + bias_attr=False) + self._norm = nn.BatchNorm2D( + round_filters(self.feature_size or 1280, width_coefficient, + depth_divisor, min_depth, skip), + self.bn_momentum, + self.bn_epsilon, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self._act = activation_fn(act_fn) + + self._avg_pooling = nn.AdaptiveAvgPool2D(output_size=1) + + if self.dropout_rate > 0: + self._dropout = nn.Dropout(self.dropout_rate) + else: + self._dropout = None + + def forward(self, x): + """Call the layer.""" + outputs = self._act(self._norm(self._conv_head(x))) + + if self._local_pooling: + outputs = F.adaptive_avg_pool2d(outputs, output_size=1) + if self._dropout: + outputs = self._dropout(outputs) + if self._fc: + outputs = paddle.squeeze(outputs, axis=[2, 3]) + outputs = self._fc(outputs) + else: + outputs = self._avg_pooling(outputs) + if self._dropout: + outputs = self._dropout(outputs) + return paddle.flatten(outputs, start_axis=1) + + +class EfficientNetV2(nn.Layer): + """A class implements tf.keras.Model. + + Reference: https://arxiv.org/abs/1807.11626 + """ + + def __init__(self, + model_name, + blocks_args=None, + mconfig=None, + include_top=True, + class_num=1000, + padding_type="SAME"): + """Initializes an `Model` instance. + + Args: + model_name: A string of model name. + model_config: A dict of model configurations or a string of hparams. + Raises: + ValueError: when blocks_args is not specified as a list. + """ + super(EfficientNetV2, self).__init__() + self.blocks_args = blocks_args + self.mconfig = mconfig + """Builds a model.""" + self._blocks = nn.LayerList() + + cur_stage = 0 + # Stem part. + self._stem = Stem( + self.mconfig.width_coefficient, + self.mconfig.depth_divisor, + self.mconfig.min_depth, + False, + self.mconfig.bn_momentum, + self.mconfig.bn_epsilon, + self.mconfig.act_fn, + stem_channels=self.blocks_args[0].in_channels, + cur_stage=cur_stage, + padding_type=padding_type, + model_name=model_name) + cur_stage += 1 + + # Builds blocks. + for block_args in self.blocks_args: + assert block_args.num_repeat > 0 + # Update block input and output filters based on depth multiplier. + in_channels = round_filters( + block_args.in_channels, self.mconfig.width_coefficient, + self.mconfig.depth_divisor, self.mconfig.min_depth, False) + out_channels = round_filters( + block_args.out_channels, self.mconfig.width_coefficient, + self.mconfig.depth_divisor, self.mconfig.min_depth, False) + + repeats = round_repeats(block_args.num_repeat, + self.mconfig.depth_coefficient) + block_args.update( + dict( + in_channels=in_channels, + out_channels=out_channels, + num_repeat=repeats)) + + # The first block needs to take care of stride and filter size increase. + conv_block = { + 0: MBConvBlock, + 1: FusedMBConvBlock + }[block_args.conv_type] + self._blocks.append( + conv_block(block_args.se_ratio, block_args.in_channels, + block_args.expand_ratio, block_args.kernel_size, + block_args.strides, block_args.out_channels, + self.mconfig.bn_momentum, self.mconfig.bn_epsilon, + self.mconfig.local_pooling, self.mconfig. + conv_dropout, cur_stage, padding_type, model_name)) + if block_args.num_repeat > 1: # rest of blocks with the same block_arg + block_args.in_channels = block_args.out_channels + block_args.strides = 1 + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + conv_block( + block_args.se_ratio, block_args.in_channels, + block_args.expand_ratio, block_args.kernel_size, + block_args.strides, block_args.out_channels, + self.mconfig.bn_momentum, self.mconfig.bn_epsilon, + self.mconfig.local_pooling, self.mconfig.conv_dropout, + cur_stage, padding_type, model_name)) + cur_stage += 1 + + # Head part. + self._head = Head( + self.blocks_args[-1].out_channels, self.mconfig.feature_size, + self.mconfig.bn_momentum, self.mconfig.bn_epsilon, + self.mconfig.act_fn, self.mconfig.dropout_rate, + self.mconfig.local_pooling, self.mconfig.width_coefficient, + self.mconfig.depth_divisor, self.mconfig.min_depth, False) + + # top part for classification + if include_top and class_num: + self._fc = nn.Linear( + self.mconfig.feature_size, + class_num, + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + else: + self._fc = None + + # initialize weight + def _init_weights(m): + if isinstance(m, nn.Conv2D): + out_filters, in_channels, kernel_height, kernel_width = m.weight.shape + if in_channels == 1 and out_filters > in_channels: + out_filters = in_channels + fan_out = int(kernel_height * kernel_width * out_filters) + Normal(mean=0.0, std=np.sqrt(2.0 / fan_out))(m.weight) + elif isinstance(m, nn.Linear): + init_range = 1.0 / np.sqrt(m.weight.shape[1]) + Uniform(-init_range, init_range)(m.weight) + Constant(0.0)(m.bias) + + self.apply(_init_weights) + + def forward(self, inputs): + # Calls Stem layers + outputs = self._stem(inputs) + # print(f"stem: {outputs.mean().item():.10f}") + + # Calls blocks. + for idx, block in enumerate(self._blocks): + survival_prob = self.mconfig.survival_prob + if survival_prob: + drop_rate = 1.0 - survival_prob + survival_prob = 1.0 - drop_rate * float(idx) / len( + self._blocks) + outputs = block(outputs, survival_prob=survival_prob) + + # Head to obtain the final feature. + outputs = self._head(outputs) + # Calls final dense layers and returns logits. + if self._fc: + outputs = self._fc(outputs) + + return outputs + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def EfficientNetV2_S(include_top=True, pretrained=False, **kwargs): + """Get a V2 model instance. + + Returns: + nn.Layer: A single model instantce + """ + model_name = "efficientnetv2-s" + model_config = efficientnetv2_config(model_name) + model = EfficientNetV2(model_name, model_config.model.blocks_args, + model_config.model, include_top, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_S"]) + return model + + +def EfficientNetV2_M(include_top=True, pretrained=False, **kwargs): + """Get a V2 model instance. + + Returns: + nn.Layer: A single model instantce + """ + model_name = "efficientnetv2-m" + model_config = efficientnetv2_config(model_name) + model = EfficientNetV2(model_name, model_config.model.blocks_args, + model_config.model, include_top, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_M"]) + return model + + +def EfficientNetV2_L(include_top=True, pretrained=False, **kwargs): + """Get a V2 model instance. + + Returns: + nn.Layer: A single model instantce + """ + model_name = "efficientnetv2-l" + model_config = efficientnetv2_config(model_name) + model = EfficientNetV2(model_name, model_config.model.blocks_args, + model_config.model, include_top, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_L"]) + return model + + +def EfficientNetV2_XL(include_top=True, pretrained=False, **kwargs): + """Get a V2 model instance. + + Returns: + nn.Layer: A single model instantce + """ + model_name = "efficientnetv2-xl" + model_config = efficientnetv2_config(model_name) + model = EfficientNetV2(model_name, model_config.model.blocks_args, + model_config.model, include_top, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["EfficientNetV2_XL"]) + return model diff --git a/ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml b/ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5644e6bad577828fd65fff45ae09e9edfc32f70 --- /dev/null +++ b/ppcls/configs/ImageNet/EfficientNetV2/EfficientNetV2_S.yaml @@ -0,0 +1,142 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 100 + eval_during_train: True + eval_interval: 1 + epochs: 350 + print_batch_step: 20 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 384, 384] + save_inference_dir: ./inference + train_mode: efficientnetv2 # progressive training + +AMP: + scale_loss: 65536 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +EMA: + decay: 0.9999 + +# model architecture +Arch: + name: EfficientNetV2_S + class_num: 1000 + use_sync_bn: True + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.65 # 8gpux128bs + warmup_epoch: 5 + regularizer: + name: L2 + coeff: 0.00001 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + scale: [0.05, 1.0] + size: 224 + - RandFlipImage: + flip_code: 1 + - RandAugmentV2: + num_layers: 2 + magnitude: 5 + - NormalizeImage: + scale: 1.0 + mean: [128.0, 128.0, 128.0] + std: [128.0, 128.0, 128.0] + order: "" + + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: True + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - CropImageAtRatio: + size: 384 + pad: 32 + interpolation: bilinear + - NormalizeImage: + scale: 1.0 + mean: [128.0, 128.0, 128.0] + std: [128.0, 128.0, 128.0] + order: "" + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - CropImageAtRatio: + size: 384 + pad: 32 + interpolation: bilinear + - NormalizeImage: + scale: 1.0 + mean: [128.0, 128.0, 128.0] + std: [128.0, 128.0, 128.0] + order: "" + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 63586fde369e0fe389a2db11305927640afabf6c..d34ba300689db40add888fa7173c8da3c3c67296 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -15,6 +15,7 @@ from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment from ppcls.data.preprocess.ops.randaugment import RandomApply +from ppcls.data.preprocess.ops.randaugment import RandAugmentV2 as RawRandAugmentV2 from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment from ppcls.data.preprocess.ops.cutout import Cutout @@ -25,6 +26,7 @@ from ppcls.data.preprocess.ops.grid import GridMask from ppcls.data.preprocess.ops.operators import DecodeImage from ppcls.data.preprocess.ops.operators import ResizeImage from ppcls.data.preprocess.ops.operators import CropImage +from ppcls.data.preprocess.ops.operators import CropImageAtRatio from ppcls.data.preprocess.ops.operators import CenterCrop, Resize from ppcls.data.preprocess.ops.operators import RandCropImage from ppcls.data.preprocess.ops.operators import RandCropImageV2 @@ -101,6 +103,13 @@ class RandAugment(RawRandAugment): return img +class RandAugmentV2(RawRandAugmentV2): + """ RandAugmentV2 wrapper to auto fit different img types """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class TimmAutoAugment(RawTimmAutoAugment): """ TimmAutoAugment wrapper to auto fit different img tyeps. """ diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 46dc615034c16e626ef53ac831157a98a624b96e..5fb08ed9c12477c00f96fe4b7d7a79d9f5dfc11e 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -319,6 +319,25 @@ class CropImage(object): return img[h_start:h_end, w_start:w_end, :] +class CropImageAtRatio(object): + """ crop image with specified size and padding""" + + def __init__(self, size: int, pad: int, interpolation="bilinear"): + self.size = size + self.ratio = size / (size + pad) + self.interpolation = interpolation + + def __call__(self, img): + height, width = img.shape[:2] + crop_size = int(self.ratio * min(height, width)) + + y = (height - crop_size) // 2 + x = (width - crop_size) // 2 + + crop_img = img[y:y + crop_size, x:x + crop_size, :] + return F.resize(crop_img, [self.size, self.size], self.interpolation) + + class Padv2(object): def __init__(self, size=None, diff --git a/ppcls/data/preprocess/ops/randaugment.py b/ppcls/data/preprocess/ops/randaugment.py index a05a78f5624c1296d88b38169665060db0f2699a..ecc6cb51d6249f7762c7d0c2c34afe6567608aec 100644 --- a/ppcls/data/preprocess/ops/randaugment.py +++ b/ppcls/data/preprocess/ops/randaugment.py @@ -15,12 +15,60 @@ # This code is based on https://github.com/heartInsert/randaugment # reference: https://arxiv.org/abs/1909.13719 -from PIL import Image, ImageEnhance, ImageOps -import numpy as np import random from .operators import RawColorJitter from paddle.vision.transforms import transforms as T +import numpy as np +from PIL import Image, ImageEnhance, ImageOps + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def cutout(image, pad_size, replace=0): + image_np = np.array(image) + image_height, image_width, _ = image_np.shape + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = np.random.randint(0, image_height + 1) + cutout_center_width = np.random.randint(0, image_width + 1) + + lower_pad = np.maximum(0, cutout_center_height - pad_size) + upper_pad = np.maximum(0, image_height - cutout_center_height - pad_size) + left_pad = np.maximum(0, cutout_center_width - pad_size) + right_pad = np.maximum(0, image_width - cutout_center_width - pad_size) + + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad) + ] + padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] + mask = np.pad(np.zeros( + cutout_shape, dtype=image_np.dtype), + padding_dims, + constant_values=1) + mask = np.expand_dims(mask, -1) + mask = np.tile(mask, [1, 1, 3]) + image_np = np.where( + np.equal(mask, 0), + np.full_like( + image_np, fill_value=replace, dtype=image_np.dtype), + image_np) + return Image.fromarray(image_np) + class RandAugment(object): def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)): @@ -95,10 +143,10 @@ class RandAugment(object): "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 1 + magnitude * rnd_ch_op([-1, 1])), - "autocontrast": lambda img, magnitude: + "autocontrast": lambda img, _: ImageOps.autocontrast(img), - "equalize": lambda img, magnitude: ImageOps.equalize(img), - "invert": lambda img, magnitude: ImageOps.invert(img) + "equalize": lambda img, _: ImageOps.equalize(img), + "invert": lambda img, _: ImageOps.invert(img) } def __call__(self, img): @@ -121,4 +169,85 @@ class RandomApply(object): def __call__(self, img): timg = self.trans(img) - return timg \ No newline at end of file + return timg + + +## RandAugment_EfficientNetV2 code below ## +class RandAugmentV2(RandAugment): + """Customed RandAugment for EfficientNetV2""" + + def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)): + super().__init__(num_layers, magnitude, fillcolor) + abso_level = self.magnitude / self.max_level # [5.0~10.0/10.0]=[0.5, 1.0] + self.level_map = { + "shearX": 0.3 * abso_level, + "shearY": 0.3 * abso_level, + "translateX": 100.0 * abso_level, + "translateY": 100.0 * abso_level, + "rotate": 30 * abso_level, + "color": 1.8 * abso_level + 0.1, + "posterize": int(4.0 * abso_level), + "solarize": int(256.0 * abso_level), + "solarize_add": int(110.0 * abso_level), + "contrast": 1.8 * abso_level + 0.1, + "sharpness": 1.8 * abso_level + 0.1, + "brightness": 1.8 * abso_level + 0.1, + "autocontrast": 0, + "equalize": 0, + "invert": 0, + "cutout": int(40 * abso_level) + } + + def rotate_with_fill(img, magnitude): + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite(rot, + Image.new("RGBA", rot.size, (128, ) * 4), + rot).convert(img.mode) + + rnd_ch_op = random.choice + + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0), + Image.NEAREST, + fillcolor=fillcolor), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0), + Image.NEAREST, + fillcolor=fillcolor), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * rnd_ch_op([-1, 1]), 0, 1, 0), + Image.NEAREST, + fillcolor=fillcolor), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * rnd_ch_op([-1, 1])), + Image.NEAREST, + fillcolor=fillcolor), + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude * rnd_ch_op([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(magnitude), + "posterize": lambda img, magnitude: + ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: + ImageOps.solarize(img, magnitude), + "solarize_add": lambda img, magnitude: + solarize_add(img, magnitude), + "contrast": lambda img, magnitude: + ImageEnhance.Contrast(img).enhance(magnitude), + "sharpness": lambda img, magnitude: + ImageEnhance.Sharpness(img).enhance(magnitude), + "brightness": lambda img, magnitude: + ImageEnhance.Brightness(img).enhance(magnitude), + "autocontrast": lambda img, _: + ImageOps.autocontrast(img), + "equalize": lambda img, _: ImageOps.equalize(img), + "invert": lambda img, _: ImageOps.invert(img), + "cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0]) + } diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index c438ab0c1c86eb175f067c7ae5fdfa626686611d..e933c955be675116751c29aea6e59715420d87de 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from ppcls.engine.train.train import train_epoch +from ppcls.engine.train.train_efficientnetv2 import train_epoch_efficientnetv2 from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch -from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl \ No newline at end of file +from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl diff --git a/ppcls/engine/train/train_efficientnetv2.py b/ppcls/engine/train/train_efficientnetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..b9eaeeedf213700f39e7e8cc8f4b25fb23f1cb14 --- /dev/null +++ b/ppcls/engine/train/train_efficientnetv2.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import, division, print_function + +import time + +import numpy as np + +from ppcls.data import build_dataloader +from ppcls.utils import logger + +from .train import train_epoch + + +def train_epoch_efficientnetv2(engine, epoch_id, print_batch_step): + # 1. Build training hyper-parameters for different training stage + num_stage = 4 + ratio_list = [(i + 1) / num_stage for i in range(num_stage)] + ram_list = np.linspace(5, 10, num_stage) + # dropout_rate_list = np.linspace(0.0, 0.2, num_stage) + stones = [ + int(engine.config["Global"]["epochs"] * ratio_list[i]) + for i in range(num_stage) + ] + image_size_list = [ + int(128 + (300 - 128) * ratio_list[i]) for i in range(num_stage) + ] + stage_id = 0 + for i in range(num_stage): + if epoch_id > stones[i]: + stage_id = i + 1 + + # 2. Adjust training hyper-parameters for different training stage + if not hasattr(engine, 'last_stage') or engine.last_stage < stage_id: + engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][1][ + "RandCropImage"]["size"] = image_size_list[stage_id] + engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][3][ + "RandAugment"]["magnitude"] = ram_list[stage_id] + engine.train_dataloader = build_dataloader( + engine.config["DataLoader"], + "Train", + engine.device, + engine.use_dali, + seed=epoch_id) + engine.train_dataloader_iter = iter(engine.train_dataloader) + engine.last_stage = stage_id + logger.info( + f"Training stage: [{stage_id+1}/{num_stage}](random_aug_magnitude={ram_list[stage_id]}, train_image_size={image_size_list[stage_id]})" + ) + + # 3. Train one epoch as usual at current stage + train_epoch(engine, epoch_id, print_batch_step) diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index b323cfe51746537241a304f0f0a2f51021ab4e50..1a0fc19336923deafae523e573a9a15a1677b291 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -33,7 +33,7 @@ class AttrDict(dict): self[key] = value def __deepcopy__(self, content): - return copy.deepcopy(dict(self)) + return AttrDict(copy.deepcopy(dict(self))) def create_attr_dict(yaml_config):