From beca8b2c1b0378ef0bcb6b4dc386e76e1028e771 Mon Sep 17 00:00:00 2001 From: Yang Nie Date: Thu, 1 Sep 2022 13:13:00 +0800 Subject: [PATCH] add mobilenext add cooldown config update optimizer fix ParamAttr & update test_tipc fix tipc update tipc config remove docs of `_make_divisible` refactor the implementation of "no weight decay" fix model name remove cooldown config --- ppcls/arch/backbone/__init__.py | 1 + ppcls/arch/backbone/model_zoo/mobilenext.py | 261 ++++++++++++++++++ .../ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml | 148 ++++++++++ ppcls/optimizer/optimizer.py | 6 +- .../MobileNeXt_x1_0_train_infer_python.txt | 54 ++++ 5 files changed, 468 insertions(+), 2 deletions(-) create mode 100644 ppcls/arch/backbone/model_zoo/mobilenext.py create mode 100644 ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml create mode 100644 test_tipc/configs/MobileNeXt/MobileNeXt_x1_0_train_infer_python.txt diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index f598e48e..72ef6313 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -77,6 +77,7 @@ from .model_zoo.nextvit import NextViT_small_224, NextViT_base_224, NextViT_larg from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 from .model_zoo.cvt import CvT_13_224, CvT_13_384, CvT_21_224, CvT_21_384, CvT_W24_384 from .model_zoo.micronet import MicroNet_M0, MicroNet_M1, MicroNet_M2, MicroNet_M3 +from .model_zoo.mobilenext import MobileNeXt_x0_35, MobileNeXt_x0_5, MobileNeXt_x0_75, MobileNeXt_x1_0, MobileNeXt_x1_4 from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d diff --git a/ppcls/arch/backbone/model_zoo/mobilenext.py b/ppcls/arch/backbone/model_zoo/mobilenext.py new file mode 100644 index 00000000..7d66f13b --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/mobilenext.py @@ -0,0 +1,261 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Code was heavily based on https://github.com/zhoudaquan/rethinking_bottleneck_design +# reference: https://arxiv.org/abs/2007.02269 + +import math +import paddle.nn as nn + +from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "MobileNeXt_x0_35": "", # TODO + "MobileNeXt_x0_5": "", # TODO + "MobileNeXt_x0_75": "", # TODO + "MobileNeXt_x1_0": "", # TODO + "MobileNeXt_x1_4": "", # TODO +} + +__all__ = list(MODEL_URLS.keys()) + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +def conv_3x3_bn(inp, oup, stride): + return nn.Sequential( + nn.Conv2D( + inp, oup, 3, stride, 1, bias_attr=False), + nn.BatchNorm2D(oup), + nn.ReLU6()) + + +class SGBlock(nn.Layer): + def __init__(self, inp, oup, stride, expand_ratio, keep_3x3=False): + super(SGBlock, self).__init__() + assert stride in [1, 2] + + hidden_dim = inp // expand_ratio + if hidden_dim < oup / 6.: + hidden_dim = math.ceil(oup / 6.) + hidden_dim = _make_divisible(hidden_dim, 16) # + 16 + + self.identity = False + self.identity_div = 1 + self.expand_ratio = expand_ratio + + if expand_ratio == 2: + self.conv = nn.Sequential( + # dw + nn.Conv2D( + inp, inp, 3, 1, 1, groups=inp, bias_attr=False), + nn.BatchNorm2D(inp), + nn.ReLU6(), + # pw-linear + nn.Conv2D( + inp, hidden_dim, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(hidden_dim), + # pw-linear + nn.Conv2D( + hidden_dim, oup, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(oup), + nn.ReLU6(), + # dw + nn.Conv2D( + oup, oup, 3, stride, 1, groups=oup, bias_attr=False), + nn.BatchNorm2D(oup)) + elif inp != oup and stride == 1 and keep_3x3 == False: + self.conv = nn.Sequential( + # pw-linear + nn.Conv2D( + inp, hidden_dim, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(hidden_dim), + # pw-linear + nn.Conv2D( + hidden_dim, oup, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(oup), + nn.ReLU6()) + elif inp != oup and stride == 2 and keep_3x3 == False: + self.conv = nn.Sequential( + # pw-linear + nn.Conv2D( + inp, hidden_dim, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(hidden_dim), + # pw-linear + nn.Conv2D( + hidden_dim, oup, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(oup), + nn.ReLU6(), + # dw + nn.Conv2D( + oup, oup, 3, stride, 1, groups=oup, bias_attr=False), + nn.BatchNorm2D(oup)) + else: + if keep_3x3 == False: + self.identity = True + self.conv = nn.Sequential( + # dw + nn.Conv2D( + inp, inp, 3, 1, 1, groups=inp, bias_attr=False), + nn.BatchNorm2D(inp), + nn.ReLU6(), + # pw + nn.Conv2D( + inp, hidden_dim, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(hidden_dim), + #nn.ReLU6(), + # pw + nn.Conv2D( + hidden_dim, oup, 1, 1, 0, bias_attr=False), + nn.BatchNorm2D(oup), + nn.ReLU6(), + # dw + nn.Conv2D( + oup, oup, 3, 1, 1, groups=oup, bias_attr=False), + nn.BatchNorm2D(oup)) + + def forward(self, x): + out = self.conv(x) + + if self.identity: + if self.identity_div == 1: + out = out + x + else: + shape = x.shape + id_tensor = x[:, :shape[1] // self.identity_div, :, :] + out[:, :shape[1] // self.identity_div, :, :] = \ + out[:, :shape[1] // self.identity_div, :, :] + id_tensor + + return out + + +class MobileNeXt(nn.Layer): + def __init__(self, class_num=1000, width_mult=1.00): + super().__init__() + + # setting of inverted residual blocks + self.cfgs = [ + # t, c, n, s + [2, 96, 1, 2], + [6, 144, 1, 1], + [6, 192, 3, 2], + [6, 288, 3, 2], + [6, 384, 4, 1], + [6, 576, 4, 2], + [6, 960, 3, 1], + [6, 1280, 1, 1], + ] + + # building first layer + input_channel = _make_divisible(32 * width_mult, 4 + if width_mult == 0.1 else 8) + layers = [conv_3x3_bn(3, input_channel, 2)] + # building inverted residual blocks + block = SGBlock + for t, c, n, s in self.cfgs: + output_channel = _make_divisible(c * width_mult, 4 + if width_mult == 0.1 else 8) + if c == 1280 and width_mult < 1: + output_channel = 1280 + layers.append( + block(input_channel, output_channel, s, t, n == 1 and s == 1)) + input_channel = output_channel + for _ in range(n - 1): + layers.append(block(input_channel, output_channel, 1, t)) + input_channel = output_channel + self.features = nn.Sequential(*layers) + # building last several layers + input_channel = output_channel + output_channel = _make_divisible(input_channel, 4) + self.avgpool = nn.AdaptiveAvgPool2D((1, 1)) + self.classifier = nn.Sequential( + nn.Dropout(0.2), nn.Linear(output_channel, class_num)) + + self.apply(self._initialize_weights) + + def _initialize_weights(self, m): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + nn.initializer.Normal(std=math.sqrt(2. / n))(m.weight) + if m.bias is not None: + nn.initializer.Constant(0)(m.bias) + elif isinstance(m, nn.BatchNorm2D): + nn.initializer.Constant(1)(m.weight) + nn.initializer.Constant(0)(m.bias) + elif isinstance(m, nn.Linear): + nn.initializer.Normal(std=0.01)(m.weight) + nn.initializer.Constant(0)(m.bias) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = x.flatten(1) + x = self.classifier(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld=False): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def MobileNeXt_x0_35(pretrained=False, use_ssld=False, **kwargs): + model = MobileNeXt(width_mult=0.35, **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["MobileNeXt_x0_35"], use_ssld=use_ssld) + return model + + +def MobileNeXt_x0_5(pretrained=False, use_ssld=False, **kwargs): + model = MobileNeXt(width_mult=0.50, **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["MobileNeXt_x0_5"], use_ssld=use_ssld) + return model + + +def MobileNeXt_x0_75(pretrained=False, use_ssld=False, **kwargs): + model = MobileNeXt(width_mult=0.75, **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["MobileNeXt_x0_75"], use_ssld=use_ssld) + return model + + +def MobileNeXt_x1_0(pretrained=False, use_ssld=False, **kwargs): + model = MobileNeXt(width_mult=1.00, **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["MobileNeXt_x1_0"], use_ssld=use_ssld) + return model + + +def MobileNeXt_x1_4(pretrained=False, use_ssld=False, **kwargs): + model = MobileNeXt(width_mult=1.40, **kwargs) + _load_pretrained( + pretrained, model, MODEL_URLS["MobileNeXt_x1_4"], use_ssld=use_ssld) + return model diff --git a/ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml b/ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml new file mode 100644 index 00000000..5b986ca8 --- /dev/null +++ b/ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml @@ -0,0 +1,148 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 200 + print_batch_step: 50 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + # training model under @to_static + to_static: False + +# model architecture +Arch: + name: MobileNeXt_x1_0 + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + use_nesterov: True + no_weight_decay_name: .bias + one_dim_param_no_weight_decay: True + lr: + name: Cosine + learning_rate: 0.1 # for total batch size 512 + eta_min: 1e-5 + warmup_epoch: 3 + warmup_start_lr: 1e-4 + by_epoch: True + regularizer: + name: 'L2' + coeff: 1e-4 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: pil + - RandCropImage: + size: 224 + interpolation: random + backend: pil + - RandFlipImage: + flip_code: 1 + - ColorJitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 128 # for 4 gpus + drop_last: True + shuffle: True + loader: + num_workers: 8 + use_shared_memory: True + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + backend: pil + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: False + loader: + num_workers: 8 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + backend: pil + - ResizeImage: + resize_short: 256 + interpolation: bicubic + backend: pil + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/ppcls/optimizer/optimizer.py b/ppcls/optimizer/optimizer.py index f3c3d354..74eab3bc 100644 --- a/ppcls/optimizer/optimizer.py +++ b/ppcls/optimizer/optimizer.py @@ -96,7 +96,8 @@ class Momentum(object): grad_clip=None, use_nesterov=False, multi_precision=True, - no_weight_decay_name=None): + no_weight_decay_name=None, + one_dim_param_no_weight_decay=False): super().__init__() self.learning_rate = learning_rate self.momentum = momentum @@ -106,6 +107,7 @@ class Momentum(object): self.use_nesterov = use_nesterov self.no_weight_decay_name_list = no_weight_decay_name.split( ) if no_weight_decay_name else [] + self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay def __call__(self, model_list): # model_list is None in static graph @@ -118,7 +120,7 @@ class Momentum(object): if not any(nd in n for nd in self.no_weight_decay_name_list)] params_with_decay.extend(params) params = [p for n, p in m.named_parameters() \ - if any(nd in n for nd in self.no_weight_decay_name_list)] + if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)] params_without_decay.extend(params) parameters = [{ "params": params_with_decay, diff --git a/test_tipc/configs/MobileNeXt/MobileNeXt_x1_0_train_infer_python.txt b/test_tipc/configs/MobileNeXt/MobileNeXt_x1_0_train_infer_python.txt new file mode 100644 index 00000000..261f1229 --- /dev/null +++ b/test_tipc/configs/MobileNeXt/MobileNeXt_x1_0_train_infer_python.txt @@ -0,0 +1,54 @@ +===========================train_params=========================== +model_name:MobileNeXt_x1_0 +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.batch_size:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/MobileNeXt/MobileNeXt_x1_0.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +inference_dir:null +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.interpolation=bicubic -o PreProcess.transform_ops.0.ResizeImage.backend=pil +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:False +-o Global.cpu_num_threads:1 +-o Global.batch_size:1 +-o Global.use_tensorrt:False +-o Global.use_fp16:False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG +-o Global.save_log_path:null +-o Global.benchmark:False +null:null +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] -- GitLab