Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
330465d0
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
330465d0
编写于
7月 28, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
4310c411
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
302 addition
and
36 deletion
+302
-36
core/metric.py
core/metric.py
+32
-3
core/metrics/__init__.py
core/metrics/__init__.py
+1
-1
core/metrics/binary_class/auc.py
core/metrics/binary_class/auc.py
+128
-2
core/metrics/binary_class/precision_recall.py
core/metrics/binary_class/precision_recall.py
+54
-6
core/metrics/pairwise_pn.py
core/metrics/pairwise_pn.py
+18
-5
core/metrics/recall_k.py
core/metrics/recall_k.py
+24
-7
core/trainers/framework/runner.py
core/trainers/framework/runner.py
+45
-12
未找到文件。
core/metric.py
浏览文件 @
330465d0
...
...
@@ -26,7 +26,7 @@ class Metric(object):
""" """
pass
def
clear
(
self
,
scope
=
None
,
**
kwargs
):
def
clear
(
self
,
scope
=
None
):
"""
clear current value
Args:
...
...
@@ -37,20 +37,49 @@ class Metric(object):
scope
=
fluid
.
global_scope
()
place
=
fluid
.
CPUPlace
()
for
(
varname
,
dtype
)
in
self
.
_need_clear_list
:
for
key
in
self
.
_global_communicate_var
:
varname
,
dtype
=
self
.
_global_communicate_var
[
key
]
if
scope
.
find_var
(
varname
)
is
None
:
continue
var
=
scope
.
var
(
varname
).
get_tensor
()
data_array
=
np
.
zeros
(
var
.
_get_dims
()).
astype
(
dtype
)
var
.
set
(
data_array
,
place
)
def
calculate
(
self
,
scope
,
params
):
def
get_global_metric
(
self
,
fleet
,
scope
,
metric_name
,
mode
=
"sum"
):
"""
reduce metric named metric_name from all worker
Return:
metric reduce result
"""
input
=
np
.
array
(
scope
.
find_var
(
metric_name
).
get_tensor
())
if
fleet
is
None
:
return
input
fleet
.
_role_maker
.
_barrier_worker
()
old_shape
=
np
.
array
(
input
.
shape
)
input
=
input
.
reshape
(
-
1
)
output
=
np
.
copy
(
input
)
*
0
fleet
.
_role_maker
.
_all_reduce
(
input
,
output
,
mode
=
mode
)
output
=
output
.
reshape
(
old_shape
)
return
output
def
cal_global_metrics
(
self
,
fleet
,
scope
=
None
):
"""
calculate result
Args:
scope: value container
params: extend varilable for clear
"""
if
scope
is
None
:
scope
=
fluid
.
global_scope
()
global_metrics
=
dict
()
for
key
in
self
.
_global_communicate_var
:
varname
,
dtype
=
self
.
_global_communicate_var
[
key
]
global_metrics
[
key
]
=
self
.
get_global_metric
(
fleet
,
scope
,
varname
)
return
self
.
calculate
(
global_metrics
)
def
calculate
(
self
,
global_metrics
):
pass
@
abc
.
abstractmethod
...
...
core/metrics/__init__.py
浏览文件 @
330465d0
...
...
@@ -14,6 +14,6 @@
from
.recall_k
import
RecallK
from
.pairwise_pn
import
PosNegRatio
import
binary_class
from
.binary_class
import
*
__all__
=
[
'RecallK'
,
'PosNegRatio'
]
+
binary_class
.
__all__
core/metrics/binary_class/auc.py
浏览文件 @
330465d0
...
...
@@ -56,11 +56,137 @@ class AUC(Metric):
topk
=
topk
,
slide_steps
=
slide_steps
)
self
.
_need_clear_list
=
[(
stat_pos
.
name
,
"float32"
),
(
stat_neg
.
name
,
"float32"
)]
prob
=
fluid
.
layers
.
slice
(
predict
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
label_cast
=
fluid
.
layers
.
cast
(
label
,
dtype
=
"float32"
)
label_cast
.
stop_gradient
=
True
sqrerr
,
abserr
,
prob
,
q
,
pos
,
total
=
\
fluid
.
contrib
.
layers
.
ctr_metric_bundle
(
prob
,
label_cast
)
self
.
_global_communicate_var
=
dict
()
self
.
_global_communicate_var
[
'stat_pos'
]
=
(
stat_pos
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'stat_neg'
]
=
(
stat_neg
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'total_ins_num'
]
=
(
total
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'pos_ins_num'
]
=
(
pos
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'q'
]
=
(
q
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'prob'
]
=
(
prob
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'abserr'
]
=
(
abserr
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'sqrerr'
]
=
(
sqrerr
.
name
,
"float32"
)
self
.
metrics
=
dict
()
self
.
metrics
[
"AUC"
]
=
auc_out
self
.
metrics
[
"BATCH_AUC"
]
=
batch_auc_out
def
calculate_bucket_error
(
self
,
global_pos
,
global_neg
):
"""R
"""
num_bucket
=
len
(
global_pos
)
last_ctr
=
-
1.0
impression_sum
=
0.0
ctr_sum
=
0.0
click_sum
=
0.0
error_sum
=
0.0
error_count
=
0.0
click
=
0.0
show
=
0.0
ctr
=
0.0
adjust_ctr
=
0.0
relative_error
=
0.0
actual_ctr
=
0.0
relative_ctr_error
=
0.0
k_max_span
=
0.01
k_relative_error_bound
=
0.05
for
i
in
range
(
num_bucket
):
click
=
global_pos
[
i
]
show
=
global_pos
[
i
]
+
global_neg
[
i
]
ctr
=
float
(
i
)
/
num_bucket
if
abs
(
ctr
-
last_ctr
)
>
k_max_span
:
last_ctr
=
ctr
impression_sum
=
0.0
ctr_sum
=
0.0
click_sum
=
0.0
impression_sum
+=
show
ctr_sum
+=
ctr
*
show
click_sum
+=
click
if
impression_sum
==
0
:
continue
adjust_ctr
=
ctr_sum
/
impression_sum
if
adjust_ctr
==
0
:
continue
relative_error
=
\
math
.
sqrt
((
1
-
adjust_ctr
)
/
(
adjust_ctr
*
impression_sum
))
if
relative_error
<
k_relative_error_bound
:
actual_ctr
=
click_sum
/
impression_sum
relative_ctr_error
=
abs
(
actual_ctr
/
adjust_ctr
-
1
)
error_sum
+=
relative_ctr_error
*
impression_sum
error_count
+=
impression_sum
last_ctr
=
-
1
bucket_error
=
error_sum
/
error_count
if
error_count
>
0
else
0.0
return
bucket_error
def
calculate_auc
(
self
,
global_pos
,
global_neg
):
"""R
"""
num_bucket
=
len
(
global_pos
)
area
=
0.0
pos
=
0.0
neg
=
0.0
new_pos
=
0.0
new_neg
=
0.0
total_ins_num
=
0
for
i
in
range
(
num_bucket
):
index
=
num_bucket
-
1
-
i
new_pos
=
pos
+
global_pos
[
index
]
total_ins_num
+=
global_pos
[
index
]
new_neg
=
neg
+
global_neg
[
index
]
total_ins_num
+=
global_neg
[
index
]
area
+=
(
new_neg
-
neg
)
*
(
pos
+
new_pos
)
/
2
pos
=
new_pos
neg
=
new_neg
auc_value
=
None
if
pos
*
neg
==
0
or
total_ins_num
==
0
:
auc_value
=
0.5
else
:
auc_value
=
area
/
(
pos
*
neg
)
return
auc_value
def
calculate
(
self
,
global_metrics
):
result
=
dict
()
for
key
in
self
.
_global_communicate_var
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
result
[
key
]
=
global_metrics
[
key
][
0
]
if
result
[
'total_ins_num'
]
==
0
:
result
[
'auc'
]
=
0
result
[
'bucket_error'
]
=
0
result
[
'actual_ctr'
]
=
0
result
[
'predict_ctr'
]
=
0
result
[
'mae'
]
=
0
result
[
'rmse'
]
=
0
result
[
'copc'
]
=
0
result
[
'mean_q'
]
=
0
else
:
result
[
'auc'
]
=
self
.
calculate_auc
(
result
[
'stat_pos'
],
result
[
'stat_neg'
])
result
[
'bucket_error'
]
=
self
.
calculate_auc
(
result
[
'stat_pos'
],
result
[
'stat_neg'
])
result
[
'actual_ctr'
]
=
result
[
'pos_ins_num'
]
/
result
[
'total_ins_num'
]
result
[
'mae'
]
=
result
[
'abserr'
]
/
result
[
'total_ins_num'
]
result
[
'rmse'
]
=
math
.
sqrt
(
result
[
'sqrerr'
]
/
result
[
'total_ins_num'
])
result
[
'predict_ctr'
]
=
result
[
'prob'
]
/
result
[
'total_ins_num'
]
if
abs
(
result
[
'predict_ctr'
])
>
1e-6
:
result
[
'copc'
]
=
result
[
'actual_ctr'
]
/
result
[
'predict_ctr'
]
result
[
'mean_q'
]
=
result
[
'q'
]
/
result
[
'total_ins_num'
]
result_str
=
"AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f "
\
"Actural_CTR=%.6f Predicted_CTR=%.6f COPC=%.6f MEAN Q_VALUE=%.6f Ins number=%s"
%
\
(
result
[
'auc'
],
result
[
'bucket_error'
],
result
[
'mae'
],
result
[
'rmse'
],
result
[
'actual_ctr'
],
result
[
'predict_ctr'
],
result
[
'copc'
],
result
[
'mean_q'
],
result
[
'total_ins_num'
])
return
result_str
def
get_result
(
self
):
return
self
.
metrics
core/metrics/binary_class/precision_recall.py
浏览文件 @
330465d0
...
...
@@ -36,7 +36,7 @@ class PrecisionRecall(Metric):
"PrecisionRecall expect input, label and class_num as inputs."
)
predict
=
kwargs
.
get
(
"input"
)
label
=
kwargs
.
get
(
"label"
)
num_cls
=
kwargs
.
get
(
"class_num"
)
self
.
num_cls
=
kwargs
.
get
(
"class_num"
)
if
not
isinstance
(
predict
,
Variable
):
raise
ValueError
(
"input must be Variable, but received %s"
%
...
...
@@ -56,7 +56,7 @@ class PrecisionRecall(Metric):
name
=
"states_info"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
num_cls
,
4
])
shape
=
[
self
.
num_cls
,
4
])
states_info
.
stop_gradient
=
True
helper
.
set_variable_initializer
(
...
...
@@ -75,12 +75,12 @@ class PrecisionRecall(Metric):
shape
=
[
6
])
batch_states
=
fluid
.
layers
.
fill_constant
(
shape
=
[
num_cls
,
4
],
value
=
0.0
,
dtype
=
"float32"
)
shape
=
[
self
.
num_cls
,
4
],
value
=
0.0
,
dtype
=
"float32"
)
batch_states
.
stop_gradient
=
True
helper
.
append_op
(
type
=
"precision_recall"
,
attrs
=
{
'class_number'
:
num_cls
},
attrs
=
{
'class_number'
:
self
.
num_cls
},
inputs
=
{
'MaxProbs'
:
[
max_probs
],
'Indices'
:
[
indices
],
...
...
@@ -100,13 +100,61 @@ class PrecisionRecall(Metric):
batch_states
.
stop_gradient
=
True
states_info
.
stop_gradient
=
True
self
.
_need_clear_list
=
[(
"states_info"
,
"float32"
)]
self
.
_global_communicate_var
=
dict
()
self
.
_global_communicate_var
[
'states_info'
]
=
(
states_info
.
name
,
"float32"
)
self
.
metrics
=
dict
()
self
.
metrics
[
"precision_recall_f1"
]
=
accum_metrics
self
.
metrics
[
"
accum_states
"
]
=
states_info
self
.
metrics
[
"
[TP FP TN FN]
"
]
=
states_info
# self.metrics["batch_metrics"] = batch_metrics
def
calculate
(
self
,
global_metrics
):
for
key
in
self
.
_global_communicate_var
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
def
calc_precision
(
tp_count
,
fp_count
):
if
tp_count
>
0.0
or
fp_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fp_count
)
return
1.0
def
calc_recall
(
tp_count
,
fn_count
):
if
tp_count
>
0.0
or
fn_count
>
0.0
:
return
tp_count
/
(
tp_count
+
fn_count
)
return
1.0
def
calc_f1_score
(
precision
,
recall
):
if
precision
>
0.0
or
recall
>
0.0
:
return
2
*
precision
*
recall
/
(
precision
+
recall
)
return
0.0
states
=
global_metrics
[
"states_info"
]
total_tp_count
=
0.0
total_fp_count
=
0.0
total_fn_count
=
0.0
macro_avg_precision
=
0.0
macro_avg_recall
=
0.0
for
i
in
range
(
self
.
num_cls
):
total_tp_count
+=
states
[
i
][
0
]
total_fp_count
+=
states
[
i
][
1
]
total_fn_count
+=
states
[
i
][
3
]
macro_avg_precision
+=
calc_precision
(
states
[
i
][
0
],
states
[
i
][
1
])
macro_avg_recall
+=
calc_recall
(
states
[
i
][
0
],
states
[
i
][
3
])
metrics
=
[]
macro_avg_precision
/=
self
.
num_cls
macro_avg_recall
/=
self
.
num_cls
metrics
.
append
(
macro_avg_precision
)
metrics
.
append
(
macro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
macro_avg_precision
,
macro_avg_recall
))
micro_avg_precision
=
calc_precision
(
total_tp_count
,
total_fp_count
)
metrics
.
append
(
micro_avg_precision
)
micro_avg_recall
=
calc_recall
(
total_tp_count
,
total_fn_count
)
metrics
.
append
(
micro_avg_recall
)
metrics
.
append
(
calc_f1_score
(
micro_avg_precision
,
micro_avg_recall
))
return
"total metrics: [TP, FP, TN, FN]=%s; precision_recall_f1=%s"
%
(
str
(
states
),
str
(
np
.
array
(
metrics
).
astype
(
'float32'
)))
def
get_result
(
self
):
return
self
.
metrics
core/metrics/pairwise_pn.py
浏览文件 @
330465d0
...
...
@@ -74,13 +74,26 @@ class PosNegRatio(Metric):
outputs
=
{
"Out"
:
[
global_wrong_cnt
]})
self
.
pn
=
(
global_right_cnt
+
1.0
)
/
(
global_wrong_cnt
+
1.0
)
self
.
_need_clear_list
=
[(
"right_cnt"
,
"float32"
),
(
"wrong_cnt"
,
"float32"
)]
self
.
_global_communicate_var
=
dict
()
self
.
_global_communicate_var
[
'right_cnt'
]
=
(
global_right_cnt
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'wrong_cnt'
]
=
(
global_wrong_cnt
.
name
,
"float32"
)
self
.
metrics
=
dict
()
self
.
metrics
[
'wrong_cnt'
]
=
global_wrong_cnt
self
.
metrics
[
'right_cnt'
]
=
global_right_cnt
self
.
metrics
[
'pos_neg_ratio'
]
=
self
.
pn
self
.
metrics
[
'WrongCnt'
]
=
global_wrong_cnt
self
.
metrics
[
'RightCnt'
]
=
global_right_cnt
self
.
metrics
[
'PN'
]
=
self
.
pn
def
calculate
(
self
,
global_metrics
):
for
key
in
self
.
_global_communicate_var
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
pn
=
(
global_metrics
[
'right_cnt'
][
0
]
+
1.0
)
/
(
global_metrics
[
'wrong_cnt'
][
0
]
+
1.0
)
return
"RightCnt=%s WrongCnt=%s PN=%s"
%
(
str
(
global_metrics
[
'right_cnt'
][
0
]),
str
(
global_metrics
[
'wrong_cnt'
][
0
]),
str
(
pn
))
def
get_result
(
self
):
return
self
.
metrics
core/metrics/recall_k.py
浏览文件 @
330465d0
...
...
@@ -35,7 +35,7 @@ class RecallK(Metric):
raise
ValueError
(
"RecallK expect input and label as inputs."
)
predict
=
kwargs
.
get
(
'input'
)
label
=
kwargs
.
get
(
'label'
)
k
=
kwargs
.
get
(
"k"
,
20
)
self
.
k
=
kwargs
.
get
(
"k"
,
20
)
if
not
isinstance
(
predict
,
Variable
):
raise
ValueError
(
"input must be Variable, but received %s"
%
...
...
@@ -45,7 +45,7 @@ class RecallK(Metric):
type
(
label
))
helper
=
LayerHelper
(
"PaddleRec_RecallK"
,
**
kwargs
)
batch_accuracy
=
accuracy
(
predict
,
label
,
k
)
batch_accuracy
=
accuracy
(
predict
,
label
,
self
.
k
)
global_ins_cnt
,
_
=
helper
.
create_or_get_global_variable
(
name
=
"ins_cnt"
,
persistable
=
True
,
dtype
=
'float32'
,
shape
=
[
1
])
global_pos_cnt
,
_
=
helper
.
create_or_get_global_variable
(
...
...
@@ -75,14 +75,31 @@ class RecallK(Metric):
self
.
acc
=
global_pos_cnt
/
global_ins_cnt
self
.
_need_clear_list
=
[(
"ins_cnt"
,
"float32"
),
(
"pos_cnt"
,
"float32"
)]
self
.
_global_communicate_var
=
dict
()
self
.
_global_communicate_var
[
'ins_cnt'
]
=
(
global_ins_cnt
.
name
,
"float32"
)
self
.
_global_communicate_var
[
'pos_cnt'
]
=
(
global_pos_cnt
.
name
,
"float32"
)
metric_name
=
"
Recall@%d_ACC"
%
k
metric_name
=
"
Acc(Recall@%d)"
%
self
.
k
self
.
metrics
=
dict
()
self
.
metrics
[
"
ins_c
nt"
]
=
global_ins_cnt
self
.
metrics
[
"
pos_c
nt"
]
=
global_pos_cnt
self
.
metrics
[
"
InsC
nt"
]
=
global_ins_cnt
self
.
metrics
[
"
RecallC
nt"
]
=
global_pos_cnt
self
.
metrics
[
metric_name
]
=
self
.
acc
# self.metrics["batch_metrics"] = batch_metrics
def
calculate
(
self
,
global_metrics
):
for
key
in
self
.
_global_communicate_var
:
if
key
not
in
global_metrics
:
raise
ValueError
(
"%s not existed"
%
key
)
ins_cnt
=
global_metrics
[
'ins_cnt'
][
0
]
pos_cnt
=
global_metrics
[
'pos_cnt'
][
0
]
if
ins_cnt
==
0
:
acc
=
0
else
:
acc
=
float
(
pos_cnt
)
/
ins_cnt
return
"InsCnt=%s RecallCnt=%s Acc(Recall@%d)=%s"
%
(
str
(
ins_cnt
),
str
(
pos_cnt
),
self
.
k
,
str
(
acc
))
def
get_result
(
self
):
return
self
.
metrics
core/trainers/framework/runner.py
浏览文件 @
330465d0
...
...
@@ -20,6 +20,7 @@ import numpy as np
import
paddle.fluid
as
fluid
from
paddlerec.core.utils
import
envs
from
paddlerec.core.metric
import
Metric
__all__
=
[
"RunnerBase"
,
"SingleRunner"
,
"PSRunner"
,
"CollectiveRunner"
,
"PslibRunner"
...
...
@@ -344,17 +345,27 @@ class SingleRunner(RunnerBase):
".epochs"
))
for
epoch
in
range
(
epochs
):
for
model_dict
in
context
[
"phases"
]:
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
metrics
=
model_class
.
_metric
begin_time
=
time
.
time
()
result
=
self
.
_run
(
context
,
model_dict
)
end_time
=
time
.
time
()
seconds
=
end_time
-
begin_time
message
=
"epoch {} done, use time: {}"
.
format
(
epoch
,
seconds
)
if
not
result
is
None
:
for
key
in
result
:
if
key
.
upper
().
startswith
(
"BATCH_"
):
continue
message
+=
", {}: {}"
.
format
(
key
,
result
[
key
])
metrics_result
=
[]
for
key
in
metrics
:
if
isinstance
(
metrics
[
key
],
Metric
):
_str
=
metrics
[
key
].
cal_global_metrics
(
None
,
context
[
"model"
][
model_dict
[
"name"
]][
"scope"
])
elif
result
is
not
None
:
_str
=
"{}={}"
.
format
(
key
,
result
[
key
])
metrics_result
.
append
(
_str
)
if
len
(
metrics_result
)
>
0
:
message
+=
", global metrics: "
+
", "
.
join
(
metrics_result
)
print
(
message
)
with
fluid
.
scope_guard
(
context
[
"model"
][
model_dict
[
"name"
]][
"scope"
]):
train_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
...
...
@@ -376,12 +387,26 @@ class PSRunner(RunnerBase):
envs
.
get_global_env
(
"runner."
+
context
[
"runner_name"
]
+
".epochs"
))
model_dict
=
context
[
"env"
][
"phase"
][
0
]
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
metrics
=
model_class
.
_metrics
for
epoch
in
range
(
epochs
):
begin_time
=
time
.
time
()
self
.
_run
(
context
,
model_dict
)
result
=
self
.
_run
(
context
,
model_dict
)
end_time
=
time
.
time
()
seconds
=
end_time
-
begin_time
print
(
"epoch {} done, use time: {}"
.
format
(
epoch
,
seconds
))
message
=
"epoch {} done, use time: {}"
.
format
(
epoch
,
seconds
)
metrics_result
=
[]
for
key
in
metrics
:
if
isinstance
(
metrics
[
key
],
Metric
):
_str
=
metrics
[
key
].
cal_global_metrics
(
context
[
"fleet"
],
context
[
"model"
][
model_dict
[
"name"
]][
"scope"
])
elif
result
is
not
None
:
_str
=
"{}={}"
.
format
(
key
,
result
[
key
])
metrics_result
.
append
(
_str
)
if
len
(
metrics_result
)
>
0
:
message
+=
", global metrics: "
+
", "
.
join
(
metrics_result
)
print
(
message
)
with
fluid
.
scope_guard
(
context
[
"model"
][
model_dict
[
"name"
]][
"scope"
]):
train_prog
=
context
[
"model"
][
model_dict
[
"name"
]][
...
...
@@ -491,6 +516,8 @@ class SingleInferRunner(RunnerBase):
self
.
epoch_model_name_list
.
sort
()
for
index
,
epoch_name
in
enumerate
(
self
.
epoch_model_name_list
):
for
model_dict
in
context
[
"phases"
]:
model_class
=
context
[
"model"
][
model_dict
[
"name"
]][
"model"
]
metrics
=
model_class
.
_infer_results
self
.
_load
(
context
,
model_dict
,
self
.
epoch_model_path_list
[
index
])
begin_time
=
time
.
time
()
...
...
@@ -499,11 +526,17 @@ class SingleInferRunner(RunnerBase):
seconds
=
end_time
-
begin_time
message
=
"Infer {} of epoch {} done, use time: {}"
.
format
(
model_dict
[
"name"
],
epoch_name
,
seconds
)
if
not
result
is
None
:
for
key
in
result
:
if
key
.
upper
().
startswith
(
"BATCH_"
):
continue
message
+=
", {}: {}"
.
format
(
key
,
result
[
key
])
metrics_result
=
[]
for
key
in
metrics
:
if
isinstance
(
metrics
[
key
],
Metric
):
_str
=
metrics
[
key
].
cal_global_metrics
(
None
,
context
[
"model"
][
model_dict
[
"name"
]][
"scope"
])
elif
result
is
not
None
:
_str
=
"{}={}"
.
format
(
key
,
result
[
key
])
metrics_result
.
append
(
_str
)
if
len
(
metrics_result
)
>
0
:
message
+=
", global metrics: "
+
", "
.
join
(
metrics_result
)
print
(
message
)
context
[
"status"
]
=
"terminal_pass"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录