From 63309941b3f13d56afb863bf7c257ee284857028 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 9 Oct 2017 17:51:17 +0800 Subject: [PATCH] pull develop and update --- paddle/operators/auc_op.cc | 21 +++++++++++---------- paddle/operators/auc_op.h | 6 ++---- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index e7275a5933..d8cecf0957 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), - "Input of Inference must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input of Inference must be initialized."); - auto *inference = ctx.Input("Inference"); - auto *label = ctx.Input("Label"); - - PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input of Label must be initialized."); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(inference_dim, label_dim, "inference and label should have same shape"); - ctx.Output("AUC")->Resize({1}); + ctx->SetOutputDim("AUC", {1}); + ctx->ShareLoD("Inference", /*->*/ "AUC"); } }; diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index ad5585be30..be6ef29d5f 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -27,7 +26,7 @@ template ; template -class AucKernel : public framework::OpKernel { +class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); @@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel { } // Create local tensor for storing the curve: TP, FN, TN, FP - // TODO(typhoonzero): put these tensors in Scope - // TODO(typhoonzero): use op to caculate these values. + // TODO(typhoonzero): use eigen op to caculate these values. Tensor true_positive, false_positive, true_negative, false_negative; true_positive.Resize({num_thresholds}); -- GitLab