未验证 提交 c2daa752 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1916 from TingquanGao/dev/add_pplcnetv2

feat: add PPLCNetV2
......@@ -10,7 +10,7 @@
- [2.1 服务器端知识蒸馏模型](#2.1)
- [2.2 移动端知识蒸馏模型](#2.2)
- [2.3 Intel CPU 端知识蒸馏模型](#2.3)
- [3. PP-LCNet 系列](#3)
- [3. PP-LCNet & PP-LCNetV2 系列](#3)
- [4. ResNet 系列](#4)
- [5. 移动端系列](#5)
- [6. SEResNeXt 与 Res2Net 系列](#6)
......@@ -106,9 +106,9 @@
<a name="3"></a>
## 3. PP-LCNet 系列 <sup>[[28](#ref28)]</sup>
## 3. PP-LCNet & PP-LCNetV2 系列 <sup>[[28](#ref28)]</sup>
PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md)
PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:[PP-LCNet 系列模型文档](../models/PP-LCNet.md)[PP-LCNetV2 系列模型文档](../models/PP-LCNetV2.md)
| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6148 time(ms)<br>bs=1 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
|:--:|:--:|:--:|:--:|----|----|----|:--:|
......@@ -121,6 +121,10 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该
| PPLCNet_x2_0 |0.7518 | 0.9227 | 20.1667 | 590 | 6.54 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar) |
| PPLCNet_x2_5 |0.7660 | 0.9300 | 29.595 | 906 | 9.04 | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams) | [下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar) |
| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6271C<br>bs=1<br>OpenVINO 2021.4.2<br>time(ms) | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
|:--:|:--:|:--:|:--:|----|----|----|:--:|
| PPLCNetV2_base | 77.04 | 93.27 | 4.32 | 604 | 6.6 | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar |
<a name="4"></a>
## 4. ResNet 系列 <sup>[[1](#ref1)]</sup>
......
# PP-LCNetV2 系列
---
## 概述
PP-LCNetV2 是在 [PP-LCNet 系列模型](./PP-LCNet.md)的基础上,所提出的针对 Intel CPU 硬件平台设计的计算机视觉骨干网络,该模型更为
在不使用额外数据的前提下,PPLCNetV2_base 模型在图像分类 ImageNet 数据集上能够取得超过 77% 的 Top1 Acc,同时在 Intel CPU 平台仅有 4.4 ms 以下的延迟,如下表所示,其中延时测试基于 Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件平台,OpenVINO 2021.4.2推理平台。
| Model | Params(M) | FLOPs(M) | Top-1 Acc(\%) | Top-5 Acc(\%) | Latency(ms) |
|-------|-----------|----------|---------------|---------------|-------------|
| PPLCNetV2_base | 6.6 | 604 | 77.04 | 93.27 | 4.32 |
关于 PP-LCNetV2 系列模型的更多信息,敬请关注。
......@@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
from ppcls.arch.backbone.legendary_models.hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W64_C
from ppcls.arch.backbone.legendary_models.pp_lcnet import PPLCNet_x0_25, PPLCNet_x0_35, PPLCNet_x0_5, PPLCNet_x0_75, PPLCNet_x1_0, PPLCNet_x1_5, PPLCNet_x2_0, PPLCNet_x2_5
from ppcls.arch.backbone.legendary_models.pp_lcnet_v2 import PPLCNetV2_base
from ppcls.arch.backbone.legendary_models.esnet import ESNet_x0_25, ESNet_x0_5, ESNet_x0_75, ESNet_x1_0
from ppcls.arch.backbone.model_zoo.resnet_vc import ResNet50_vc
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn import AdaptiveAvgPool2D, BatchNorm2D, Conv2D, Dropout, Linear
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"PPLCNetV2_base":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams",
}
__all__ = list(MODEL_URLS.keys())
NET_CONFIG = {
# in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut
"stage1": [64, 3, False, False, False, False],
"stage2": [128, 3, False, False, False, False],
"stage3": [256, 5, True, True, True, False],
"stage4": [512, 5, False, True, False, True],
}
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNLayer(TheseusLayer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
use_act=True):
super().__init__()
self.use_act = use_act
self.conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
if self.use_act:
self.act = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.use_act:
x = self.act(x)
return x
class SEModule(TheseusLayer):
def __init__(self, channel, reduction=4):
super().__init__()
self.avg_pool = AdaptiveAvgPool2D(1)
self.conv1 = Conv2D(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0)
self.relu = nn.ReLU()
self.conv2 = Conv2D(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0)
self.hardsigmoid = nn.Sigmoid()
def forward(self, x):
identity = x
x = self.avg_pool(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.hardsigmoid(x)
x = paddle.multiply(x=identity, y=x)
return x
class RepDepthwiseSeparable(TheseusLayer):
def __init__(self,
in_channels,
out_channels,
stride,
dw_size=3,
split_pw=False,
use_rep=False,
use_se=False,
use_shortcut=False):
super().__init__()
self.is_repped = False
self.dw_size = dw_size
self.split_pw = split_pw
self.use_rep = use_rep
self.use_se = use_se
self.use_shortcut = True if use_shortcut and stride == 1 and in_channels == out_channels else False
if self.use_rep:
self.dw_conv_list = nn.LayerList()
for kernel_size in range(self.dw_size, 0, -2):
if kernel_size == 1 and stride != 1:
continue
dw_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
groups=in_channels,
use_act=False)
self.dw_conv_list.append(dw_conv)
self.dw_conv = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=dw_size,
stride=stride,
padding=(dw_size - 1) // 2,
groups=in_channels)
else:
self.dw_conv = ConvBNLayer(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=dw_size,
stride=stride,
groups=in_channels)
self.act = nn.ReLU()
if use_se:
self.se = SEModule(in_channels)
if self.split_pw:
pw_ratio = 0.5
self.pw_conv_1 = ConvBNLayer(
in_channels=in_channels,
kernel_size=1,
out_channels=int(out_channels * pw_ratio),
stride=1)
self.pw_conv_2 = ConvBNLayer(
in_channels=int(out_channels * pw_ratio),
kernel_size=1,
out_channels=out_channels,
stride=1)
else:
self.pw_conv = ConvBNLayer(
in_channels=in_channels,
kernel_size=1,
out_channels=out_channels,
stride=1)
def forward(self, x):
if self.use_rep:
input_x = x
if not self.training:
x = self.act(self.dw_conv(x))
else:
y = self.dw_conv_list[0](x)
for dw_conv in self.dw_conv_list[1:]:
y += dw_conv(x)
x = self.act(y)
else:
x = self.dw_conv(x)
if self.use_se:
x = self.se(x)
if self.split_pw:
x = self.pw_conv_1(x)
x = self.pw_conv_2(x)
else:
x = self.pw_conv(x)
if self.use_shortcut:
x = x + input_x
return x
def eval(self):
if self.use_rep:
kernel, bias = self._get_equivalent_kernel_bias()
self.dw_conv.weight.set_value(kernel)
self.dw_conv.bias.set_value(bias)
self.training = False
for layer in self.sublayers():
layer.eval()
def _get_equivalent_kernel_bias(self):
kernel_sum = 0
bias_sum = 0
for dw_conv in self.dw_conv_list:
kernel, bias = self._fuse_bn_tensor(dw_conv)
kernel = self._pad_tensor(kernel, to_size=self.dw_size)
kernel_sum += kernel
bias_sum += bias
return kernel_sum, bias_sum
def _fuse_bn_tensor(self, branch):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
def _pad_tensor(self, tensor, to_size):
from_size = tensor.shape[-1]
if from_size == to_size:
return tensor
pad = (to_size - from_size) // 2
return F.pad(tensor, [pad, pad, pad, pad])
class PPLCNetV2(TheseusLayer):
def __init__(self,
scale,
depths,
class_num=1000,
dropout_prob=0,
use_last_conv=True,
class_expand=1280):
super().__init__()
self.scale = scale
self.use_last_conv = use_last_conv
self.class_expand = class_expand
self.stem = nn.Sequential(* [
ConvBNLayer(
in_channels=3,
kernel_size=3,
out_channels=make_divisible(32 * scale),
stride=2), RepDepthwiseSeparable(
in_channels=make_divisible(32 * scale),
out_channels=make_divisible(64 * scale),
stride=1,
dw_size=3)
])
# stages
self.stages = nn.LayerList()
for depth_idx, k in enumerate(NET_CONFIG):
in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut = NET_CONFIG[
k]
self.stages.append(
nn.Sequential(* [
RepDepthwiseSeparable(
in_channels=make_divisible((in_channels if i == 0 else
in_channels * 2) * scale),
out_channels=make_divisible(in_channels * 2 * scale),
stride=2 if i == 0 else 1,
dw_size=kernel_size,
split_pw=split_pw,
use_rep=use_rep,
use_se=use_se,
use_shortcut=use_shortcut)
for i in range(depths[depth_idx])
]))
self.avg_pool = AdaptiveAvgPool2D(1)
if self.use_last_conv:
self.last_conv = Conv2D(
in_channels=make_divisible(NET_CONFIG["stage4"][0] * 2 *
scale),
out_channels=self.class_expand,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.act = nn.ReLU()
self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer")
self.flatten = nn.Flatten(start_axis=1, stop_axis=-1)
in_features = self.class_expand if self.use_last_conv else NET_CONFIG[
"stage4"][0] * 2 * scale
self.fc = Linear(in_features, class_num)
def forward(self, x):
x = self.stem(x)
for stage in self.stages:
x = stage(x)
x = self.avg_pool(x)
if self.use_last_conv:
x = self.last_conv(x)
x = self.act(x)
x = self.dropout(x)
x = self.flatten(x)
x = self.fc(x)
return x
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def PPLCNetV2_base(pretrained=False, use_ssld=False, **kwargs):
"""
PPLCNetV2_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
"""
model = PPLCNetV2(
scale=1.0, depths=[2, 2, 6, 2], dropout_prob=0.2, **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 480
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: PPLCNetV2_base
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.8
warmup_epoch: 5
regularizer:
name: 'L2'
coeff: 0.00004
# data loader for train and eval
DataLoader:
Train:
dataset:
name: MultiScaleDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
sampler:
name: MultiScaleSampler
scales: [160, 192, 224, 288, 320]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs: 500
divided_factor: 32
is_training: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
===========================train_params===========================
model_name:PPLCNetV2_base
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.first_bs:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml -o Global.seed=1234 -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册