Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c6525f0b
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c6525f0b
编写于
6月 01, 2022
作者:
C
cuicheng01
提交者:
GitHub
6月 01, 2022
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1949 from cuicheng01/develop
update PPHGNet-base config
上级
0f520aab
ee5121a3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
346 addition
and
2 deletion
+346
-2
docs/zh_CN/algorithm_introduction/ImageNet_models.md
docs/zh_CN/algorithm_introduction/ImageNet_models.md
+5
-0
docs/zh_CN/models/PP-HGNet.md
docs/zh_CN/models/PP-HGNet.md
+5
-0
ppcls/arch/backbone/legendary_models/pp_hgnet.py
ppcls/arch/backbone/legendary_models/pp_hgnet.py
+3
-2
ppcls/configs/ImageNet/Distillation/res2net200_vd_distill_pphgnet_base.yaml
...eNet/Distillation/res2net200_vd_distill_pphgnet_base.yaml
+169
-0
ppcls/configs/ImageNet/PPHGNet/PPHGNet_base.yaml
ppcls/configs/ImageNet/PPHGNet/PPHGNet_base.yaml
+164
-0
未找到文件。
docs/zh_CN/algorithm_introduction/ImageNet_models.md
浏览文件 @
c6525f0b
...
@@ -133,6 +133,8 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该
...
@@ -133,6 +133,8 @@ PP-LCNet 系列模型的精度、速度指标如下表所示,更多关于该
**
: 基于 Intel-Xeon-Gold-6271C 硬件平台与 OpenVINO 2021.4.2 推理平台。
**
: 基于 Intel-Xeon-Gold-6271C 硬件平台与 OpenVINO 2021.4.2 推理平台。
<a
name=
"PPHGNet"
></a>
## PP-HGNet 系列
## PP-HGNet 系列
PP-HGNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:
[
PP-HGNet 系列模型文档
](
../models/PP-HGNet.md
)
。
PP-HGNet 系列模型的精度、速度指标如下表所示,更多关于该系列的模型介绍可以参考:
[
PP-HGNet 系列模型文档
](
../models/PP-HGNet.md
)
。
...
@@ -140,7 +142,10 @@ PP-HGNet 系列模型的精度、速度指标如下表所示,更多关于该
...
@@ -140,7 +142,10 @@ PP-HGNet 系列模型的精度、速度指标如下表所示,更多关于该
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
<br>
bs=1 | time(ms)
<br>
bs=4 | time(ms)
<br/>
bs=8 | FLOPs(G) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
| 模型 | Top-1 Acc | Top-5 Acc | time(ms)
<br>
bs=1 | time(ms)
<br>
bs=4 | time(ms)
<br/>
bs=8 | FLOPs(G) | Params(M) | 预训练模型下载地址 | inference模型下载地址 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
| PPHGNet_tiny | 0.7983 | 0.9504 | 1.77 | - | - | 4.54 | 14.75 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
)
|
| PPHGNet_tiny | 0.7983 | 0.9504 | 1.77 | - | - | 4.54 | 14.75 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_infer.tar
)
|
| PPHGNet_tiny_ssld | 0.8195 | 0.9612 | 1.77 | - | - | 4.54 | 14.75 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_ssld_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_tiny_ssld_infer.tar
)
|
| PPHGNet_small | 0.8151 | 0.9582 | 2.52 | - | - | 8.53 | 24.38 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
)
|
| PPHGNet_small | 0.8151 | 0.9582 | 2.52 | - | - | 8.53 | 24.38 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_infer.tar
)
|
| PPHGNet_small_ssld | 0.8382 | 0.9681 | 2.52 | - | - | 8.53 | 24.38 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_ssld_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_small_ssld_infer.tar
)
|
| PPHGNet_base_ssld | 0.8500 | 0.9735 | 5.97 | - | - | 25.14 | 71.62 |
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_base_ssld_pretrained.pdparams
)
|
[
下载链接
](
https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/PPHGNet_base_ssld_infer.tar
)
|
<a
name=
"ResNet"
></a>
<a
name=
"ResNet"
></a>
...
...
docs/zh_CN/models/PP-HGNet.md
浏览文件 @
c6525f0b
...
@@ -46,6 +46,11 @@ PP-HGNet 与其他模型的比较如下,其中测试机器为 NVIDIA® Tesla®
...
@@ -46,6 +46,11 @@ PP-HGNet 与其他模型的比较如下,其中测试机器为 NVIDIA® Tesla®
| SwinTransformer_tiny | 81.2 | 95.5 | 6.59 |
| SwinTransformer_tiny | 81.2 | 95.5 | 6.59 |
|
<b>
PPHGNet_small
<b>
|
<b>
81.51
<b>
|
<b>
95.82
<b>
|
<b>
2.52
<b>
|
|
<b>
PPHGNet_small
<b>
|
<b>
81.51
<b>
|
<b>
95.82
<b>
|
<b>
2.52
<b>
|
|
<b>
PPHGNet_small_ssld
<b>
|
<b>
83.82
<b>
|
<b>
96.81
<b>
|
<b>
2.52
<b>
|
|
<b>
PPHGNet_small_ssld
<b>
|
<b>
83.82
<b>
|
<b>
96.81
<b>
|
<b>
2.52
<b>
|
| Res2Net200_vd_26w_4s_ssld| 85.13 | 97.42 | 11.45 |
| ResNeXt101_32x48d_wsl | 85.37 | 97.69 | 55.07 |
| SwinTransformer_base | 85.2 | 97.5 | 13.53 |
|
<b>
PPHGNet_base_ssld
<b>
|
<b>
85.00
<b>
|
<b>
97.35
<b>
|
<b>
5.97
<b>
|
关于更多 PP-HGNet 的介绍以及下游任务的表现,敬请期待。
关于更多 PP-HGNet 的介绍以及下游任务的表现,敬请期待。
ppcls/arch/backbone/legendary_models/pp_hgnet.py
浏览文件 @
c6525f0b
...
@@ -27,7 +27,8 @@ MODEL_URLS = {
...
@@ -27,7 +27,8 @@ MODEL_URLS = {
"PPHGNet_tiny"
:
"PPHGNet_tiny"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams"
,
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_tiny_pretrained.pdparams"
,
"PPHGNet_small"
:
"PPHGNet_small"
:
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams"
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/PPHGNet_small_pretrained.pdparams"
,
"PPHGNet_base"
:
""
}
}
__all__
=
list
(
MODEL_URLS
.
keys
())
__all__
=
list
(
MODEL_URLS
.
keys
())
...
@@ -344,7 +345,7 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
...
@@ -344,7 +345,7 @@ def PPHGNet_small(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPHGNet_base
(
pretrained
=
False
,
use_ssld
=
Fals
e
,
**
kwargs
):
def
PPHGNet_base
(
pretrained
=
False
,
use_ssld
=
Tru
e
,
**
kwargs
):
"""
"""
PPHGNet_base
PPHGNet_base
Args:
Args:
...
...
ppcls/configs/ImageNet/Distillation/res2net200_vd_distill_pphgnet_base.yaml
0 → 100644
浏览文件 @
c6525f0b
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
360
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
"
./inference"
use_dali
:
false
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
"
DistillationModel"
class_num
:
&class_num
1000
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
models
:
-
Teacher
:
name
:
Res2Net200_vd_26w_4s
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
-
Student
:
name
:
PPHGNet_base
class_num
:
*class_num
pretrained
:
False
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationCELoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.5
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/train_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m7-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
"
./dataset/ILSVRC2012/"
cls_label_path
:
"
./dataset/ILSVRC2012/val_list.txt"
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
0.00392157
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
8
use_shared_memory
:
True
Infer
:
infer_imgs
:
"
docs/images/inference_deployment/whl_demo.jpg"
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
DistillationPostProcess
func
:
Topk
topk
:
5
class_id_map_file
:
"
ppcls/utils/imagenet1k_label_list.txt"
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
Eval
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
5
]
ppcls/configs/ImageNet/PPHGNet/PPHGNet_base.yaml
0 → 100644
浏览文件 @
c6525f0b
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
600
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
PPHGNet_base
class_num
:
1000
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.5
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m15-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.4
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.4
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
16
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
236
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
5
]
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录