未验证 提交 1ca2cde1 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #12962 from jacquesqiao/cherry-pick-fix-auc-init

fix auc layer and add check for auc op (#12954)
...@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
auc_out = helper.create_tmp_variable(dtype="float64") auc_out = helper.create_tmp_variable(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(persistable=True, dtype='int64') tp = helper.create_global_variable(
tn = helper.create_global_variable(persistable=True, dtype='int64') persistable=True, dtype='int64', shape=[num_thresholds])
fp = helper.create_global_variable(persistable=True, dtype='int64') tn = helper.create_global_variable(
fn = helper.create_global_variable(persistable=True, dtype='int64') persistable=True, dtype='int64', shape=[num_thresholds])
fp = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
fn = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
for var in [tp, tn, fp, fn]: for var in [tp, tn, fp, fn]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册