Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
0a3ecf60
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
接近 2 年 前同步成功
通知
116
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a3ecf60
编写于
5月 11, 2022
作者:
Z
zhiboniu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add attribute strongbaseline
上级
675e60d5
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
464 addition
and
24 deletion
+464
-24
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
+98
-0
ppcls/configs/Attr/StrongBaselineAttr.yaml
ppcls/configs/Attr/StrongBaselineAttr.yaml
+110
-0
ppcls/data/__init__.py
ppcls/data/__init__.py
+1
-0
ppcls/engine/evaluation/classification.py
ppcls/engine/evaluation/classification.py
+51
-24
ppcls/loss/__init__.py
ppcls/loss/__init__.py
+1
-0
ppcls/loss/bceloss.py
ppcls/loss/bceloss.py
+59
-0
ppcls/metric/__init__.py
ppcls/metric/__init__.py
+1
-0
ppcls/metric/metrics.py
ppcls/metric/metrics.py
+58
-0
ppcls/utils/misc.py
ppcls/utils/misc.py
+84
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
0a3ecf60
...
@@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
...
@@ -70,6 +70,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
from
ppcls.arch.backbone.model_zoo.strongbaseline_attr
import
StrongBaselineAttr
# help whl get all the models' api (class type) and components' api (func type)
# help whl get all the models' api (class type) and components' api (func type)
...
...
ppcls/arch/backbone/model_zoo/strongbaseline_attr.py
0 → 100644
浏览文件 @
0a3ecf60
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle
from
paddle
import
ParamAttr
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.nn
import
Conv2D
,
BatchNorm
,
Linear
,
Dropout
from
paddle.nn
import
AdaptiveAvgPool2D
,
MaxPool2D
,
AvgPool2D
from
paddle.nn.initializer
import
Uniform
import
math
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
,
get_weights_path_from_url
from
..legendary_models.resnet
import
ResNet50
MODEL_URLS
=
{
"StrongBaselineAttr"
:
"strongbaseline_attr_clas"
,
}
__all__
=
list
(
MODEL_URLS
.
keys
())
class
StrongBaselinePAR
(
nn
.
Layer
):
def
__init__
(
self
,
**
config
,
):
"""
A strong baseline for Pedestrian Attribute Recognition, see https://arxiv.org/abs/2107.03576
Args:
backbone (object): backbone instance
classifier (object): classifier instance
loss (object): loss instance
"""
super
(
StrongBaselinePAR
,
self
).
__init__
()
backbone_config
=
config
[
"Backbone"
]
backbone_name
=
backbone_config
.
pop
(
"name"
)
self
.
backbone
=
eval
(
backbone_name
)(
**
backbone_config
)
def
forward
(
self
,
x
):
fc_feat
=
self
.
backbone
(
x
)
output
=
F
.
sigmoid
(
fc_feat
)
return
output
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
load_pretrained
(
model
,
local_weight_path
):
# local_weight_path = get_weights_path_from_url(model_url).replace(
# ".pdparams", "")
param_state_dict
=
paddle
.
load
(
local_weight_path
+
".pdparams"
)
model_dict
=
model
.
state_dict
()
model_dict_keys
=
list
(
model_dict
.
keys
())
param_state_dict_keys
=
list
(
param_state_dict
.
keys
())
# assert(len(model_dict_keys) == len(param_state_dict_keys)), "{} == {}".format(len(model_dict_keys), len(param_state_dict_keys))
for
idx
in
range
(
len
(
model_dict
.
keys
())):
model_key
=
model_dict_keys
[
idx
]
param_key
=
param_state_dict_keys
[
idx
]
if
model_dict
[
model_key
].
shape
==
param_state_dict
[
param_key
].
shape
:
model_dict
[
model_key
]
=
param_state_dict
[
param_key
]
else
:
print
(
"miss match idx: {} weights: {} vs {}; {} vs {}"
.
format
(
idx
,
model_key
,
param_key
,
model_dict
[
model_key
].
shape
,
param_state_dict
[
param_key
].
shape
))
model
.
set_dict
(
model_dict
)
def
StrongBaselineAttr
(
pretrained
=
True
,
use_ssld
=
False
,
**
kwargs
):
model
=
StrongBaselinePAR
(
**
kwargs
)
_load_pretrained
(
MODEL_URLS
[
"StrongBaselineAttr"
],
model
,
None
,
None
)
return
model
ppcls/configs/Attr/StrongBaselineAttr.yaml
0 → 100644
浏览文件 @
0a3ecf60
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
"
./output/"
device
:
"
gpu"
save_interval
:
5
eval_during_train
:
False
eval_interval
:
1
epochs
:
30
print_batch_step
:
20
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
256
,
192
]
save_inference_dir
:
"
./inference"
use_multilabel
:
True
metric_attr
:
True
# model architecture
Arch
:
name
:
"
StrongBaselineAttr"
Backbone
:
name
:
"
ResNet50"
class_num
:
26
# loss function config for traing/eval process
Loss
:
Train
:
-
BCELoss
:
weight
:
1.0
Eval
:
-
BCELoss
:
weight
:
1.0
Optimizer
:
name
:
Adam
lr
:
name
:
Piecewise
decay_epochs
:
[
12
,
18
,
24
,
28
]
values
:
[
0.0001
,
0.00001
,
0.000001
,
0.0000001
]
regularizer
:
name
:
'
L2'
coeff
:
0.0005
clip_norm
:
10
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
AttrDataset
image_root
:
"
dataset/xingrenfenxi/data/"
cls_label_path
:
"
dataset/xingrenfenxi/all_qiye.pkl"
split
:
'
trainval'
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
# - ResizeImage:
# size: [192, 256]
-
RandCropImage
:
size
:
[
192
,
256
]
scale
:
[
0.9
,
1.1
]
ratio
:
[
0.75
,
0.75
]
-
RandFlipImage
:
flip_code
:
1
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
AttrDataset
image_root
:
"
dataset/xingrenfenxi/data/"
cls_label_path
:
"
dataset/xingrenfenxi/all_qiye.pkl"
split
:
'
test'
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
size
:
[
192
,
256
]
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
64
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Metric
:
Eval
:
-
ATTRMetric
:
ppcls/data/__init__.py
浏览文件 @
0a3ecf60
...
@@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
...
@@ -30,6 +30,7 @@ from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.mix_dataset
import
MixDataset
from
ppcls.data.dataloader.multi_scale_dataset
import
MultiScaleDataset
from
ppcls.data.dataloader.multi_scale_dataset
import
MultiScaleDataset
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.person_dataset
import
Market1501
,
MSMT17
from
ppcls.data.dataloader.attr_dataset
import
AttrDataset
# sampler
# sampler
...
...
ppcls/engine/evaluation/classification.py
浏览文件 @
0a3ecf60
...
@@ -18,7 +18,7 @@ import time
...
@@ -18,7 +18,7 @@ import time
import
platform
import
platform
import
paddle
import
paddle
from
ppcls.utils.misc
import
AverageMeter
from
ppcls.utils.misc
import
AverageMeter
,
AttrMeter
from
ppcls.utils
import
logger
from
ppcls.utils
import
logger
...
@@ -32,6 +32,10 @@ def classification_eval(engine, epoch_id=0):
...
@@ -32,6 +32,10 @@ def classification_eval(engine, epoch_id=0):
}
}
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
print_batch_step
=
engine
.
config
[
"Global"
][
"print_batch_step"
]
if
engine
.
eval_metric_func
is
not
None
and
engine
.
config
[
"Global"
][
"metric_attr"
]:
output_info
[
"attr"
]
=
AttrMeter
(
threshold
=
0.5
)
metric_key
=
None
metric_key
=
None
tic
=
time
.
time
()
tic
=
time
.
time
()
accum_samples
=
0
accum_samples
=
0
...
@@ -121,17 +125,22 @@ def classification_eval(engine, epoch_id=0):
...
@@ -121,17 +125,22 @@ def classification_eval(engine, epoch_id=0):
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
output_info
[
key
].
update
(
loss_dict
[
key
].
numpy
()[
0
],
current_samples
)
current_samples
)
# calc metric
# calc metric
if
engine
.
eval_metric_func
is
not
None
:
if
engine
.
eval_metric_func
is
not
None
:
metric_dict
=
engine
.
eval_metric_func
(
preds
,
labels
)
if
engine
.
config
[
"Global"
][
"metric_attr"
]:
for
key
in
metric_dict
:
metric_dict
=
engine
.
eval_metric_func
(
preds
,
labels
)
if
metric_key
is
None
:
metric_key
=
"attr"
metric_key
=
key
output_info
[
"attr"
].
update
(
metric_dict
)
if
key
not
in
output_info
:
else
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
metric_dict
=
engine
.
eval_metric_func
(
preds
,
labels
)
for
key
in
metric_dict
:
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
if
metric_key
is
None
:
current_samples
)
metric_key
=
key
if
key
not
in
output_info
:
output_info
[
key
]
=
AverageMeter
(
key
,
'7.5f'
)
output_info
[
key
].
update
(
metric_dict
[
key
].
numpy
()[
0
],
current_samples
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
time_info
[
"batch_cost"
].
update
(
time
.
time
()
-
tic
)
...
@@ -144,10 +153,13 @@ def classification_eval(engine, epoch_id=0):
...
@@ -144,10 +153,13 @@ def classification_eval(engine, epoch_id=0):
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
ips_msg
=
"ips: {:.5f} images/sec"
.
format
(
batch_size
/
time_info
[
"batch_cost"
].
avg
)
batch_size
/
time_info
[
"batch_cost"
].
avg
)
metric_msg
=
", "
.
join
([
if
engine
.
config
[
"Global"
][
"metric_attr"
]:
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
metric_msg
=
""
for
key
in
output_info
else
:
])
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
val
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
logger
.
info
(
"[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}"
.
format
(
epoch_id
,
iter_id
,
epoch_id
,
iter_id
,
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
len
(
engine
.
eval_dataloader
),
metric_msg
,
time_msg
,
ips_msg
))
...
@@ -155,13 +167,28 @@ def classification_eval(engine, epoch_id=0):
...
@@ -155,13 +167,28 @@ def classification_eval(engine, epoch_id=0):
tic
=
time
.
time
()
tic
=
time
.
time
()
if
engine
.
use_dali
:
if
engine
.
use_dali
:
engine
.
eval_dataloader
.
reset
()
engine
.
eval_dataloader
.
reset
()
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
if
engine
.
config
[
"Global"
][
"metric_attr"
]:
])
metric_msg
=
", "
.
join
([
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}"
.
format
(
*
output_info
[
"attr"
].
res
())
# do not try to save best eval.model
])
if
engine
.
eval_metric_func
is
None
:
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
return
-
1
# return 1st metric in the dict
# do not try to save best eval.model
return
output_info
[
metric_key
].
avg
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
"attr"
].
res
()[
0
]
else
:
metric_msg
=
", "
.
join
([
"{}: {:.5f}"
.
format
(
key
,
output_info
[
key
].
avg
)
for
key
in
output_info
])
logger
.
info
(
"[Eval][Epoch {}][Avg]{}"
.
format
(
epoch_id
,
metric_msg
))
# do not try to save best eval.model
if
engine
.
eval_metric_func
is
None
:
return
-
1
# return 1st metric in the dict
return
output_info
[
metric_key
].
avg
ppcls/loss/__init__.py
浏览文件 @
0a3ecf60
...
@@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss
...
@@ -26,6 +26,7 @@ from .distillationloss import DistillationKLDivLoss
from
.distillationloss
import
DistillationDKDLoss
from
.distillationloss
import
DistillationDKDLoss
from
.multilabelloss
import
MultiLabelLoss
from
.multilabelloss
import
MultiLabelLoss
from
.afdloss
import
AFDLoss
from
.afdloss
import
AFDLoss
from
.bceloss
import
BCELoss
from
.deephashloss
import
DSHSDLoss
from
.deephashloss
import
DSHSDLoss
from
.deephashloss
import
LCDSHLoss
from
.deephashloss
import
LCDSHLoss
...
...
ppcls/loss/bceloss.py
0 → 100644
浏览文件 @
0a3ecf60
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
def
ratio2weight
(
targets
,
ratio
):
# print(targets)
pos_weights
=
targets
*
(
1.
-
ratio
)
neg_weights
=
(
1.
-
targets
)
*
ratio
weights
=
paddle
.
exp
(
neg_weights
+
pos_weights
)
# for RAP dataloader, targets element may be 2, with or without smooth, some element must great than 1
weights
=
weights
-
weights
*
(
targets
>
1
)
return
weights
class
BCELoss
(
nn
.
Layer
):
"""BCE Loss.
Args:
"""
def
__init__
(
self
,
sample_weight
=
True
,
size_sum
=
True
,
smoothing
=
None
,
weight
=
1.0
):
super
(
BCELoss
,
self
).
__init__
()
self
.
sample_weight
=
sample_weight
self
.
size_sum
=
size_sum
self
.
hyper
=
0.8
self
.
smoothing
=
smoothing
def
forward
(
self
,
logits
,
labels
):
targets
,
ratio
=
labels
if
self
.
smoothing
is
not
None
:
targets
=
(
1
-
self
.
smoothing
)
*
targets
+
self
.
smoothing
*
(
1
-
targets
)
targets
=
paddle
.
cast
(
targets
,
'float32'
)
loss_m
=
F
.
binary_cross_entropy_with_logits
(
logits
,
targets
,
reduction
=
'none'
)
targets_mask
=
paddle
.
cast
(
targets
>
0.5
,
'float32'
)
if
self
.
sample_weight
:
weight
=
ratio2weight
(
targets_mask
,
ratio
[
0
])
weight
=
weight
*
(
targets
>
-
1
)
loss_m
=
loss_m
*
weight
loss
=
loss_m
.
sum
(
1
).
mean
()
if
self
.
size_sum
else
loss_m
.
sum
()
return
{
"BCELoss"
:
loss
}
ppcls/metric/__init__.py
浏览文件 @
0a3ecf60
...
@@ -20,6 +20,7 @@ from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
...
@@ -20,6 +20,7 @@ from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
DistillationTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
from
.metrics
import
GoogLeNetTopkAcc
from
.metrics
import
HammingDistance
,
AccuracyScore
from
.metrics
import
HammingDistance
,
AccuracyScore
from
.metrics
import
ATTRMetric
class
CombinedMetrics
(
nn
.
Layer
):
class
CombinedMetrics
(
nn
.
Layer
):
...
...
ppcls/metric/metrics.py
浏览文件 @
0a3ecf60
...
@@ -22,6 +22,8 @@ from sklearn.metrics import accuracy_score as accuracy_metric
...
@@ -22,6 +22,8 @@ from sklearn.metrics import accuracy_score as accuracy_metric
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.metrics
import
multilabel_confusion_matrix
from
sklearn.preprocessing
import
binarize
from
sklearn.preprocessing
import
binarize
from
easydict
import
EasyDict
class
TopkAcc
(
nn
.
Layer
):
class
TopkAcc
(
nn
.
Layer
):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
def
__init__
(
self
,
topk
=
(
1
,
5
)):
...
@@ -308,3 +310,59 @@ class AccuracyScore(MutiLabelMetric):
...
@@ -308,3 +310,59 @@ class AccuracyScore(MutiLabelMetric):
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
sum
(
tps
)
+
sum
(
tns
)
+
sum
(
fns
)
+
sum
(
fps
))
metric_dict
[
"AccuracyScore"
]
=
paddle
.
to_tensor
(
accuracy
)
metric_dict
[
"AccuracyScore"
]
=
paddle
.
to_tensor
(
accuracy
)
return
metric_dict
return
metric_dict
def
get_attr_metrics
(
gt_label
,
preds_probs
,
threshold
):
"""
index: evaluated label index
"""
pred_label
=
(
preds_probs
>
threshold
).
astype
(
int
)
eps
=
1e-20
result
=
EasyDict
()
has_fuyi
=
gt_label
==
-
1
pred_label
[
has_fuyi
]
=
-
1
###############################
# label metrics
# TP + FN
result
.
gt_pos
=
np
.
sum
((
gt_label
==
1
),
axis
=
0
).
astype
(
float
)
# TN + FP
result
.
gt_neg
=
np
.
sum
((
gt_label
==
0
),
axis
=
0
).
astype
(
float
)
# TP
result
.
true_pos
=
np
.
sum
((
gt_label
==
1
)
*
(
pred_label
==
1
),
axis
=
0
).
astype
(
float
)
# TN
result
.
true_neg
=
np
.
sum
((
gt_label
==
0
)
*
(
pred_label
==
0
),
axis
=
0
).
astype
(
float
)
# FP
result
.
false_pos
=
np
.
sum
(((
gt_label
==
0
)
*
(
pred_label
==
1
)),
axis
=
0
).
astype
(
float
)
# FN
result
.
false_neg
=
np
.
sum
(((
gt_label
==
1
)
*
(
pred_label
==
0
)),
axis
=
0
).
astype
(
float
)
################
# instance metrics
result
.
gt_pos_ins
=
np
.
sum
((
gt_label
==
1
),
axis
=
1
).
astype
(
float
)
result
.
true_pos_ins
=
np
.
sum
((
pred_label
==
1
),
axis
=
1
).
astype
(
float
)
# true positive
result
.
intersect_pos
=
np
.
sum
((
gt_label
==
1
)
*
(
pred_label
==
1
),
axis
=
1
).
astype
(
float
)
# IOU
result
.
union_pos
=
np
.
sum
(((
gt_label
==
1
)
+
(
pred_label
==
1
)),
axis
=
1
).
astype
(
float
)
return
result
class
ATTRMetric
(
nn
.
Layer
):
def
__init__
(
self
,
threshold
=
0.5
):
super
().
__init__
()
self
.
threshold
=
threshold
def
__call__
(
self
,
output
,
target
):
metric_dict
=
get_attr_metrics
(
target
[
0
].
numpy
(),
output
.
numpy
(),
self
.
threshold
)
return
metric_dict
ppcls/utils/misc.py
浏览文件 @
0a3ecf60
...
@@ -61,3 +61,87 @@ class AverageMeter(object):
...
@@ -61,3 +61,87 @@ class AverageMeter(object):
def
value
(
self
):
def
value
(
self
):
return
'{self.name}: {self.val:{self.fmt}}{self.postfix}'
.
format
(
return
'{self.name}: {self.val:{self.fmt}}{self.postfix}'
.
format
(
self
=
self
)
self
=
self
)
class
AttrMeter
(
object
):
"""
Computes and stores the average and current value
Code was based on https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def
__init__
(
self
,
threshold
=
0.5
):
self
.
threshold
=
threshold
self
.
reset
()
def
reset
(
self
):
self
.
gt_pos
=
0
self
.
gt_neg
=
0
self
.
true_pos
=
0
self
.
true_neg
=
0
self
.
false_pos
=
0
self
.
false_neg
=
0
self
.
gt_pos_ins
=
[]
self
.
true_pos_ins
=
[]
self
.
intersect_pos
=
[]
self
.
union_pos
=
[]
def
update
(
self
,
metric_dict
):
self
.
gt_pos
+=
metric_dict
[
'gt_pos'
]
self
.
gt_neg
+=
metric_dict
[
'gt_neg'
]
self
.
true_pos
+=
metric_dict
[
'true_pos'
]
self
.
true_neg
+=
metric_dict
[
'true_neg'
]
self
.
false_pos
+=
metric_dict
[
'false_pos'
]
self
.
false_neg
+=
metric_dict
[
'false_neg'
]
self
.
gt_pos_ins
+=
metric_dict
[
'gt_pos_ins'
].
tolist
()
self
.
true_pos_ins
+=
metric_dict
[
'true_pos_ins'
].
tolist
()
self
.
intersect_pos
+=
metric_dict
[
'intersect_pos'
].
tolist
()
self
.
union_pos
+=
metric_dict
[
'union_pos'
].
tolist
()
def
res
(
self
):
import
numpy
as
np
eps
=
1e-20
label_pos_recall
=
1.0
*
self
.
true_pos
/
(
self
.
gt_pos
+
eps
)
# true positive
label_neg_recall
=
1.0
*
self
.
true_neg
/
(
self
.
gt_neg
+
eps
)
# true negative
# mean accuracy
label_ma
=
(
label_pos_recall
+
label_neg_recall
)
/
2
label_pos_recall
=
np
.
mean
(
label_pos_recall
)
label_neg_recall
=
np
.
mean
(
label_neg_recall
)
label_prec
=
(
self
.
true_pos
/
(
self
.
true_pos
+
self
.
false_pos
+
eps
))
label_acc
=
(
self
.
true_pos
/
(
self
.
true_pos
+
self
.
false_pos
+
self
.
false_neg
+
eps
))
label_f1
=
np
.
mean
(
2
*
label_prec
*
label_pos_recall
/
(
label_prec
+
label_pos_recall
+
eps
))
ma
=
(
np
.
mean
(
label_ma
))
self
.
gt_pos_ins
=
np
.
array
(
self
.
gt_pos_ins
)
self
.
true_pos_ins
=
np
.
array
(
self
.
true_pos_ins
)
self
.
intersect_pos
=
np
.
array
(
self
.
intersect_pos
)
self
.
union_pos
=
np
.
array
(
self
.
union_pos
)
instance_acc
=
self
.
intersect_pos
/
(
self
.
union_pos
+
eps
)
instance_prec
=
self
.
intersect_pos
/
(
self
.
true_pos_ins
+
eps
)
instance_recall
=
self
.
intersect_pos
/
(
self
.
gt_pos_ins
+
eps
)
instance_f1
=
2
*
instance_prec
*
instance_recall
/
(
instance_prec
+
instance_recall
+
eps
)
instance_acc
=
np
.
mean
(
instance_acc
)
instance_prec
=
np
.
mean
(
instance_prec
)
instance_recall
=
np
.
mean
(
instance_recall
)
instance_f1
=
2
*
instance_prec
*
instance_recall
/
(
instance_prec
+
instance_recall
+
eps
)
instance_acc
=
np
.
mean
(
instance_acc
)
instance_prec
=
np
.
mean
(
instance_prec
)
instance_recall
=
np
.
mean
(
instance_recall
)
instance_f1
=
np
.
mean
(
instance_f1
)
res
=
[
ma
,
label_f1
,
label_pos_recall
,
label_neg_recall
,
instance_f1
,
instance_acc
,
instance_prec
,
instance_recall
]
return
res
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录