Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
2949a20f
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2949a20f
编写于
7月 29, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename global_metric_state_vars
上级
c1a4a6b8
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
36 addition
and
33 deletion
+36
-33
core/metric.py
core/metric.py
+10
-8
core/metrics/binary_class/auc.py
core/metrics/binary_class/auc.py
+11
-10
core/metrics/binary_class/precision_recall.py
core/metrics/binary_class/precision_recall.py
+4
-4
core/metrics/pairwise_pn.py
core/metrics/pairwise_pn.py
+5
-5
core/metrics/recall_k.py
core/metrics/recall_k.py
+6
-6
未找到文件。
core/metric.py
浏览文件 @
2949a20f
...
...
@@ -32,15 +32,16 @@ class Metric(object):
scope
=
fluid
.
global_scope
()
place
=
fluid
.
CPUPlace
()
for
key
in
self
.
_global_communicate_var
:
varname
,
dtype
=
self
.
_global_communicate_var
[
key
]
if
scope
.
find_var
(
varname
)
is
None
:
for
key
in
self
.
_global_metric_state_vars
:
varname
,
dtype
=
self
.
_global_metric_state_vars
[
key
]
var
=
scope
.
find_var
(
varname
)
if
not
var
:
continue
var
=
scope
.
var
(
varname
)
.
get_tensor
()
var
=
var
.
get_tensor
()
data_array
=
np
.
zeros
(
var
.
_get_dims
()).
astype
(
dtype
)
var
.
set
(
data_array
,
place
)
def
get_global_metric
(
self
,
fleet
,
scope
,
metric_name
,
mode
=
"sum"
):
def
get_global_metric
_state
(
self
,
fleet
,
scope
,
metric_name
,
mode
=
"sum"
):
""" """
input
=
np
.
array
(
scope
.
find_var
(
metric_name
).
get_tensor
())
if
fleet
is
None
:
...
...
@@ -59,9 +60,10 @@ class Metric(object):
scope
=
fluid
.
global_scope
()
global_metrics
=
dict
()
for
key
in
self
.
_global_communicate_var
:
varname
,
dtype
=
self
.
_global_communicate_var
[
key
]
global_metrics
[
key
]
=
self
.
get_global_metric
(
fleet
,
scope
,
varname
)
for
key
in
self
.
_global_metric_state_vars
:
varname
,
dtype
=
self
.
_global_metric_state_vars
[
key
]
global_metrics
[
key
]
=
self
.
get_global_metric_state
(
fleet
,
scope
,
varname
)
return
self
.
calculate
(
global_metrics
)
...
...
core/metrics/binary_class/auc.py
浏览文件 @
2949a20f
...
...
@@ -59,15 +59,16 @@ class AUC(Metric):
sqrerr
,
abserr
,
prob
,
q
,
pos
,
total
=
\
fluid
.
contrib
.
layers
.
ctr_metric_bundle
(
prob
,
label_cast
)
self
.
_global_communicate_var
=
dict
()
self
.
_global_communicate_var
[
'stat_pos'
]
=
(
stat_pos
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'stat_neg'
]
=
(
stat_neg
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'total_ins_num'
]
=
(
total
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'pos_ins_num'
]
=
(
pos
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'q'
]
=
(
q
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'prob'
]
=
(
prob
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'abserr'
]
=
(
abserr
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'sqrerr'
]
=
(
sqrerr
.
name
,
"float32"
)
self
.
_global_metric_state_vars
=
dict
()
self
.
_global_metric_state_vars
[
'stat_pos'
]
=
(
stat_pos
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'stat_neg'
]
=
(
stat_neg
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'total_ins_num'
]
=
(
total
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'pos_ins_num'
]
=
(
pos
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'q'
]
=
(
q
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'prob'
]
=
(
prob
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'abserr'
]
=
(
abserr
.
name
,
"float32"
)
self
.
_global_metric_state_vars
[
'sqrerr'
]
=
(
sqrerr
.
name
,
"float32"
)
self
.
metrics
=
dict
()
self
.
metrics
[
"AUC"
]
=
auc_out
...
...
@@ -149,7 +150,7 @@ class AUC(Metric):
def
calculate
(
self
,
global_metrics
):
result
=
dict
()
for
key
in
self
.
_global_
communicate_var
:
for
key
in
self
.
_global_
metric_state_vars
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
result
[
key
]
=
global_metrics
[
key
][
0
]
...
...
core/metrics/binary_class/precision_recall.py
浏览文件 @
2949a20f
...
...
@@ -99,8 +99,8 @@ class PrecisionRecall(Metric):
batch_states
.
stop_gradient
=
True
states_info
.
stop_gradient
=
True
self
.
_global_
communicate_var
=
dict
()
self
.
_global_
communicate_var
[
'states_info'
]
=
(
states_info
.
name
,
self
.
_global_
metric_state_vars
=
dict
()
self
.
_global_
metric_state_vars
[
'states_info'
]
=
(
states_info
.
name
,
"float32"
)
self
.
metrics
=
dict
()
...
...
@@ -110,7 +110,7 @@ class PrecisionRecall(Metric):
# self.metrics["batch_metrics"] = batch_metrics
def
calculate
(
self
,
global_metrics
):
for
key
in
self
.
_global_
communicate_var
:
for
key
in
self
.
_global_
metric_state_vars
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
...
...
core/metrics/pairwise_pn.py
浏览文件 @
2949a20f
...
...
@@ -73,10 +73,10 @@ class PosNegRatio(Metric):
outputs
=
{
"Out"
:
[
global_wrong_cnt
]})
self
.
pn
=
(
global_right_cnt
+
1.0
)
/
(
global_wrong_cnt
+
1.0
)
self
.
_global_
communicate_var
=
dict
()
self
.
_global_
communicate_var
[
'right_cnt'
]
=
(
global_right_cnt
.
name
,
self
.
_global_
metric_state_vars
=
dict
()
self
.
_global_
metric_state_vars
[
'right_cnt'
]
=
(
global_right_cnt
.
name
,
"float32"
)
self
.
_global_
communicate_var
[
'wrong_cnt'
]
=
(
global_wrong_cnt
.
name
,
self
.
_global_
metric_state_vars
[
'wrong_cnt'
]
=
(
global_wrong_cnt
.
name
,
"float32"
)
self
.
metrics
=
dict
()
...
...
core/metrics/recall_k.py
浏览文件 @
2949a20f
...
...
@@ -75,10 +75,10 @@ class RecallK(Metric):
self
.
acc
=
global_pos_cnt
/
global_ins_cnt
self
.
_global_
communicate_var
=
dict
()
self
.
_global_
communicate_var
[
'ins_cnt'
]
=
(
global_ins_cnt
.
name
,
self
.
_global_
metric_state_vars
=
dict
()
self
.
_global_
metric_state_vars
[
'ins_cnt'
]
=
(
global_ins_cnt
.
name
,
"float32"
)
self
.
_global_
communicate_var
[
'pos_cnt'
]
=
(
global_pos_cnt
.
name
,
self
.
_global_
metric_state_vars
[
'pos_cnt'
]
=
(
global_pos_cnt
.
name
,
"float32"
)
metric_name
=
"Acc(Recall@%d)"
%
self
.
k
...
...
@@ -89,7 +89,7 @@ class RecallK(Metric):
# self.metrics["batch_metrics"] = batch_metrics
def
calculate
(
self
,
global_metrics
):
for
key
in
self
.
_global_
communicate_var
:
for
key
in
self
.
_global_
metric_state_vars
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
ins_cnt
=
global_metrics
[
'ins_cnt'
][
0
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录