Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
45b1296c
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
45b1296c
编写于
5月 14, 2022
作者:
C
cuicheng01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add cls_demo_person code
上级
713dd6f9
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
828 addition
and
84 deletion
+828
-84
ppcls/arch/backbone/legendary_models/pp_lcnet.py
ppcls/arch/backbone/legendary_models/pp_lcnet.py
+74
-44
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
...s_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
+168
-0
ppcls/configs/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml
...s/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml
+144
-0
ppcls/configs/cls_demo/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
.../OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
+167
-0
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
+150
-0
ppcls/data/preprocess/__init__.py
ppcls/data/preprocess/__init__.py
+5
-5
ppcls/engine/engine.py
ppcls/engine/engine.py
+4
-4
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+6
-11
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+13
-3
ppcls/metric/avg_metrics.py
ppcls/metric/avg_metrics.py
+20
-0
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+73
-17
ppcls/utils/misc.py
ppcls/utils/misc.py
+4
-0
未找到文件。
ppcls/arch/backbone/legendary_models/pp_lcnet.py
浏览文件 @
45b1296c
...
@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function
...
@@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddle
import
ParamAttr
from
paddle
import
ParamAttr
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
,
Conv2D
,
Dropout
,
Linear
from
paddle.nn
import
AdaptiveAvgPool2D
,
BatchNorm
2D
,
Conv2D
,
Dropout
,
Linear
from
paddle.regularizer
import
L2Decay
from
paddle.regularizer
import
L2Decay
from
paddle.nn.initializer
import
KaimingNormal
from
paddle.nn.initializer
import
KaimingNormal
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
from
ppcls.arch.backbone.base.theseus_layer
import
TheseusLayer
...
@@ -83,7 +83,8 @@ class ConvBNLayer(TheseusLayer):
...
@@ -83,7 +83,8 @@ class ConvBNLayer(TheseusLayer):
filter_size
,
filter_size
,
num_filters
,
num_filters
,
stride
,
stride
,
num_groups
=
1
):
num_groups
=
1
,
lr_mult
=
1.0
):
super
().
__init__
()
super
().
__init__
()
self
.
conv
=
Conv2D
(
self
.
conv
=
Conv2D
(
...
@@ -93,13 +94,13 @@ class ConvBNLayer(TheseusLayer):
...
@@ -93,13 +94,13 @@ class ConvBNLayer(TheseusLayer):
stride
=
stride
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
num_groups
,
groups
=
num_groups
,
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()),
weight_attr
=
ParamAttr
(
initializer
=
KaimingNormal
()
,
learning_rate
=
lr_mult
),
bias_attr
=
False
)
bias_attr
=
False
)
self
.
bn
=
BatchNorm
(
self
.
bn
=
BatchNorm
2D
(
num_filters
,
num_filters
,
param_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)
),
weight_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
),
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)))
bias_attr
=
ParamAttr
(
regularizer
=
L2Decay
(
0.0
)
,
learning_rate
=
lr_mult
))
self
.
hardswish
=
nn
.
Hardswish
()
self
.
hardswish
=
nn
.
Hardswish
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -115,7 +116,8 @@ class DepthwiseSeparable(TheseusLayer):
...
@@ -115,7 +116,8 @@ class DepthwiseSeparable(TheseusLayer):
num_filters
,
num_filters
,
stride
,
stride
,
dw_size
=
3
,
dw_size
=
3
,
use_se
=
False
):
use_se
=
False
,
lr_mult
=
1.0
):
super
().
__init__
()
super
().
__init__
()
self
.
use_se
=
use_se
self
.
use_se
=
use_se
self
.
dw_conv
=
ConvBNLayer
(
self
.
dw_conv
=
ConvBNLayer
(
...
@@ -123,14 +125,17 @@ class DepthwiseSeparable(TheseusLayer):
...
@@ -123,14 +125,17 @@ class DepthwiseSeparable(TheseusLayer):
num_filters
=
num_channels
,
num_filters
=
num_channels
,
filter_size
=
dw_size
,
filter_size
=
dw_size
,
stride
=
stride
,
stride
=
stride
,
num_groups
=
num_channels
)
num_groups
=
num_channels
,
lr_mult
=
lr_mult
)
if
use_se
:
if
use_se
:
self
.
se
=
SEModule
(
num_channels
)
self
.
se
=
SEModule
(
num_channels
,
lr_mult
=
lr_mult
)
self
.
pw_conv
=
ConvBNLayer
(
self
.
pw_conv
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_channels
=
num_channels
,
filter_size
=
1
,
filter_size
=
1
,
num_filters
=
num_filters
,
num_filters
=
num_filters
,
stride
=
1
)
stride
=
1
,
lr_mult
=
lr_mult
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
dw_conv
(
x
)
x
=
self
.
dw_conv
(
x
)
...
@@ -141,7 +146,7 @@ class DepthwiseSeparable(TheseusLayer):
...
@@ -141,7 +146,7 @@ class DepthwiseSeparable(TheseusLayer):
class
SEModule
(
TheseusLayer
):
class
SEModule
(
TheseusLayer
):
def
__init__
(
self
,
channel
,
reduction
=
4
):
def
__init__
(
self
,
channel
,
reduction
=
4
,
lr_mult
=
1.0
):
super
().
__init__
()
super
().
__init__
()
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
conv1
=
Conv2D
(
self
.
conv1
=
Conv2D
(
...
@@ -149,14 +154,18 @@ class SEModule(TheseusLayer):
...
@@ -149,14 +154,18 @@ class SEModule(TheseusLayer):
out_channels
=
channel
//
reduction
,
out_channels
=
channel
//
reduction
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
weight_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
))
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv2
=
Conv2D
(
self
.
conv2
=
Conv2D
(
in_channels
=
channel
//
reduction
,
in_channels
=
channel
//
reduction
,
out_channels
=
channel
,
out_channels
=
channel
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
)
padding
=
0
,
weight_attr
=
ParamAttr
(
learning_rate
=
lr_mult
),
bias_attr
=
ParamAttr
(
learning_rate
=
lr_mult
))
self
.
hardsigmoid
=
nn
.
Hardsigmoid
()
self
.
hardsigmoid
=
nn
.
Hardsigmoid
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -175,19 +184,34 @@ class PPLCNet(TheseusLayer):
...
@@ -175,19 +184,34 @@ class PPLCNet(TheseusLayer):
stages_pattern
,
stages_pattern
,
scale
=
1.0
,
scale
=
1.0
,
class_num
=
1000
,
class_num
=
1000
,
dropout_prob
=
0.
2
,
dropout_prob
=
0.
0
,
class_expand
=
1280
,
class_expand
=
1280
,
lr_mult_list
=
[
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
],
use_last_conv
=
True
,
return_patterns
=
None
,
return_patterns
=
None
,
return_stages
=
None
):
return_stages
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
scale
=
scale
self
.
scale
=
scale
self
.
class_expand
=
class_expand
self
.
class_expand
=
class_expand
self
.
lr_mult_list
=
lr_mult_list
self
.
use_last_conv
=
use_last_conv
if
isinstance
(
self
.
lr_mult_list
,
str
):
self
.
lr_mult_list
=
eval
(
self
.
lr_mult_list
)
assert
isinstance
(
self
.
lr_mult_list
,
(
list
,
tuple
)),
"lr_mult_list should be in (list, tuple) but got {}"
.
format
(
type
(
self
.
lr_mult_list
))
assert
len
(
self
.
lr_mult_list
)
==
6
,
"lr_mult_list length should be 5 but got {}"
.
format
(
len
(
self
.
lr_mult_list
))
self
.
conv1
=
ConvBNLayer
(
self
.
conv1
=
ConvBNLayer
(
num_channels
=
3
,
num_channels
=
3
,
filter_size
=
3
,
filter_size
=
3
,
num_filters
=
make_divisible
(
16
*
scale
),
num_filters
=
make_divisible
(
16
*
scale
),
stride
=
2
)
stride
=
2
,
lr_mult
=
self
.
lr_mult_list
[
0
])
self
.
blocks2
=
nn
.
Sequential
(
*
[
self
.
blocks2
=
nn
.
Sequential
(
*
[
DepthwiseSeparable
(
DepthwiseSeparable
(
...
@@ -195,7 +219,8 @@ class PPLCNet(TheseusLayer):
...
@@ -195,7 +219,8 @@ class PPLCNet(TheseusLayer):
num_filters
=
make_divisible
(
out_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
dw_size
=
k
,
stride
=
s
,
stride
=
s
,
use_se
=
se
)
use_se
=
se
,
lr_mult
=
self
.
lr_mult_list
[
1
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks2"
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks2"
])
])
])
...
@@ -205,7 +230,8 @@ class PPLCNet(TheseusLayer):
...
@@ -205,7 +230,8 @@ class PPLCNet(TheseusLayer):
num_filters
=
make_divisible
(
out_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
dw_size
=
k
,
stride
=
s
,
stride
=
s
,
use_se
=
se
)
use_se
=
se
,
lr_mult
=
self
.
lr_mult_list
[
2
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks3"
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks3"
])
])
])
...
@@ -215,7 +241,8 @@ class PPLCNet(TheseusLayer):
...
@@ -215,7 +241,8 @@ class PPLCNet(TheseusLayer):
num_filters
=
make_divisible
(
out_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
dw_size
=
k
,
stride
=
s
,
stride
=
s
,
use_se
=
se
)
use_se
=
se
,
lr_mult
=
self
.
lr_mult_list
[
3
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks4"
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks4"
])
])
])
...
@@ -225,7 +252,8 @@ class PPLCNet(TheseusLayer):
...
@@ -225,7 +252,8 @@ class PPLCNet(TheseusLayer):
num_filters
=
make_divisible
(
out_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
dw_size
=
k
,
stride
=
s
,
stride
=
s
,
use_se
=
se
)
use_se
=
se
,
lr_mult
=
self
.
lr_mult_list
[
4
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks5"
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks5"
])
])
])
...
@@ -235,25 +263,26 @@ class PPLCNet(TheseusLayer):
...
@@ -235,25 +263,26 @@ class PPLCNet(TheseusLayer):
num_filters
=
make_divisible
(
out_c
*
scale
),
num_filters
=
make_divisible
(
out_c
*
scale
),
dw_size
=
k
,
dw_size
=
k
,
stride
=
s
,
stride
=
s
,
use_se
=
se
)
use_se
=
se
,
lr_mult
=
self
.
lr_mult_list
[
5
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks6"
])
for
i
,
(
k
,
in_c
,
out_c
,
s
,
se
)
in
enumerate
(
NET_CONFIG
[
"blocks6"
])
])
])
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
self
.
avg_pool
=
AdaptiveAvgPool2D
(
1
)
if
self
.
use_last_conv
:
self
.
last_conv
=
Conv2D
(
self
.
last_conv
=
Conv2D
(
in_channels
=
make_divisible
(
NET_CONFIG
[
"blocks6"
][
-
1
][
2
]
*
scale
),
in_channels
=
make_divisible
(
NET_CONFIG
[
"blocks6"
][
-
1
][
2
]
*
scale
),
out_channels
=
self
.
class_expand
,
out_channels
=
self
.
class_expand
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias_attr
=
False
)
bias_attr
=
False
)
self
.
hardswish
=
nn
.
Hardswish
()
self
.
hardswish
=
nn
.
Hardswish
()
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
self
.
dropout
=
Dropout
(
p
=
dropout_prob
,
mode
=
"downscale_in_infer"
)
else
:
self
.
last_conv
=
None
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
flatten
=
nn
.
Flatten
(
start_axis
=
1
,
stop_axis
=-
1
)
self
.
fc
=
Linear
(
self
.
class_expand
if
self
.
use_last_conv
else
NET_CONFIG
[
"blocks6"
][
-
1
][
2
],
class_num
)
self
.
fc
=
Linear
(
self
.
class_expand
,
class_num
)
super
().
init_res
(
super
().
init_res
(
stages_pattern
,
stages_pattern
,
...
@@ -270,9 +299,10 @@ class PPLCNet(TheseusLayer):
...
@@ -270,9 +299,10 @@ class PPLCNet(TheseusLayer):
x
=
self
.
blocks6
(
x
)
x
=
self
.
blocks6
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
self
.
avg_pool
(
x
)
x
=
self
.
last_conv
(
x
)
if
self
.
last_conv
is
not
None
:
x
=
self
.
hardswish
(
x
)
x
=
self
.
last_conv
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
hardswish
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
flatten
(
x
)
x
=
self
.
fc
(
x
)
x
=
self
.
fc
(
x
)
return
x
return
x
...
@@ -291,7 +321,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld):
...
@@ -291,7 +321,7 @@ def _load_pretrained(pretrained, model, model_url, use_ssld):
)
)
def
PPLCNet_x0_25
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x0_25
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x0_25
PPLCNet_x0_25
Args:
Args:
...
@@ -307,7 +337,7 @@ def PPLCNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
...
@@ -307,7 +337,7 @@ def PPLCNet_x0_25(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x0_35
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x0_35
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x0_35
PPLCNet_x0_35
Args:
Args:
...
@@ -323,7 +353,7 @@ def PPLCNet_x0_35(pretrained=False, use_ssld=False, **kwargs):
...
@@ -323,7 +353,7 @@ def PPLCNet_x0_35(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x0_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x0_5
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x0_5
PPLCNet_x0_5
Args:
Args:
...
@@ -339,7 +369,7 @@ def PPLCNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
...
@@ -339,7 +369,7 @@ def PPLCNet_x0_5(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x0_75
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x0_75
PPLCNet_x0_75
Args:
Args:
...
@@ -355,7 +385,7 @@ def PPLCNet_x0_75(pretrained=False, use_ssld=False, **kwargs):
...
@@ -355,7 +385,7 @@ def PPLCNet_x0_75(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x1_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x1_0
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x1_0
PPLCNet_x1_0
Args:
Args:
...
@@ -371,7 +401,7 @@ def PPLCNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
...
@@ -371,7 +401,7 @@ def PPLCNet_x1_0(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x1_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x1_5
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x1_5
PPLCNet_x1_5
Args:
Args:
...
@@ -387,7 +417,7 @@ def PPLCNet_x1_5(pretrained=False, use_ssld=False, **kwargs):
...
@@ -387,7 +417,7 @@ def PPLCNet_x1_5(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x2_0
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x2_0
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x2_0
PPLCNet_x2_0
Args:
Args:
...
@@ -403,7 +433,7 @@ def PPLCNet_x2_0(pretrained=False, use_ssld=False, **kwargs):
...
@@ -403,7 +433,7 @@ def PPLCNet_x2_0(pretrained=False, use_ssld=False, **kwargs):
return
model
return
model
def
PPLCNet_x2_5
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
def
PPLCNet_x2_5
(
pretrained
=
False
,
use_ssld
=
False
,
use_sync_bn
=
False
,
**
kwargs
):
"""
"""
PPLCNet_x2_5
PPLCNet_x2_5
Args:
Args:
...
...
ppcls/configs/cls_demo/person/Distillation/PPLCNet_x1_0_distillation.yaml
0 → 100644
浏览文件 @
45b1296c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output
device
:
gpu
save_interval
:
1
eval_during_train
:
True
start_eval_epoch
:
1
eval_interval
:
1
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
"
DistillationModel"
class_num
:
&class_num
2
# if not null, its lengths should be same as models
pretrained_list
:
# if not null, its lengths should be same as models
freeze_params_list
:
-
True
-
False
use_sync_bn
:
True
models
:
-
Teacher
:
name
:
ResNet101_vd
class_num
:
*class_num
pretrained
:
"
./output/TEACHER_ResNet101_vd/ResNet101_vd/best_model"
-
Student
:
name
:
PPLCNet_x1_0
class_num
:
*class_num
pretrained
:
True
use_ssld
:
True
infer_model_name
:
"
Student"
# loss function config for traing/eval process
Loss
:
Train
:
-
DistillationDMLLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.01
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/coco/
cls_label_path
:
./dataset/coco/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.0
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
16
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/coco/
cls_label_path
:
./dataset/coco/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
DistillationTopkAcc
:
model_key
:
"
Student"
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/configs/cls_demo/person/OtherModels/MobileNetV3_large_x1_0.yaml
0 → 100644
浏览文件 @
45b1296c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
start_eval_epoch
:
15
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
MobileNetV3_large_x1_0
class_num
:
2
pretrained
:
True
use_sync_bn
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.13
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00002
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
512
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/configs/cls_demo/person/OtherModels/SwinTransformer_tiny_patch4_window7_224.yaml
0 → 100644
浏览文件 @
45b1296c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
start_eval_epoch
:
1
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# mixed precision training
AMP
:
scale_loss
:
128.0
use_dynamic_loss_scaling
:
True
# O1: mixed fp16
level
:
O1
# model architecture
Arch
:
name
:
SwinTransformer_tiny_patch4_window7_224
class_num
:
2
pretrained
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
epsilon
:
1e-8
weight_decay
:
0.05
no_weight_decay_name
:
absolute_pos_embed relative_position_bias_table .bias norm
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Cosine
learning_rate
:
1e-4
eta_min
:
2e-6
warmup_epoch
:
5
warmup_start_lr
:
2e-7
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.25
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.8
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
8
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/configs/cls_demo/person/PPLCNet/PPLCNet_x1_0.yaml
0 → 100644
浏览文件 @
45b1296c
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
start_eval_epoch
:
1
epochs
:
20
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
use_dali
:
False
# model architecture
Arch
:
name
:
PPLCNet_x1_0
class_num
:
2
pretrained
:
True
use_ssld
:
True
use_sync_bn
:
True
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
Momentum
momentum
:
0.9
lr
:
name
:
Cosine
learning_rate
:
0.01
warmup_epoch
:
5
regularizer
:
name
:
'
L2'
coeff
:
0.00004
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
prob
:
0.0
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.0
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
True
loader
:
num_workers
:
8
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/person/
cls_label_path
:
./dataset/person/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Train
:
-
TopkAcc
:
topk
:
[
1
,
2
]
Eval
:
-
TprAtFpr
:
-
TopkAcc
:
topk
:
[
1
,
2
]
ppcls/data/preprocess/__init__.py
浏览文件 @
45b1296c
...
@@ -38,7 +38,7 @@ from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, Cutmi
...
@@ -38,7 +38,7 @@ from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, Cutmi
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
import
random
def
transform
(
data
,
ops
=
[]):
def
transform
(
data
,
ops
=
[]):
""" transform """
""" transform """
...
@@ -88,16 +88,16 @@ class RandAugment(RawRandAugment):
...
@@ -88,16 +88,16 @@ class RandAugment(RawRandAugment):
class
TimmAutoAugment
(
RawTimmAutoAugment
):
class
TimmAutoAugment
(
RawTimmAutoAugment
):
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
prob
=
1.0
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
prob
=
prob
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
if
not
isinstance
(
img
,
Image
.
Image
):
if
not
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
ascontiguousarray
(
img
)
img
=
np
.
ascontiguousarray
(
img
)
img
=
Image
.
fromarray
(
img
)
img
=
Image
.
fromarray
(
img
)
if
random
.
random
()
<
self
.
prob
:
img
=
super
().
__call__
(
img
)
img
=
super
().
__call__
(
img
)
if
isinstance
(
img
,
Image
.
Image
):
if
isinstance
(
img
,
Image
.
Image
):
img
=
np
.
asarray
(
img
)
img
=
np
.
asarray
(
img
)
...
...
ppcls/engine/engine.py
浏览文件 @
45b1296c
...
@@ -312,7 +312,7 @@ class Engine(object):
...
@@ -312,7 +312,7 @@ class Engine(object):
print_batch_step
=
self
.
config
[
'Global'
][
'print_batch_step'
]
print_batch_step
=
self
.
config
[
'Global'
][
'print_batch_step'
]
save_interval
=
self
.
config
[
"Global"
][
"save_interval"
]
save_interval
=
self
.
config
[
"Global"
][
"save_interval"
]
best_metric
=
{
best_metric
=
{
"metric"
:
0
.0
,
"metric"
:
-
1
.0
,
"epoch"
:
0
,
"epoch"
:
0
,
}
}
# key:
# key:
...
@@ -345,17 +345,17 @@ class Engine(object):
...
@@ -345,17 +345,17 @@ class Engine(object):
if
self
.
use_dali
:
if
self
.
use_dali
:
self
.
train_dataloader
.
reset
()
self
.
train_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
self
.
output_info
[
key
].
avg
)
self
.
output_info
[
key
].
avg_info
for
key
in
self
.
output_info
for
key
in
self
.
output_info
])
])
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
logger
.
info
(
"[Train][Epoch {}/{}][Avg]{}"
.
format
(
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
metric_msg
))
epoch_id
,
self
.
config
[
"Global"
][
"epochs"
],
metric_msg
))
self
.
output_info
.
clear
()
self
.
output_info
.
clear
()
# eval model and save model if possible
# eval model and save model if possible
start_eval_epoch
=
self
.
config
[
"Global"
].
get
(
"start_eval_epoch"
,
0
)
-
1
if
self
.
config
[
"Global"
][
if
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_during_train"
]
and
epoch_id
%
self
.
config
[
"Global"
][
"eval_interval"
]
==
0
:
"eval_interval"
]
==
0
and
epoch_id
>
start_eval_epoch
:
acc
=
self
.
eval
(
epoch_id
)
acc
=
self
.
eval
(
epoch_id
)
if
acc
>
best_metric
[
"metric"
]:
if
acc
>
best_metric
[
"metric"
]:
best_metric
[
"metric"
]
=
acc
best_metric
[
"metric"
]
=
acc
...
...
ppcls/engine/evaluation/classification.py
浏览文件 @
45b1296c
...
@@ -23,6 +23,8 @@ from ppcls.utils import logger
...
@@ -23,6 +23,8 @@ from ppcls.utils import logger
def
classification_eval
(
engine
,
epoch_id
=
0
):
def
classification_eval
(
engine
,
epoch_id
=
0
):
if
hasattr
(
engine
.
eval_metric_func
,
"reset"
):
engine
.
eval_metric_func
.
reset
()
output_info
=
dict
()
output_info
=
dict
()
time_info
=
{
time_info
=
{
"batch_cost"
:
AverageMeter
(
"batch_cost"
:
AverageMeter
(
...
@@ -123,16 +125,7 @@ def classification_eval(engine, epoch_id=0):
...
@@ -123,16 +125,7 @@ def classification_eval(engine, epoch_id=0):
current_samples
)
current_samples
)
# calc metric
# calc metric
if
engine
.
eval_metric_func
is
not
None
:
if
engine
.
eval_metric_func
is
not
None
:
metric_dict
=
engine
.
eval_metric_func
(
preds
,
labels
)
engine
.
eval_metric_func
(
preds
,
labels
)
for
key
in
metric_dict
:
if
metric_key
is
None
:
metric_key
=
key
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
current_samples
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
if
iter_id
%
print_batch_step
==
0
:
if
iter_id
%
print_batch_step
==
0
:
...
@@ -148,6 +141,7 @@ def classification_eval(engine, epoch_id=0):
...
@@ -148,6 +141,7 @@ def classification_eval(engine, epoch_id=0):
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
for
key
in
output_info
for
key
in
output_info
])
])
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
epoch_id
,
iter_id
,
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
...
@@ -158,10 +152,11 @@ def classification_eval(engine, epoch_id=0):
...
@@ -158,10 +152,11 @@ def classification_eval(engine, epoch_id=0):
metric_msg
=
", "
.
join
([
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
])
metric_msg
+=
", {}"
.
format
(
engine
.
eval_metric_func
.
avg_info
)
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
if
engine
.
eval_metric_func
is
None
:
return
-
1
return
-
1
# return 1st metric in the dict
# return 1st metric in the dict
return
output_info
[
metric_key
]
.
avg
return
engine
.
eval_metric_func
.
avg
ppcls/metric/__init__.py
浏览文件 @
45b1296c
...
@@ -12,17 +12,18 @@
...
@@ -12,17 +12,18 @@
#See the License for the specific language governing permissions and
#See the License for the specific language governing permissions and
#limitations under the License.
#limitations under the License.
from
paddle
import
nn
import
copy
import
copy
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
.avg_metrics
import
AvgMetrics
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
,
Precisionk
from
.metrics
import
TopkAcc
,
mAP
,
mINP
,
Recallk
,
Precisionk
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
from
.metrics
import
HammingDistance
,
AccuracyScore
from
.metrics
import
HammingDistance
,
AccuracyScore
from
.metrics
import
TprAtFpr
class
CombinedMetrics
(
nn
.
Layer
):
class
CombinedMetrics
(
AvgMetrics
):
def
__init__
(
self
,
config_list
):
def
__init__
(
self
,
config_list
):
super
().
__init__
()
super
().
__init__
()
self
.
metric_func_list
=
[]
self
.
metric_func_list
=
[]
...
@@ -39,13 +40,22 @@ class CombinedMetrics(nn.Layer):
...
@@ -39,13 +40,22 @@ class CombinedMetrics(nn.Layer):
else
:
else
:
self
.
metric_func_list
.
append
(
eval
(
metric_name
)())
self
.
metric_func_list
.
append
(
eval
(
metric_name
)())
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
metric_dict
=
OrderedDict
()
metric_dict
=
OrderedDict
()
for
idx
,
metric_func
in
enumerate
(
self
.
metric_func_list
):
for
idx
,
metric_func
in
enumerate
(
self
.
metric_func_list
):
metric_dict
.
update
(
metric_func
(
*
args
,
**
kwargs
))
metric_dict
.
update
(
metric_func
(
*
args
,
**
kwargs
))
return
metric_dict
return
metric_dict
@
property
def
avg_info
(
self
):
return
", "
.
join
([
metric
.
avg_info
for
metric
in
self
.
metric_func_list
])
@
property
def
avg
(
self
):
return
self
.
metric_func_list
[
0
].
avg
def
build_metrics
(
config
):
def
build_metrics
(
config
):
metrics_list
=
CombinedMetrics
(
copy
.
deepcopy
(
config
))
metrics_list
=
CombinedMetrics
(
copy
.
deepcopy
(
config
))
return
metrics_list
return
metrics_list
ppcls/metric/avg_metrics.py
0 → 100644
浏览文件 @
45b1296c
from
paddle
import
nn
class
AvgMetrics
(
nn
.
Layer
):
def
__init__
(
self
):
super
().
__init__
()
self
.
avg_meters
=
{}
def
reset
(
self
):
self
.
avg_meters
=
{}
@
property
def
avg
(
self
):
if
self
.
avg_meters
:
for
metric_key
in
self
.
avg_meters
:
return
self
.
avg_meters
[
metric_key
].
avg
@
property
def
avg_info
(
self
):
return
", "
.
join
([
self
.
avg_meters
[
key
].
avg_info
for
key
in
self
.
avg_meters
])
ppcls/metric/metrics.py
浏览文件 @
45b1296c
...
@@ -22,14 +22,18 @@ from sklearn.metrics import accuracy_score as accuracy_metric
...
@@ -22,14 +22,18 @@ from sklearn.metrics import accuracy_score as accuracy_metric
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.preprocessing
import
binarize
from
sklearn.preprocessing
import
binarize
from
ppcls.metric.avg_metrics
import
AvgMetrics
from
ppcls.utils.misc
import
AverageMeter
class
TopkAcc
(
nn
.
Layer
):
class
TopkAcc
(
AvgMetrics
):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
super
().
__init__
()
super
().
__init__
()
assert
isinstance
(
topk
,
(
int
,
list
,
tuple
))
assert
isinstance
(
topk
,
(
int
,
list
,
tuple
))
if
isinstance
(
topk
,
int
):
if
isinstance
(
topk
,
int
):
topk
=
[
topk
]
topk
=
[
topk
]
self
.
topk
=
topk
self
.
topk
=
topk
self
.
avg_meters
=
{
"top{}"
.
format
(
k
):
AverageMeter
(
"top{}"
.
format
(
k
))
for
k
in
self
.
topk
}
def
forward
(
self
,
x
,
label
):
def
forward
(
self
,
x
,
label
):
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
...
@@ -39,6 +43,7 @@ class TopkAcc(nn.Layer):
...
@@ -39,6 +43,7 @@ class TopkAcc(nn.Layer):
for
k
in
self
.
topk
:
for
k
in
self
.
topk
:
metric_dict
[
"top{}"
.
format
(
k
)]
=
paddle
.
metric
.
accuracy
(
metric_dict
[
"top{}"
.
format
(
k
)]
=
paddle
.
metric
.
accuracy
(
x
,
label
,
k
=
k
)
x
,
label
,
k
=
k
)
self
.
avg_meters
[
"top{}"
.
format
(
k
)].
update
(
metric_dict
[
"top{}"
.
format
(
k
)].
numpy
()[
0
],
x
.
shape
[
0
])
return
metric_dict
return
metric_dict
...
@@ -129,6 +134,57 @@ class mINP(nn.Layer):
...
@@ -129,6 +134,57 @@ class mINP(nn.Layer):
return
metric_dict
return
metric_dict
class
TprAtFpr
(
nn
.
Layer
):
def
__init__
(
self
,
max_fpr
=
1
/
1000.
):
super
().
__init__
()
self
.
gt_pos_score_list
=
[]
self
.
gt_neg_score_list
=
[]
self
.
softmax
=
nn
.
Softmax
(
axis
=-
1
)
self
.
max_fpr
=
max_fpr
self
.
max_tpr
=
0.
def
forward
(
self
,
x
,
label
):
if
isinstance
(
x
,
dict
):
x
=
x
[
"logits"
]
x
=
self
.
softmax
(
x
)
for
i
,
label_i
in
enumerate
(
label
):
if
label_i
[
0
]
==
0
:
self
.
gt_neg_score_list
.
append
(
x
[
i
][
1
].
numpy
())
else
:
self
.
gt_pos_score_list
.
append
(
x
[
i
][
1
].
numpy
())
return
{}
def
reset
(
self
):
self
.
gt_pos_score_list
=
[]
self
.
gt_neg_score_list
=
[]
self
.
max_tpr
=
0.
@
property
def
avg
(
self
):
return
self
.
max_tpr
@
property
def
avg_info
(
self
):
max_tpr
=
0.
result
=
""
gt_pos_score_list
=
np
.
array
(
self
.
gt_pos_score_list
)
gt_neg_score_list
=
np
.
array
(
self
.
gt_neg_score_list
)
for
i
in
range
(
0
,
10000
):
threshold
=
i
/
10000.
if
len
(
gt_pos_score_list
)
==
0
:
continue
tpr
=
np
.
sum
(
gt_pos_score_list
>
threshold
)
/
len
(
gt_pos_score_list
)
if
len
(
gt_neg_score_list
)
==
0
and
tpr
>
max_tpr
:
max_tpr
=
tpr
result
=
"threshold: {}, fpr: {}, tpr: {:.5f}"
.
format
(
threshold
,
fpr
,
tpr
)
fpr
=
np
.
sum
(
gt_neg_score_list
>
threshold
)
/
len
(
gt_neg_score_list
)
if
fpr
<=
self
.
max_fpr
and
tpr
>
max_tpr
:
max_tpr
=
tpr
result
=
"threshold: {}, fpr: {}, tpr: {:.5f}"
.
format
(
threshold
,
fpr
,
tpr
)
self
.
max_tpr
=
max_tpr
return
result
class
Recallk
(
nn
.
Layer
):
class
Recallk
(
nn
.
Layer
):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
super
().
__init__
()
super
().
__init__
()
...
@@ -241,20 +297,17 @@ class GoogLeNetTopkAcc(TopkAcc):
...
@@ -241,20 +297,17 @@ class GoogLeNetTopkAcc(TopkAcc):
return
super
().
forward
(
x
[
0
],
label
)
return
super
().
forward
(
x
[
0
],
label
)
class
MutiLabelMetric
(
object
):
class
MultiLabelMetric
(
AvgMetrics
):
def
__init__
(
self
):
def
__init__
(
self
,
bi_threshold
=
0.5
):
pass
super
().
__init__
()
self
.
bi_threshold
=
bi_threshold
def
_multi_hot_encode
(
self
,
logits
,
threshold
=
0.5
):
return
binarize
(
logits
,
threshold
=
threshold
)
def
__call__
(
self
,
output
):
def
_multi_hot_encode
(
self
,
output
):
output
=
F
.
sigmoid
(
output
)
logits
=
F
.
sigmoid
(
output
).
numpy
()
preds
=
self
.
_multi_hot_encode
(
logits
=
output
.
numpy
(),
threshold
=
0.5
)
return
binarize
(
logits
,
threshold
=
self
.
bi_threshold
)
return
preds
class
HammingDistance
(
MutiLabelMetric
):
class
HammingDistance
(
Mu
l
tiLabelMetric
):
"""
"""
Soft metric based label for multilabel classification
Soft metric based label for multilabel classification
Returns:
Returns:
...
@@ -263,16 +316,18 @@ class HammingDistance(MutiLabelMetric):
...
@@ -263,16 +316,18 @@ class HammingDistance(MutiLabelMetric):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
avg_meters
=
{
"HammingDistance"
:
AverageMeter
(
"HammingDistance"
)}
def
__call__
(
self
,
output
,
target
):
def
forward
(
self
,
output
,
target
):
preds
=
super
().
_
_call__
(
output
)
preds
=
super
().
_
multi_hot_encode
(
output
)
metric_dict
=
dict
()
metric_dict
=
dict
()
metric_dict
[
"HammingDistance"
]
=
paddle
.
to_tensor
(
metric_dict
[
"HammingDistance"
]
=
paddle
.
to_tensor
(
hamming_loss
(
target
,
preds
))
hamming_loss
(
target
,
preds
))
self
.
avg_meters
[
"HammingDistance"
].
update
(
metric_dict
[
"HammingDistance"
].
numpy
()[
0
],
output
.
shape
[
0
])
return
metric_dict
return
metric_dict
class
AccuracyScore
(
MutiLabelMetric
):
class
AccuracyScore
(
Mu
l
tiLabelMetric
):
"""
"""
Hard metric for multilabel classification
Hard metric for multilabel classification
Args:
Args:
...
@@ -289,8 +344,8 @@ class AccuracyScore(MutiLabelMetric):
...
@@ -289,8 +344,8 @@ class AccuracyScore(MutiLabelMetric):
],
'must be one of ["sample", "label"]'
],
'must be one of ["sample", "label"]'
self
.
base
=
base
self
.
base
=
base
def
__call__
(
self
,
output
,
target
):
def
forward
(
self
,
output
,
target
):
preds
=
super
().
_
_call__
(
output
)
preds
=
super
().
_
multi_hot_encode
(
output
)
metric_dict
=
dict
()
metric_dict
=
dict
()
if
self
.
base
==
"sample"
:
if
self
.
base
==
"sample"
:
accuracy
=
accuracy_metric
(
target
,
preds
)
accuracy
=
accuracy_metric
(
target
,
preds
)
...
@@ -303,4 +358,5 @@ class AccuracyScore(MutiLabelMetric):
...
@@ -303,4 +358,5 @@ class AccuracyScore(MutiLabelMetric):
accuracy
=
(
sum
(
tps
)
+
sum
(
tns
))
/
(
accuracy
=
(
sum
(
tps
)
+
sum
(
tns
))
/
(
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
metric_dict
[
"AccuracyScore"
]
=
paddle
.
to_tensor
(
accuracy
)
metric_dict
[
"AccuracyScore"
]
=
paddle
.
to_tensor
(
accuracy
)
self
.
avg_meters
[
"AccuracyScore"
].
update
(
metric_dict
[
"AccuracyScore"
].
numpy
()[
0
],
output
.
shape
[
0
])
return
metric_dict
return
metric_dict
ppcls/utils/misc.py
浏览文件 @
45b1296c
...
@@ -42,6 +42,10 @@ class AverageMeter(object):
...
@@ -42,6 +42,10 @@ class AverageMeter(object):
self
.
count
+=
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
self
.
avg
=
self
.
sum
/
self
.
count
@
property
def
avg_info
(
self
):
return
"{}: {:.5f}"
.
format
(
self
.
name
,
self
.
avg
)
@
property
@
property
def
total
(
self
):
def
total
(
self
):
return
'{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}'
.
format
(
return
'{self.name}_sum: {self.sum:{self.fmt}}{self.postfix}'
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录