未验证 提交 3c58b87b 编写于 作者: Q Qiao Longfei 提交者: GitHub

fix auc layer and add check for auc op (#12954)

* fix auc layer and add check for auc op

* use input to check if states are inited

* optimize code
上级 c1488b17
...@@ -60,6 +60,20 @@ class AucKernel : public framework::OpKernel<T> { ...@@ -60,6 +60,20 @@ class AucKernel : public framework::OpKernel<T> {
const T* inference_data = predict->data<T>(); const T* inference_data = predict->data<T>();
const auto* label_data = label->data<int64_t>(); const auto* label_data = label->data<int64_t>();
// check if states are inited.
auto* tp_in = ctx.Input<Tensor>("TP");
auto* fp_in = ctx.Input<Tensor>("FP");
auto* tn_in = ctx.Input<Tensor>("TN");
auto* fn_in = ctx.Input<Tensor>("FN");
PADDLE_ENFORCE(tp_in->IsInitialized(), "true_positive is not inited!");
PADDLE_ENFORCE(fp_in->IsInitialized(), "false_negative is not inited!");
PADDLE_ENFORCE(tn_in->IsInitialized(), "true_negative is not inited!");
PADDLE_ENFORCE(fn_in->IsInitialized(), "false_positive is not inited!");
PADDLE_ENFORCE_EQ(tp_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(fp_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(tn_in->numel(), num_thresholds, "");
PADDLE_ENFORCE_EQ(fn_in->numel(), num_thresholds, "");
auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace()); auto* tp_data = true_positive->mutable_data<int64_t>(ctx.GetPlace());
auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* fn_data = false_negative->mutable_data<int64_t>(ctx.GetPlace());
auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace()); auto* tn_data = true_negative->mutable_data<int64_t>(ctx.GetPlace());
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include <sys/time.h> #include <sys/time.h>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <random>
#include <vector> #include <vector>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "glog/logging.h" #include "glog/logging.h"
......
...@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1): ...@@ -119,10 +119,14 @@ def auc(input, label, curve='ROC', num_thresholds=200, topk=1):
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
auc_out = helper.create_tmp_variable(dtype="float64") auc_out = helper.create_tmp_variable(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
tp = helper.create_global_variable(persistable=True, dtype='int64') tp = helper.create_global_variable(
tn = helper.create_global_variable(persistable=True, dtype='int64') persistable=True, dtype='int64', shape=[num_thresholds])
fp = helper.create_global_variable(persistable=True, dtype='int64') tn = helper.create_global_variable(
fn = helper.create_global_variable(persistable=True, dtype='int64') persistable=True, dtype='int64', shape=[num_thresholds])
fp = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
fn = helper.create_global_variable(
persistable=True, dtype='int64', shape=[num_thresholds])
for var in [tp, tn, fp, fn]: for var in [tp, tn, fp, fn]:
helper.set_variable_initializer( helper.set_variable_initializer(
var, Constant( var, Constant(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册