Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
8a760fb8
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8a760fb8
编写于
5月 12, 2022
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add PPHGNet code
上级
50c1302b
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
831 addition
and
0 deletion
+831
-0
docs/zh_CN/models/PP-HGNet.md
docs/zh_CN/models/PP-HGNet.md
+24
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/legendary_models/pp_hgnet.py
ppcls/arch/backbone/legendary_models/pp_hgnet.py
+372
-0
ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
+164
-0
ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
+164
-0
test_tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt
..._tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt
+53
-0
test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt
test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt
+53
-0
未找到文件。
docs/zh_CN/models/PP-HGNet.md
0 → 100644
浏览文件 @
8a760fb8
# PP-HGNet 系列
---
## 目录
*
[
1. 概述
](
#1
)
*
[
2. 精度、FLOPs 和参数量
](
#2
)
<a
name=
'1'
></a>
## 1. 概述
PP-HGNet是百度自研的一个在 GPU 端上高性能的网络,该网络在 VOVNet 的基础上融合了 ResNet_vd、PPLCNet 的优点,使用了可学习的下采样层,组合成了一个在 GPU 设备上速度快、精度高的网络,超越其他 GPU 端 SOTA 模型。
<a
name=
'2'
></a>
## 2.精度、FLOPs 和参数量
| Models | Top1 | Top5 | FLOPs
<br>
(G) | Params
<br/>
(M) |
|:--:|:--:|:--:|:--:|:--:|
| PPHGNet_tiny | 79.83 | 95.04 | 4.54 | 14.75 |
| PPHGNet_tiny_ssld | 81.95 | 96.12 | 4.54 | 14.75 |
| PPHGNet_small | 81.51 | 95.82 | 8.53 | 24.38 |
关于 Inference speed 等信息,敬请期待。
ppcls/arch/backbone/__init__.py
浏览文件 @
8a760fb8
...
...
@@ -23,6 +23,7 @@ from ppcls.arch.backbone.legendary_models.inception_v3 import InceptionV3
from
ppcls.arch.backbone.legendary_models.hrnet
import
HRNet_W18_C
,
HRNet_W30_C
,
HRNet_W32_C
,
HRNet_W40_C
,
HRNet_W44_C
,
HRNet_W48_C
,
HRNet_W60_C
,
HRNet_W64_C
,
SE_HRNet_W64_C
from
ppcls.arch.backbone.legendary_models.pp_lcnet
import
PPLCNet_x0_25
,
PPLCNet_x0_35
,
PPLCNet_x0_5
,
PPLCNet_x0_75
,
PPLCNet_x1_0
,
PPLCNet_x1_5
,
PPLCNet_x2_0
,
PPLCNet_x2_5
from
ppcls.arch.backbone.legendary_models.esnet
import
ESNet_x0_25
,
ESNet_x0_5
,
ESNet_x0_75
,
ESNet_x1_0
from
ppcls.arch.backbone.legendary_models.pp_hgnet
import
PPHGNet_tiny
,
PPHGNet_small
,
PPHGNet_base
from
ppcls.arch.backbone.model_zoo.resnet_vc
import
ResNet50_vc
from
ppcls.arch.backbone.model_zoo.resnext
import
ResNeXt50_32x4d
,
ResNeXt50_64x4d
,
ResNeXt101_32x4d
,
ResNeXt101_64x4d
,
ResNeXt152_32x4d
,
ResNeXt152_64x4d
...
...
ppcls/arch/backbone/legendary_models/pp_hgnet.py
0 → 100644
浏览文件 @
8a760fb8
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn.initializer
import
KaimingNormal
,
Constant
from
paddle.nn
import
Conv2D
,
BatchNorm2D
,
ReLU
,
AdaptiveAvgPool2D
,
MaxPool2D
from
paddle.regularizer
import
L2Decay
from
paddle
import
ParamAttr
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"PPHGNet_tiny"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams"
,
"PPHGNet_small"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams"
}
__all__
=
list
(
MODEL_URLS
.
keys
())
kaiming_normal_
=
KaimingNormal
()
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
class
ConvBNAct
(
TheseusLayer
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
groups
=
1
,
use_act
=
True
):
super
().
__init__
()
self
.
use_act
=
use_act
self
.
conv
=
Conv2D
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
=
(
kernel_size
-
1
)
//
2
,
groups
=
groups
,
bias_attr
=
False
)
self
.
bn
=
BatchNorm2D
(
out_channels
,
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
if
self
.
use_act
:
self
.
act
=
ReLU
()
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
bn
(
x
)
if
self
.
use_act
:
x
=
self
.
act
(
x
)
return
x
class
ESEModule
(
TheseusLayer
):
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv
=
Conv2D
(
in_channels
=
channels
,
out_channels
=
channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
self
.
sigmoid
=
nn
.
Sigmoid
()
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
avg_pool
(
x
)
x
=
self
.
conv
(
x
)
x
=
self
.
sigmoid
(
x
)
return
paddle
.
multiply
(
x
=
identity
,
y
=
x
)
class
_HG_Block
(
TheseusLayer
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
False
,
):
super
().
__init__
()
self
.
identity
=
identity
self
.
layers
=
nn
.
LayerList
()
self
.
layers
.
append
(
ConvBNAct
(
in_channels
=
in_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
1
))
for
_
in
range
(
layer_num
-
1
):
self
.
layers
.
append
(
ConvBNAct
(
in_channels
=
mid_channels
,
out_channels
=
mid_channels
,
kernel_size
=
3
,
stride
=
1
))
# feature aggregation
total_channels
=
in_channels
+
layer_num
*
mid_channels
self
.
aggregation_conv
=
ConvBNAct
(
in_channels
=
total_channels
,
out_channels
=
out_channels
,
kernel_size
=
1
,
stride
=
1
)
self
.
att
=
ESEModule
(
out_channels
)
def
forward
(
self
,
x
):
identity
=
x
output
=
[]
output
.
append
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
output
.
append
(
x
)
x
=
paddle
.
concat
(
output
,
axis
=
1
)
x
=
self
.
aggregation_conv
(
x
)
x
=
self
.
att
(
x
)
if
self
.
identity
:
x
+=
identity
return
x
class
_HG_Stage
(
TheseusLayer
):
def
__init__
(
self
,
in_channels
,
mid_channels
,
out_channels
,
block_num
,
layer_num
,
downsample
=
True
):
super
().
__init__
()
self
.
downsample
=
downsample
if
downsample
:
self
.
downsample
=
ConvBNAct
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
3
,
stride
=
2
,
groups
=
in_channels
,
use_act
=
False
)
blocks_list
=
[]
blocks_list
.
append
(
_HG_Block
(
in_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
False
))
for
_
in
range
(
block_num
-
1
):
blocks_list
.
append
(
_HG_Block
(
out_channels
,
mid_channels
,
out_channels
,
layer_num
,
identity
=
True
))
self
.
blocks
=
nn
.
Sequential
(
*
blocks_list
)
def
forward
(
self
,
x
):
if
self
.
downsample
:
x
=
self
.
downsample
(
x
)
x
=
self
.
blocks
(
x
)
return
x
class
PPHGNet
(
TheseusLayer
):
"""
PPHGNet
Args:
stem_channels: list. Stem channel list of PPHGNet.
stage_config: dict. The configuration of each stage of PPHGNet. such as the number of channels, stride, etc.
layer_num: int. Number of layers of HG_Block.
use_last_conv: boolean. Whether to use a 1x1 convolutional layer before the classification layer.
class_expand: int=2048. Number of channels for the last 1x1 convolutional layer.
dropout_prob: float. Parameters of dropout, 0.0 means dropout is not used.
class_num: int=1000. The number of classes.
Returns:
model: nn.Layer. Specific PPHGNet model depends on args.
"""
def
__init__
(
self
,
stem_channels
,
stage_config
,
layer_num
,
use_last_conv
=
True
,
class_expand
=
2048
,
dropout_prob
=
0.0
,
class_num
=
1000
):
super
().
__init__
()
self
.
use_last_conv
=
use_last_conv
self
.
class_expand
=
class_expand
# stem
stem_channels
.
insert
(
0
,
3
)
self
.
stem
=
nn
.
Sequential
(
*
[
ConvBNAct
(
in_channels
=
stem_channels
[
i
],
out_channels
=
stem_channels
[
i
+
1
],
kernel_size
=
3
,
stride
=
2
if
i
==
0
else
1
)
for
i
in
range
(
len
(
stem_channels
)
-
1
)
])
self
.
pool
=
nn
.
MaxPool2D
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
# stages
self
.
stages
=
nn
.
LayerList
()
for
k
in
stage_config
:
in_channels
,
mid_channels
,
out_channels
,
block_num
,
downsample
=
stage_config
[
k
]
self
.
stages
.
append
(
_HG_Stage
(
in_channels
,
mid_channels
,
out_channels
,
block_num
,
layer_num
,
downsample
))
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
if
self
.
use_last_conv
:
self
.
last_conv
=
Conv2D
(
in_channels
=
out_channels
,
out_channels
=
self
.
class_expand
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias_attr
=
False
)
self
.
act
=
nn
.
ReLU
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
fc
=
nn
.
Linear
(
self
.
class_expand
if
self
.
use_last_conv
else
out_channels
,
class_num
)
self
.
_init_weights
()
def
_init_weights
(
self
):
for
m
in
self
.
sublayers
():
if
isinstance
(
m
,
nn
.
Conv2D
):
kaiming_normal_
(
m
.
weight
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm2D
)):
ones_
(
m
.
weight
)
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
Linear
):
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
x
=
self
.
pool
(
x
)
for
stage
in
self
.
stages
:
x
=
stage
(
x
)
x
=
self
.
avg_pool
(
x
)
if
self
.
use_last_conv
:
x
=
self
.
last_conv
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
PPHGNet_tiny
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
PPHGNet_tiny
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_tiny` model depends on args.
"""
stage_config
=
{
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1"
:
[
96
,
96
,
224
,
1
,
False
],
"stage2"
:
[
224
,
128
,
448
,
1
,
True
],
"stage3"
:
[
448
,
160
,
512
,
2
,
True
],
"stage4"
:
[
512
,
192
,
768
,
1
,
True
],
}
model
=
PPHGNet
(
stem_channels
=
[
48
,
48
,
96
],
stage_config
=
stage_config
,
layer_num
=
5
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPHGNet_tiny"
],
use_ssld
)
return
model
def
PPHGNet_small
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
PPHGNet_small
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_small` model depends on args.
"""
stage_config
=
{
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1"
:
[
128
,
128
,
256
,
1
,
False
],
"stage2"
:
[
256
,
160
,
512
,
1
,
True
],
"stage3"
:
[
512
,
192
,
768
,
2
,
True
],
"stage4"
:
[
768
,
224
,
1024
,
1
,
True
],
}
model
=
PPHGNet
(
stem_channels
=
[
64
,
64
,
128
],
stage_config
=
stage_config
,
layer_num
=
6
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPHGNet_small"
],
use_ssld
)
return
model
def
PPHGNet_base
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
"""
PPHGNet_base
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `PPHGNet_base` model depends on args.
"""
stage_config
=
{
# in_channels, mid_channels, out_channels, blocks, downsample
"stage1"
:
[
160
,
192
,
320
,
1
,
False
],
"stage2"
:
[
320
,
224
,
640
,
2
,
True
],
"stage3"
:
[
640
,
256
,
960
,
3
,
True
],
"stage4"
:
[
960
,
288
,
1280
,
2
,
True
],
}
model
=
PPHGNet
(
stem_channels
=
[
96
,
96
,
160
],
stage_config
=
stage_config
,
layer_num
=
7
,
dropout_prob
=
0.2
,
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"PPHGNet_base"
],
use_ssld
)
return
model
ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
0 → 100644
浏览文件 @
8a760fb8
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
600
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
PPHGNet_small
class_num
:
1000
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.5
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m7-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.25
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.2
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
16
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
0 → 100644
浏览文件 @
8a760fb8
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
600
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
PPHGNet_tiny
class_num
:
1000
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.5
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m7-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.25
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.2
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
232
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
16
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
232
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
test_tipc/config/PPHGNet/PPHGNet_small_train_infer_python.txt
0 → 100644
浏览文件 @
8a760fb8
===========================train_params===========================
model_name:PPHGNet_small
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_small.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=236
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
test_tipc/config/PPHGNet/PPHGNet_tiny_train_infer_python.txt
0 → 100644
浏览文件 @
8a760fb8
===========================train_params===========================
model_name:PPHGNet_tiny
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/PPHGNet/PPHGNet_tiny.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=232
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录