diff --git a/docs/zh_CN/algorithm_introduction/ImageNet_models.md b/docs/zh_CN/algorithm_introduction/ImageNet_models.md index ee98de442a40fb7c37b2274b756a728f7dcfc5af..4c26ea105453e954457aca71edb66394c5037153 100644 --- a/docs/zh_CN/algorithm_introduction/ImageNet_models.md +++ b/docs/zh_CN/algorithm_introduction/ImageNet_models.md @@ -10,7 +10,7 @@ - [2.1 服务器端知识蒸馏模型](#2.1) - [2.2 移动端知识蒸馏模型](#2.2) - [2.3 Intel CPU 端知识蒸馏模型](#2.3) -- [3. PP-LCNet 系列](#3) +- [3. PP-LCNet & PP-LCNetV2 系列](#3) - [4. ResNet 系列](#4) - [5. 移动端系列](#5) - [6. SEResNeXt 与 Res2Net 系列](#6) @@ -106,9 +106,9 @@ -## 3. PP-LCNet 系列 [[28](#ref28)] +## 3. PP-LCNet & PP-LCNetV2 系列 [[28](#ref28)] -PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md)。 +PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md),[PP-LCNetV2 系列模型文档](../models/PP-LCNetV2.md)。 | 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6148 time(ms)
bs=1 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | |:--:|:--:|:--:|:--:|----|----|----|:--:| @@ -121,6 +121,10 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该 | PPLCNet_x2_0 |0.7518 | 0.9227 | 20.1667 | 590 | 6.54 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar) | | PPLCNet_x2_5 |0.7660 | 0.9300 | 29.595 | 906 | 9.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar) | +| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6271C
bs=1
OpenVINO 2021.4.2
time(ms) | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 | +|:--:|:--:|:--:|:--:|----|----|----|:--:| +| PPLCNetV2_base | 77.04 | 93.27 | 4.32 | 604 | 6.6 | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar | + ## 4. ResNet 系列 [[1](#ref1)] diff --git a/docs/zh_CN/models/PP-LCNetV2.md b/docs/zh_CN/models/PP-LCNetV2.md new file mode 100644 index 0000000000000000000000000000000000000000..7563574694696247d553669e363df68fa00148dc --- /dev/null +++ b/docs/zh_CN/models/PP-LCNetV2.md @@ -0,0 +1,15 @@ +# PP-LCNetV2 系列 + +--- + +## 概述 + +PP-LCNetV2 是在 [PP-LCNet 系列模型](./PP-LCNet.md)的基础上,所提出的针对 Intel CPU 硬件平台设计的计算机视觉骨干网络,该模型更为 + +在不使用额外数据的前提下,PPLCNetV2_base 模型在图像分类 ImageNet 数据集上能够取得超过 77% 的 Top1 Acc,同时在 Intel CPU 平台仅有 4.4 ms 以下的延迟,如下表所示,其中延时测试基于 Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件平台,OpenVINO 2021.4.2推理平台。 + +| Model | Params(M) | FLOPs(M) | Top-1 Acc(\%) | Top-5 Acc(\%) | Latency(ms) | +|-------|-----------|----------|---------------|---------------|-------------| +| PPLCNetV2_base | 6.6 | 604 | 77.04 | 93.27 | 4.32 | + +关于 PP-LCNetV2 系列模型的更多信息,敬请关注。 diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index b62b5a64df348e257beee174eeb5bff1007f1d3e..a685cfb5b23f299e7d875470034f4f7b3f626086 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19 from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3 from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5 +from ppcls.arch.backbone.legendary_models.pp_lcnet_v2 import PPLCNetV2_base from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0 from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce03a9c9f01d2e148e8894de6f1aaad704dcc33 --- /dev/null +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -0,0 +1,354 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear +from paddle.regularizer import L2Decay +from paddle.nn.initializer import KaimingNormal +from ppcls.arch.backbone.base.theseus_layer import TheseusLayer +from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url + +MODEL_URLS = { + "PPLCNetV2_base": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams", +} + +__all__ = list(MODEL_URLS.keys()) + +NET_CONFIG = { + # in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut + "stage1": [64, 3, False, False, False, False], + "stage2": [128, 3, False, False, False, False], + "stage3": [256, 5, True, True, True, False], + "stage4": [512, 5, False, True, False, True], +} + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNLayer(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1, + use_act=True): + super().__init__() + self.use_act = use_act + self.conv = Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(initializer=KaimingNormal()), + bias_attr=False) + + self.bn = BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + if self.use_act: + self.act = nn.ReLU() + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.use_act: + x = self.act(x) + return x + + +class SEModule(TheseusLayer): + def __init__(self, channel, reduction=4): + super().__init__() + self.avg_pool = AdaptiveAvgPool2D(1) + self.conv1 = Conv2D( + in_channels=channel, + out_channels=channel // reduction, + kernel_size=1, + stride=1, + padding=0) + self.relu = nn.ReLU() + self.conv2 = Conv2D( + in_channels=channel // reduction, + out_channels=channel, + kernel_size=1, + stride=1, + padding=0) + self.hardsigmoid = nn.Sigmoid() + + def forward(self, x): + identity = x + x = self.avg_pool(x) + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.hardsigmoid(x) + x = paddle.multiply(x=identity, y=x) + return x + + +class RepDepthwiseSeparable(TheseusLayer): + def __init__(self, + in_channels, + out_channels, + stride, + dw_size=3, + split_pw=False, + use_rep=False, + use_se=False, + use_shortcut=False): + super().__init__() + self.is_repped = False + + self.dw_size = dw_size + self.split_pw = split_pw + self.use_rep = use_rep + self.use_se = use_se + self.use_shortcut = True if use_shortcut and stride == 1 and in_channels == out_channels else False + + if self.use_rep: + self.dw_conv_list = nn.LayerList() + for kernel_size in range(self.dw_size, 0, -2): + if kernel_size == 1 and stride != 1: + continue + dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=stride, + groups=in_channels, + use_act=False) + self.dw_conv_list.append(dw_conv) + self.dw_conv = nn.Conv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + padding=(dw_size - 1) // 2, + groups=in_channels) + else: + self.dw_conv = ConvBNLayer( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=dw_size, + stride=stride, + groups=in_channels) + + self.act = nn.ReLU() + + if use_se: + self.se = SEModule(in_channels) + + if self.split_pw: + pw_ratio = 0.5 + self.pw_conv_1 = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=int(out_channels * pw_ratio), + stride=1) + self.pw_conv_2 = ConvBNLayer( + in_channels=int(out_channels * pw_ratio), + kernel_size=1, + out_channels=out_channels, + stride=1) + else: + self.pw_conv = ConvBNLayer( + in_channels=in_channels, + kernel_size=1, + out_channels=out_channels, + stride=1) + + def forward(self, x): + if self.use_rep: + input_x = x + if not self.training: + x = self.act(self.dw_conv(x)) + else: + y = self.dw_conv_list[0](x) + for dw_conv in self.dw_conv_list[1:]: + y += dw_conv(x) + x = self.act(y) + else: + x = self.dw_conv(x) + + if self.use_se: + x = self.se(x) + if self.split_pw: + x = self.pw_conv_1(x) + x = self.pw_conv_2(x) + else: + x = self.pw_conv(x) + if self.use_shortcut: + x = x + input_x + return x + + def eval(self): + if self.use_rep: + kernel, bias = self._get_equivalent_kernel_bias() + self.dw_conv.weight.set_value(kernel) + self.dw_conv.bias.set_value(bias) + self.training = False + for layer in self.sublayers(): + layer.eval() + + def _get_equivalent_kernel_bias(self): + kernel_sum = 0 + bias_sum = 0 + for dw_conv in self.dw_conv_list: + kernel, bias = self._fuse_bn_tensor(dw_conv) + kernel = self._pad_tensor(kernel, to_size=self.dw_size) + kernel_sum += kernel + bias_sum += bias + return kernel_sum, bias_sum + + def _fuse_bn_tensor(self, branch): + kernel = branch.conv.weight + running_mean = branch.bn._mean + running_var = branch.bn._variance + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn._epsilon + std = (running_var + eps).sqrt() + t = (gamma / std).reshape((-1, 1, 1, 1)) + return kernel * t, beta - running_mean * gamma / std + + def _pad_tensor(self, tensor, to_size): + from_size = tensor.shape[-1] + if from_size == to_size: + return tensor + pad = (to_size - from_size) // 2 + return F.pad(tensor, [pad, pad, pad, pad]) + + +class PPLCNetV2(TheseusLayer): + def __init__(self, + scale, + depths, + class_num=1000, + dropout_prob=0, + use_last_conv=True, + class_expand=1280): + super().__init__() + self.scale = scale + self.use_last_conv = use_last_conv + self.class_expand = class_expand + + self.stem = nn.Sequential(* [ + ConvBNLayer( + in_channels=3, + kernel_size=3, + out_channels=make_divisible(32 * scale), + stride=2), RepDepthwiseSeparable( + in_channels=make_divisible(32 * scale), + out_channels=make_divisible(64 * scale), + stride=1, + dw_size=3) + ]) + + # stages + self.stages = nn.LayerList() + for depth_idx, k in enumerate(NET_CONFIG): + in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut = NET_CONFIG[ + k] + self.stages.append( + nn.Sequential(* [ + RepDepthwiseSeparable( + in_channels=make_divisible((in_channels if i == 0 else + in_channels * 2) * scale), + out_channels=make_divisible(in_channels * 2 * scale), + stride=2 if i == 0 else 1, + dw_size=kernel_size, + split_pw=split_pw, + use_rep=use_rep, + use_se=use_se, + use_shortcut=use_shortcut) + for i in range(depths[depth_idx]) + ])) + + self.avg_pool = AdaptiveAvgPool2D(1) + + if self.use_last_conv: + self.last_conv = Conv2D( + in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 * + scale), + out_channels=self.class_expand, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.act = nn.ReLU() + self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") + + self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) + in_features = self.class_expand if self.use_last_conv else NET_CONFIG[ + "stage4"][0] * 2 * scale + self.fc = Linear(in_features, class_num) + + def forward(self, x): + x = self.stem(x) + for stage in self.stages: + x = stage(x) + x = self.avg_pool(x) + if self.use_last_conv: + x = self.last_conv(x) + x = self.act(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + return x + + +def _load_pretrained(pretrained, model, model_url, use_ssld): + if pretrained is False: + pass + elif pretrained is True: + load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld) + elif isinstance(pretrained, str): + load_dygraph_pretrain(model, pretrained) + else: + raise RuntimeError( + "pretrained type is not available. Please use `string` or `boolean` type." + ) + + +def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs): + """ + PPLCNetV2_base + Args: + pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise. + If str, means the path of the pretrained model. + use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True. + Returns: + model: nn.Layer. Specific `PPLCNetV2_base` model depends on args. + """ + model = PPLCNetV2( + scale=1.0, depths=[2, 2, 6, 2], dropout_prob=0.2, **kwargs) + _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) + return model diff --git a/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..640833938bd81d8dd24c8bdd0ae1de86d8697a10 --- /dev/null +++ b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml @@ -0,0 +1,133 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 480 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +# model architecture +Arch: + name: PPLCNetV2_base + class_num: 1000 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.8 + warmup_epoch: 5 + regularizer: + name: 'L2' + coeff: 0.00004 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: MultiScaleDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/train_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + # support to specify width and height respectively: + # scales: [(160,160), (192,192), (224,224) (288,288) (320,320)] + sampler: + name: MultiScaleSampler + scales: [160, 192, 224, 288, 320] + # first_bs: batch size for the first image resolution in the scales list + # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple + first_bs: 500 + divided_factor: 32 + is_training: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Infer: + infer_imgs: docs/images/inference_deployment/whl_demo.jpg + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: ppcls/utils/imagenet1k_label_list.txt + +Metric: + Train: + - TopkAcc: + topk: [1, 5] + Eval: + - TopkAcc: + topk: [1, 5] diff --git a/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..1c2806f27885e8fc3d31233b700ac9120fce6888 --- /dev/null +++ b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:PPLCNetV2_base +python:python3.7 +gpu_list:0|0,1 +-o Global.device:gpu +-o Global.auto_cast:null +-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120 +-o Global.output_dir:./output/ +-o DataLoader.Train.sampler.first_bs:8 +-o Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./dataset/ILSVRC2012/val +null:null +## +trainer:norm_train +norm_train:tools/train.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml -o Global.seed=1234 -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml +null:null +## +===========================infer_params========================== +-o Global.save_inference_dir:./inference +-o Global.pretrained_model: +norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml +quant_export:null +fpgm_export:null +distill_export:null +kl_quant:null +export2:null +pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams +infer_model:../inference/ +infer_export:True +infer_quant:Fasle +inference:python/predict_cls.py -c configs/inference_cls.yaml +-o Global.use_gpu:True|False +-o Global.enable_mkldnn:True|False +-o Global.cpu_num_threads:1|6 +-o Global.batch_size:1|16 +-o Global.use_tensorrt:True|False +-o Global.use_fp16:True|False +-o Global.inference_model_dir:../inference +-o Global.infer_imgs:../dataset/ILSVRC2012/val +-o Global.save_log_path:null +-o Global.benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}]