From d1e6d5522a437ae592e8a2e2126e6ff50d9c7d08 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 12 Sep 2017 21:03:55 +0800 Subject: [PATCH] update --- paddle/operators/auc_op.cc | 4 ++-- paddle/operators/auc_op.h | 32 ++++++++++++++++---------------- paddle/pybind/pybind.cc | 1 + 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index fa18d6ca0d2..3a43f9bcc48 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { -class AccuracyOp : public framework::OperatorWithKernel { +class AucOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -76,5 +76,5 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AccuracyOp, ops::AccuracyOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel); diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index d4f40cd79c6..fd110c06e64 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -23,7 +23,7 @@ namespace operators { using Tensor = framework::Tensor; template -class AccuracyKernel : public framework::OpKernel { +class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); @@ -45,7 +45,7 @@ class AccuracyKernel : public framework::OpKernel { thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; const int* inference_data = inference->data(); - const T* inference_prob_data = inference->data(); + const T* inference_prob_data = inference_prob->data(); const T* label_data = label->data(); size_t num_samples = inference->dims()[0]; @@ -54,17 +54,17 @@ class AccuracyKernel : public framework::OpKernel { // create local tensor for storing the curve: TP, FN, TN, FP // TODO(typhoonzero): put these tensors in Scope // TODO(typhoonzero): use op to caculate these values. - Tensor true_positive, false_positeve, true_negative, false_negative; + Tensor true_positive, false_positive, true_negative, false_negative; true_positive.Resize({num_thresholds}); false_negative.Resize({num_thresholds}); true_negative.Resize({num_thresholds}); false_positive.Resize({num_thresholds}); - int* tp_data = true_positive.mutable_data(); - int* fn_data = false_negative.mutable_data(); - int* tn_data = true_negative.mutable_data(); - int* fp_data = false_positive.mutable_data(); + int* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int* fp_data = false_positive.mutable_data(ctx.GetPlace()); for (auto thresh = thresholds_list.begin(); thresh != thresholds_list.end(); thresh++) { @@ -101,15 +101,15 @@ class AccuracyKernel : public framework::OpKernel { tp_rate.Resize({num_thresholds}); fp_rate.Resize({num_thresholds}); rec_rate.Resize({num_thresholds}); - float* tp_rate_data = tp_rate.mutable_data(); - float* fp_rate_data = fp_rate.mutable_data(); - float* rec_rate_data = rec_rate.mutable_data(); + float* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); + float* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); + float* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); for (int i = 0; i < num_thresholds; i++) { - tp_rate_data[i] = ((float)tp_data[i + epsilon) / (tp_data[i] + fn_data[i] + epsilon); - fp_rate_data[i] = - (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); - rec_rate_data[i] = - ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); + tp_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); + fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); } if (curve == "ROC") { @@ -118,7 +118,7 @@ class AccuracyKernel : public framework::OpKernel { auto y = (tp_rate_data[i] + tp_rate_data[i - 1]) / 2.0f; *auc_data = *auc_data + dx * y; } - } else if (curve = "PR") { + } else if (curve == "PR") { for (int i = 1; i < num_thresholds; i++) { auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 53985933ed1..a673b7d1a87 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -50,6 +50,7 @@ USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); USE_OP(top_k); +USE_CPU_ONLY_OP(auc); USE_OP(squared_l2_distance); namespace paddle { -- GitLab