diff --git a/docs/zh_CN/algorithm_introduction/ImageNet_models.md b/docs/zh_CN/algorithm_introduction/ImageNet_models.md
index ee98de442a40fb7c37b2274b756a728f7dcfc5af..4c26ea105453e954457aca71edb66394c5037153 100644
--- a/docs/zh_CN/algorithm_introduction/ImageNet_models.md
+++ b/docs/zh_CN/algorithm_introduction/ImageNet_models.md
@@ -10,7 +10,7 @@
- [2.1 服务器端知识蒸馏模型](#2.1)
- [2.2 移动端知识蒸馏模型](#2.2)
- [2.3 Intel CPU 端知识蒸馏模型](#2.3)
-- [3. PP-LCNet 系列](#3)
+- [3. PP-LCNet & PP-LCNetV2 系列](#3)
- [4. ResNet 系列](#4)
- [5. 移动端系列](#5)
- [6. SEResNeXt 与 Res2Net 系列](#6)
@@ -106,9 +106,9 @@
-## 3. PP-LCNet 系列 [[28](#ref28)]
+## 3. PP-LCNet & PP-LCNetV2 系列 [[28](#ref28)]
-PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md)。
+PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md),[PP-LCNetV2 系列模型文档](../models/PP-LCNetV2.md)。
| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6148 time(ms)
bs=1 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
|:--:|:--:|:--:|:--:|----|----|----|:--:|
@@ -121,6 +121,10 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该
| PPLCNet_x2_0 |0.7518 | 0.9227 | 20.1667 | 590 | 6.54 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar) |
| PPLCNet_x2_5 |0.7660 | 0.9300 | 29.595 | 906 | 9.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar) |
+| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6271C
bs=1
OpenVINO 2021.4.2
time(ms) | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
+|:--:|:--:|:--:|:--:|----|----|----|:--:|
+| PPLCNetV2_base | 77.04 | 93.27 | 4.32 | 604 | 6.6 | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar |
+
## 4. ResNet 系列 [[1](#ref1)]
diff --git a/docs/zh_CN/models/PP-LCNetV2.md b/docs/zh_CN/models/PP-LCNetV2.md
new file mode 100644
index 0000000000000000000000000000000000000000..7563574694696247d553669e363df68fa00148dc
--- /dev/null
+++ b/docs/zh_CN/models/PP-LCNetV2.md
@@ -0,0 +1,15 @@
+# PP-LCNetV2 系列
+
+---
+
+## 概述
+
+PP-LCNetV2 是在 [PP-LCNet 系列模型](./PP-LCNet.md)的基础上,所提出的针对 Intel CPU 硬件平台设计的计算机视觉骨干网络,该模型更为
+
+在不使用额外数据的前提下,PPLCNetV2_base 模型在图像分类 ImageNet 数据集上能够取得超过 77% 的 Top1 Acc,同时在 Intel CPU 平台仅有 4.4 ms 以下的延迟,如下表所示,其中延时测试基于 Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件平台,OpenVINO 2021.4.2推理平台。
+
+| Model | Params(M) | FLOPs(M) | Top-1 Acc(\%) | Top-5 Acc(\%) | Latency(ms) |
+|-------|-----------|----------|---------------|---------------|-------------|
+| PPLCNetV2_base | 6.6 | 604 | 77.04 | 93.27 | 4.32 |
+
+关于 PP-LCNetV2 系列模型的更多信息,敬请关注。
diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py
index b62b5a64df348e257beee174eeb5bff1007f1d3e..a685cfb5b23f299e7d875470034f4f7b3f626086 100644
--- a/ppcls/arch/backbone/__init__.py
+++ b/ppcls/arch/backbone/__init__.py
@@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
+from ppcls.arch.backbone.legendary_models.pp_lcnet_v2 import PPLCNetV2_base
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ce03a9c9f01d2e148e8894de6f1aaad704dcc33
--- /dev/null
+++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
@@ -0,0 +1,354 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import, division, print_function
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import KaimingNormal
+from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
+from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
+
+MODEL_URLS = {
+ "PPLCNetV2_base":
+ "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams",
+}
+
+__all__ = list(MODEL_URLS.keys())
+
+NET_CONFIG = {
+ # in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut
+ "stage1": [64, 3, False, False, False, False],
+ "stage2": [128, 3, False, False, False, False],
+ "stage3": [256, 5, True, True, True, False],
+ "stage4": [512, 5, False, True, False, True],
+}
+
+
+def make_divisible(v, divisor=8, min_value=None):
+ if min_value is None:
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class ConvBNLayer(TheseusLayer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ groups=1,
+ use_act=True):
+ super().__init__()
+ self.use_act = use_act
+ self.conv = Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(initializer=KaimingNormal()),
+ bias_attr=False)
+
+ self.bn = BatchNorm2D(
+ out_channels,
+ weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
+ bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+ if self.use_act:
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.use_act:
+ x = self.act(x)
+ return x
+
+
+class SEModule(TheseusLayer):
+ def __init__(self, channel, reduction=4):
+ super().__init__()
+ self.avg_pool = AdaptiveAvgPool2D(1)
+ self.conv1 = Conv2D(
+ in_channels=channel,
+ out_channels=channel // reduction,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.relu = nn.ReLU()
+ self.conv2 = Conv2D(
+ in_channels=channel // reduction,
+ out_channels=channel,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.hardsigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ identity = x
+ x = self.avg_pool(x)
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.hardsigmoid(x)
+ x = paddle.multiply(x=identity, y=x)
+ return x
+
+
+class RepDepthwiseSeparable(TheseusLayer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ stride,
+ dw_size=3,
+ split_pw=False,
+ use_rep=False,
+ use_se=False,
+ use_shortcut=False):
+ super().__init__()
+ self.is_repped = False
+
+ self.dw_size = dw_size
+ self.split_pw = split_pw
+ self.use_rep = use_rep
+ self.use_se = use_se
+ self.use_shortcut = True if use_shortcut and stride == 1 and in_channels == out_channels else False
+
+ if self.use_rep:
+ self.dw_conv_list = nn.LayerList()
+ for kernel_size in range(self.dw_size, 0, -2):
+ if kernel_size == 1 and stride != 1:
+ continue
+ dw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=in_channels,
+ use_act=False)
+ self.dw_conv_list.append(dw_conv)
+ self.dw_conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=dw_size,
+ stride=stride,
+ padding=(dw_size - 1) // 2,
+ groups=in_channels)
+ else:
+ self.dw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=dw_size,
+ stride=stride,
+ groups=in_channels)
+
+ self.act = nn.ReLU()
+
+ if use_se:
+ self.se = SEModule(in_channels)
+
+ if self.split_pw:
+ pw_ratio = 0.5
+ self.pw_conv_1 = ConvBNLayer(
+ in_channels=in_channels,
+ kernel_size=1,
+ out_channels=int(out_channels * pw_ratio),
+ stride=1)
+ self.pw_conv_2 = ConvBNLayer(
+ in_channels=int(out_channels * pw_ratio),
+ kernel_size=1,
+ out_channels=out_channels,
+ stride=1)
+ else:
+ self.pw_conv = ConvBNLayer(
+ in_channels=in_channels,
+ kernel_size=1,
+ out_channels=out_channels,
+ stride=1)
+
+ def forward(self, x):
+ if self.use_rep:
+ input_x = x
+ if not self.training:
+ x = self.act(self.dw_conv(x))
+ else:
+ y = self.dw_conv_list[0](x)
+ for dw_conv in self.dw_conv_list[1:]:
+ y += dw_conv(x)
+ x = self.act(y)
+ else:
+ x = self.dw_conv(x)
+
+ if self.use_se:
+ x = self.se(x)
+ if self.split_pw:
+ x = self.pw_conv_1(x)
+ x = self.pw_conv_2(x)
+ else:
+ x = self.pw_conv(x)
+ if self.use_shortcut:
+ x = x + input_x
+ return x
+
+ def eval(self):
+ if self.use_rep:
+ kernel, bias = self._get_equivalent_kernel_bias()
+ self.dw_conv.weight.set_value(kernel)
+ self.dw_conv.bias.set_value(bias)
+ self.training = False
+ for layer in self.sublayers():
+ layer.eval()
+
+ def _get_equivalent_kernel_bias(self):
+ kernel_sum = 0
+ bias_sum = 0
+ for dw_conv in self.dw_conv_list:
+ kernel, bias = self._fuse_bn_tensor(dw_conv)
+ kernel = self._pad_tensor(kernel, to_size=self.dw_size)
+ kernel_sum += kernel
+ bias_sum += bias
+ return kernel_sum, bias_sum
+
+ def _fuse_bn_tensor(self, branch):
+ kernel = branch.conv.weight
+ running_mean = branch.bn._mean
+ running_var = branch.bn._variance
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn._epsilon
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape((-1, 1, 1, 1))
+ return kernel * t, beta - running_mean * gamma / std
+
+ def _pad_tensor(self, tensor, to_size):
+ from_size = tensor.shape[-1]
+ if from_size == to_size:
+ return tensor
+ pad = (to_size - from_size) // 2
+ return F.pad(tensor, [pad, pad, pad, pad])
+
+
+class PPLCNetV2(TheseusLayer):
+ def __init__(self,
+ scale,
+ depths,
+ class_num=1000,
+ dropout_prob=0,
+ use_last_conv=True,
+ class_expand=1280):
+ super().__init__()
+ self.scale = scale
+ self.use_last_conv = use_last_conv
+ self.class_expand = class_expand
+
+ self.stem = nn.Sequential(* [
+ ConvBNLayer(
+ in_channels=3,
+ kernel_size=3,
+ out_channels=make_divisible(32 * scale),
+ stride=2), RepDepthwiseSeparable(
+ in_channels=make_divisible(32 * scale),
+ out_channels=make_divisible(64 * scale),
+ stride=1,
+ dw_size=3)
+ ])
+
+ # stages
+ self.stages = nn.LayerList()
+ for depth_idx, k in enumerate(NET_CONFIG):
+ in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut = NET_CONFIG[
+ k]
+ self.stages.append(
+ nn.Sequential(* [
+ RepDepthwiseSeparable(
+ in_channels=make_divisible((in_channels if i == 0 else
+ in_channels * 2) * scale),
+ out_channels=make_divisible(in_channels * 2 * scale),
+ stride=2 if i == 0 else 1,
+ dw_size=kernel_size,
+ split_pw=split_pw,
+ use_rep=use_rep,
+ use_se=use_se,
+ use_shortcut=use_shortcut)
+ for i in range(depths[depth_idx])
+ ]))
+
+ self.avg_pool = AdaptiveAvgPool2D(1)
+
+ if self.use_last_conv:
+ self.last_conv = Conv2D(
+ in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 *
+ scale),
+ out_channels=self.class_expand,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias_attr=False)
+ self.act = nn.ReLU()
+ self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
+
+ self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
+ in_features = self.class_expand if self.use_last_conv else NET_CONFIG[
+ "stage4"][0] * 2 * scale
+ self.fc = Linear(in_features, class_num)
+
+ def forward(self, x):
+ x = self.stem(x)
+ for stage in self.stages:
+ x = stage(x)
+ x = self.avg_pool(x)
+ if self.use_last_conv:
+ x = self.last_conv(x)
+ x = self.act(x)
+ x = self.dropout(x)
+ x = self.flatten(x)
+ x = self.fc(x)
+ return x
+
+
+def _load_pretrained(pretrained, model, model_url, use_ssld):
+ if pretrained is False:
+ pass
+ elif pretrained is True:
+ load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
+ elif isinstance(pretrained, str):
+ load_dygraph_pretrain(model, pretrained)
+ else:
+ raise RuntimeError(
+ "pretrained type is not available. Please use `string` or `boolean` type."
+ )
+
+
+def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs):
+ """
+ PPLCNetV2_base
+ Args:
+ pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
+ If str, means the path of the pretrained model.
+ use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
+ Returns:
+ model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
+ """
+ model = PPLCNetV2(
+ scale=1.0, depths=[2, 2, 6, 2], dropout_prob=0.2, **kwargs)
+ _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld)
+ return model
diff --git a/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..640833938bd81d8dd24c8bdd0ae1de86d8697a10
--- /dev/null
+++ b/ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
@@ -0,0 +1,133 @@
+# global configs
+Global:
+ checkpoints: null
+ pretrained_model: null
+ output_dir: ./output/
+ device: gpu
+ save_interval: 1
+ eval_during_train: True
+ eval_interval: 1
+ epochs: 480
+ print_batch_step: 10
+ use_visualdl: False
+ # used for static mode and model export
+ image_shape: [3, 224, 224]
+ save_inference_dir: ./inference
+
+# model architecture
+Arch:
+ name: PPLCNetV2_base
+ class_num: 1000
+
+# loss function config for traing/eval process
+Loss:
+ Train:
+ - CELoss:
+ weight: 1.0
+ epsilon: 0.1
+ Eval:
+ - CELoss:
+ weight: 1.0
+
+Optimizer:
+ name: Momentum
+ momentum: 0.9
+ lr:
+ name: Cosine
+ learning_rate: 0.8
+ warmup_epoch: 5
+ regularizer:
+ name: 'L2'
+ coeff: 0.00004
+
+# data loader for train and eval
+DataLoader:
+ Train:
+ dataset:
+ name: MultiScaleDataset
+ image_root: ./dataset/ILSVRC2012/
+ cls_label_path: ./dataset/ILSVRC2012/train_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+
+ # support to specify width and height respectively:
+ # scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
+ sampler:
+ name: MultiScaleSampler
+ scales: [160, 192, 224, 288, 320]
+ # first_bs: batch size for the first image resolution in the scales list
+ # divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
+ first_bs: 500
+ divided_factor: 32
+ is_training: True
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+ Eval:
+ dataset:
+ name: ImageNetDataset
+ image_root: ./dataset/ILSVRC2012/
+ cls_label_path: ./dataset/ILSVRC2012/val_list.txt
+ transform_ops:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ sampler:
+ name: DistributedBatchSampler
+ batch_size: 64
+ drop_last: False
+ shuffle: False
+ loader:
+ num_workers: 4
+ use_shared_memory: True
+
+Infer:
+ infer_imgs: docs/images/inference_deployment/whl_demo.jpg
+ batch_size: 10
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ PostProcess:
+ name: Topk
+ topk: 5
+ class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
+
+Metric:
+ Train:
+ - TopkAcc:
+ topk: [1, 5]
+ Eval:
+ - TopkAcc:
+ topk: [1, 5]
diff --git a/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1c2806f27885e8fc3d31233b700ac9120fce6888
--- /dev/null
+++ b/test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt
@@ -0,0 +1,53 @@
+===========================train_params===========================
+model_name:PPLCNetV2_base
+python:python3.7
+gpu_list:0|0,1
+-o Global.device:gpu
+-o Global.auto_cast:null
+-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
+-o Global.output_dir:./output/
+-o DataLoader.Train.sampler.first_bs:8
+-o Global.pretrained_model:null
+train_model_name:latest
+train_infer_img_dir:./dataset/ILSVRC2012/val
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml -o Global.seed=1234 -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:tools/eval.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
+null:null
+##
+===========================infer_params==========================
+-o Global.save_inference_dir:./inference
+-o Global.pretrained_model:
+norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
+quant_export:null
+fpgm_export:null
+distill_export:null
+kl_quant:null
+export2:null
+pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams
+infer_model:../inference/
+infer_export:True
+infer_quant:Fasle
+inference:python/predict_cls.py -c configs/inference_cls.yaml
+-o Global.use_gpu:True|False
+-o Global.enable_mkldnn:True|False
+-o Global.cpu_num_threads:1|6
+-o Global.batch_size:1|16
+-o Global.use_tensorrt:True|False
+-o Global.use_fp16:True|False
+-o Global.inference_model_dir:../inference
+-o Global.infer_imgs:../dataset/ILSVRC2012/val
+-o Global.save_log_path:null
+-o Global.benchmark:True
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,224,224]}]