Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c2daa752
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
1 年多 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c2daa752
编写于
5月 13, 2022
作者:
C
cuicheng01
提交者:
GitHub
5月 13, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1916 from TingquanGao/dev/add_pplcnetv2
feat: add PPLCNetV2
上级
50c1302b
dce720dc
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
563 addition
and
3 deletion
+563
-3
docs/zh_CN/algorithm_introduction/ImageNet_models.md
docs/zh_CN/algorithm_introduction/ImageNet_models.md
+7
-3
docs/zh_CN/models/PP-LCNetV2.md
docs/zh_CN/models/PP-LCNetV2.md
+15
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
+354
-0
ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
+133
-0
test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt
...pc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt
+53
-0
未找到文件。
docs/zh_CN/algorithm_introduction/ImageNet_models.md
浏览文件 @
c2daa752
...
...
@@ -10,7 +10,7 @@
-
[
2.1 服务器端知识蒸馏模型
](
#2.1
)
-
[
2.2 移动端知识蒸馏模型
](
#2.2
)
-
[
2.3 Intel CPU 端知识蒸馏模型
](
#2.3
)
-
[
3. PP-LCNet 系列
](
#3
)
-
[
3. PP-LCNet
& PP-LCNetV2
系列
](
#3
)
-
[
4. ResNet 系列
](
#4
)
-
[
5. 移动端系列
](
#5
)
-
[
6. SEResNeXt 与 Res2Net 系列
](
#6
)
...
...
@@ -106,9 +106,9 @@
<a
name=
"3"
></a>
## 3. PP-LCNet 系列 <sup>[[28](#ref28)]</sup>
## 3. PP-LCNet
& PP-LCNetV2
系列 <sup>[[28](#ref28)]</sup>
PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:
[
PP-LCNet 系列模型文档
](
../models/PP-LCNet.md
)
。
PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:
[
PP-LCNet 系列模型文档
](
../models/PP-LCNet.md
)
,
[
PP-LCNetV2 系列模型文档
](
../models/PP-LCNetV2.md
)
。
| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6148 time(ms)
<br>
bs=1 | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
|:--:|:--:|:--:|:--:|----|----|----|:--:|
...
...
@@ -121,6 +121,10 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该
| PPLCNet_x2_0 |0.7518 | 0.9227 | 20.1667 | 590 | 6.54 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_0_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_0_infer.tar
)
|
| PPLCNet_x2_5 |0.7660 | 0.9300 | 29.595 | 906 | 9.04 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNet_x2_5_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNet_x2_5_infer.tar
)
|
| 模型 | Top-1 Acc | Top-5 Acc | Intel-Xeon-Gold-6271C
<br>
bs=1
<br>
OpenVINO 2021.4.2
<br>
time(ms) | FLOPs(M) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
|:--:|:--:|:--:|:--:|----|----|----|:--:|
| PPLCNetV2_base | 77.04 | 93.27 | 4.32 | 604 | 6.6 | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams | https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPLCNetV2_base_infer.tar |
<a
name=
"4"
></a>
## 4. ResNet 系列 <sup>[[1](#ref1)]</sup>
...
...
docs/zh_CN/models/PP-LCNetV2.md
0 → 100644
浏览文件 @
c2daa752
# PP-LCNetV2 系列
---
## 概述
PP-LCNetV2 是在
[
PP-LCNet 系列模型
](
./PP-LCNet.md
)
的基础上,所提出的针对 Intel CPU 硬件平台设计的计算机视觉骨干网络,该模型更为
在不使用额外数据的前提下,PPLCNetV2_base 模型在图像分类 ImageNet 数据集上能够取得超过 77% 的 Top1 Acc,同时在 Intel CPU 平台仅有 4.4 ms 以下的延迟,如下表所示,其中延时测试基于 Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz 硬件平台,OpenVINO 2021.4.2推理平台。
| Model | Params(M) | FLOPs(M) | Top-1 Acc(
\%
) | Top-5 Acc(
\%
) | Latency(ms) |
|-------|-----------|----------|---------------|---------------|-------------|
| PPLCNetV2_base | 6.6 | 604 | 77.04 | 93.27 | 4.32 |
关于 PP-LCNetV2 系列模型的更多信息,敬请关注。
ppcls/arch/backbone/__init__.py
浏览文件 @
c2daa752
...
...
@@ -22,6 +22,7 @@ from ppcls.arch.backbone.legendary_models.vgg import VGG11, VGG13, VGG16, VGG19
from
ppcls.arch.backbone.legendary_models.inception_v3
import
InceptionV3
from
ppcls.arch.backbone.legendary_models.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W60_C
,
HRNet_W64_C
,
SE_HRNet_W64_C
from
ppcls.arch.backbone.legendary_models.pp_lcnet
import
PPLCNet_x0_25
,
PPLCNet_x0_35
,
PPLCNet_x0_5
,
PPLCNet_x0_75
,
PPLCNet_x1_0
,
PPLCNet_x1_5
,
PPLCNet_x2_0
,
PPLCNet_x2_5
from
ppcls.arch.backbone.legendary_models.pp_lcnet_v2
import
PPLCNetV2_base
from
ppcls.arch.backbone.legendary_models.esnet
import
ESNet_x0_25
,
ESNet_x0_5
,
ESNet_x0_75
,
ESNet_x1_0
from
ppcls.arch.backbone.model_zoo.resnet_vc
import
ResNet50_vc
...
...
ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py
0 → 100644
浏览文件 @
c2daa752
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm2D
,
Conv2D
,
Dropout
,
Linear
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"PPLCNetV2_base"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams"
,
}
__all__
=
list
(
MODEL_URLS
.
keys
())
NET_CONFIG
=
{
# in_channels, kernel_size, split_pw, use_rep, use_se, use_shortcut
"stage1"
:
[
64
,
3
,
False
,
False
,
False
,
False
],
"stage2"
:
[
128
,
3
,
False
,
False
,
False
,
False
],
"stage3"
:
[
256
,
5
,
True
,
True
,
True
,
False
],
"stage4"
:
[
512
,
5
,
False
,
True
,
False
,
True
],
}
def
make_divisible
(
v
,
divisor
=
8
,
min_value
=
None
):
if
min_value
is
None
:
min_value
=
divisor
new_v
=
max
(
min_value
,
int
(
v
+
divisor
/
2
)
//
divisor
*
divisor
)
if
new_v
<
0.9
*
v
:
new_v
+=
divisor
return
new_v
class
ConvBNLayer
(
TheseusLayer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
groups
=
1
,
use_act
=
True
):
super
().
__init__
()
self
.
use_act
=
use_act
self
.
conv
=
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
bias_attr
=
False
)
self
.
bn
=
BatchNorm2D
(
out_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
if
self
.
use_act
:
self
.
act
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
use_act
:
x
=
self
.
act
(
x
)
return
x
class
SEModule
(
TheseusLayer
):
def
__init__
(
self
,
channel
,
reduction
=
4
):
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
Conv2D
(
in_channels
=
channel
,
out_channels
=
channel
//
reduction
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
Conv2D
(
in_channels
=
channel
//
reduction
,
out_channels
=
channel
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
hardsigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
hardsigmoid
(
x
)
x
=
paddle
.
multiply
(
x
=
identity
,
y
=
x
)
return
x
class
RepDepthwiseSeparable
(
TheseusLayer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
,
dw_size
=
3
,
split_pw
=
False
,
use_rep
=
False
,
use_se
=
False
,
use_shortcut
=
False
):
super
().
__init__
()
self
.
is_repped
=
False
self
.
dw_size
=
dw_size
self
.
split_pw
=
split_pw
self
.
use_rep
=
use_rep
self
.
use_se
=
use_se
self
.
use_shortcut
=
True
if
use_shortcut
and
stride
==
1
and
in_channels
==
out_channels
else
False
if
self
.
use_rep
:
self
.
dw_conv_list
=
nn
.
LayerList
()
for
kernel_size
in
range
(
self
.
dw_size
,
0
,
-
2
):
if
kernel_size
==
1
and
stride
!=
1
:
continue
dw_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
groups
=
in_channels
,
use_act
=
False
)
self
.
dw_conv_list
.
append
(
dw_conv
)
self
.
dw_conv
=
nn
.
Conv2D
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
dw_size
,
stride
=
stride
,
padding
=
(
dw_size
-
1
)
//
2
,
groups
=
in_channels
)
else
:
self
.
dw_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
dw_size
,
stride
=
stride
,
groups
=
in_channels
)
self
.
act
=
nn
.
ReLU
()
if
use_se
:
self
.
se
=
SEModule
(
in_channels
)
if
self
.
split_pw
:
pw_ratio
=
0.5
self
.
pw_conv_1
=
ConvBNLayer
(
in_channels
=
in_channels
,
kernel_size
=
1
,
out_channels
=
int
(
out_channels
*
pw_ratio
),
stride
=
1
)
self
.
pw_conv_2
=
ConvBNLayer
(
in_channels
=
int
(
out_channels
*
pw_ratio
),
kernel_size
=
1
,
out_channels
=
out_channels
,
stride
=
1
)
else
:
self
.
pw_conv
=
ConvBNLayer
(
in_channels
=
in_channels
,
kernel_size
=
1
,
out_channels
=
out_channels
,
stride
=
1
)
def
forward
(
self
,
x
):
if
self
.
use_rep
:
input_x
=
x
if
not
self
.
training
:
x
=
self
.
act
(
self
.
dw_conv
(
x
))
else
:
y
=
self
.
dw_conv_list
[
0
](
x
)
for
dw_conv
in
self
.
dw_conv_list
[
1
:]:
y
+=
dw_conv
(
x
)
x
=
self
.
act
(
y
)
else
:
x
=
self
.
dw_conv
(
x
)
if
self
.
use_se
:
x
=
self
.
se
(
x
)
if
self
.
split_pw
:
x
=
self
.
pw_conv_1
(
x
)
x
=
self
.
pw_conv_2
(
x
)
else
:
x
=
self
.
pw_conv
(
x
)
if
self
.
use_shortcut
:
x
=
x
+
input_x
return
x
def
eval
(
self
):
if
self
.
use_rep
:
kernel
,
bias
=
self
.
_get_equivalent_kernel_bias
()
self
.
dw_conv
.
weight
.
set_value
(
kernel
)
self
.
dw_conv
.
bias
.
set_value
(
bias
)
self
.
training
=
False
for
layer
in
self
.
sublayers
():
layer
.
eval
()
def
_get_equivalent_kernel_bias
(
self
):
kernel_sum
=
0
bias_sum
=
0
for
dw_conv
in
self
.
dw_conv_list
:
kernel
,
bias
=
self
.
_fuse_bn_tensor
(
dw_conv
)
kernel
=
self
.
_pad_tensor
(
kernel
,
to_size
=
self
.
dw_size
)
kernel_sum
+=
kernel
bias_sum
+=
bias
return
kernel_sum
,
bias_sum
def
_fuse_bn_tensor
(
self
,
branch
):
kernel
=
branch
.
conv
.
weight
running_mean
=
branch
.
bn
.
_mean
running_var
=
branch
.
bn
.
_variance
gamma
=
branch
.
bn
.
weight
beta
=
branch
.
bn
.
bias
eps
=
branch
.
bn
.
_epsilon
std
=
(
running_var
+
eps
).
sqrt
()
t
=
(
gamma
/
std
).
reshape
((
-
1
,
1
,
1
,
1
))
return
kernel
*
t
,
beta
-
running_mean
*
gamma
/
std
def
_pad_tensor
(
self
,
tensor
,
to_size
):
from_size
=
tensor
.
shape
[
-
1
]
if
from_size
==
to_size
:
return
tensor
pad
=
(
to_size
-
from_size
)
//
2
return
F
.
pad
(
tensor
,
[
pad
,
pad
,
pad
,
pad
])
class
PPLCNetV2
(
TheseusLayer
):
def
__init__
(
self
,
scale
,
depths
,
class_num
=
1000
,
dropout_prob
=
0
,
use_last_conv
=
True
,
class_expand
=
1280
):
super
().
__init__
()
self
.
scale
=
scale
self
.
use_last_conv
=
use_last_conv
self
.
class_expand
=
class_expand
self
.
stem
=
nn
.
Sequential
(
*
[
ConvBNLayer
(
in_channels
=
3
,
kernel_size
=
3
,
out_channels
=
make_divisible
(
32
*
scale
),
stride
=
2
),
RepDepthwiseSeparable
(
in_channels
=
make_divisible
(
32
*
scale
),
out_channels
=
make_divisible
(
64
*
scale
),
stride
=
1
,
dw_size
=
3
)
])
# stages
self
.
stages
=
nn
.
LayerList
()
for
depth_idx
,
k
in
enumerate
(
NET_CONFIG
):
in_channels
,
kernel_size
,
split_pw
,
use_rep
,
use_se
,
use_shortcut
=
NET_CONFIG
[
k
]
self
.
stages
.
append
(
nn
.
Sequential
(
*
[
RepDepthwiseSeparable
(
in_channels
=
make_divisible
((
in_channels
if
i
==
0
else
in_channels
*
2
)
*
scale
),
out_channels
=
make_divisible
(
in_channels
*
2
*
scale
),
stride
=
2
if
i
==
0
else
1
,
dw_size
=
kernel_size
,
split_pw
=
split_pw
,
use_rep
=
use_rep
,
use_se
=
use_se
,
use_shortcut
=
use_shortcut
)
for
i
in
range
(
depths
[
depth_idx
])
]))
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
if
self
.
use_last_conv
:
self
.
last_conv
=
Conv2D
(
in_channels
=
make_divisible
(
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
),
out_channels
=
self
.
class_expand
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
act
=
nn
.
ReLU
()
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
in_features
=
self
.
class_expand
if
self
.
use_last_conv
else
NET_CONFIG
[
"stage4"
][
0
]
*
2
*
scale
self
.
fc
=
Linear
(
in_features
,
class_num
)
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
for
stage
in
self
.
stages
:
x
=
stage
(
x
)
x
=
self
.
avg_pool
(
x
)
if
self
.
use_last_conv
:
x
=
self
.
last_conv
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
PPLCNetV2_base
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
PPLCNetV2_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPLCNetV2_base` model depends on args.
"""
model
=
PPLCNetV2
(
scale
=
1.0
,
depths
=
[
2
,
2
,
6
,
2
],
dropout_prob
=
0.2
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPLCNetV2_base"
],
use_ssld
)
return
model
ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
0 → 100644
浏览文件 @
c2daa752
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
480
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# model architecture
Arch
:
name
:
PPLCNetV2_base
class_num
:
1000
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.8
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
MultiScaleDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
# support to specify width and height respectively:
# scales: [(160,160), (192,192), (224,224) (288,288) (320,320)]
sampler
:
name
:
MultiScaleSampler
scales
:
[
160
,
192
,
224
,
288
,
320
]
# first_bs: batch size for the first image resolution in the scales list
# divide_factor: to ensure the width and height dimensions can be devided by downsampling multiple
first_bs
:
500
divided_factor
:
32
is_training
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
test_tipc/config/PPLCNetV2/PPLCNetV2_base_train_infer_python.txt
0 → 100644
浏览文件 @
c2daa752
===========================train_params===========================
model_name:PPLCNetV2_base
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.first_bs:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml -o Global.seed=1234 -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPLCNetV2/PPLCNetV2_base.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPLCNetV2_base_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录