diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc index e7275a5933c03308dc6ec6b71193ef0c4984627a..d8cecf0957c6cb296b20746422b83d3402c11d11 100644 --- a/paddle/operators/auc_op.cc +++ b/paddle/operators/auc_op.cc @@ -22,18 +22,19 @@ class AucOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Inference"), - "Input of Inference must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input of Inference must be initialized."); - auto *inference = ctx.Input("Inference"); - auto *label = ctx.Input("Label"); - - PADDLE_ENFORCE_EQ(inference->dims(), label->dims(), + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input of Label must be initialized."); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(inference_dim, label_dim, "inference and label should have same shape"); - ctx.Output("AUC")->Resize({1}); + ctx->SetOutputDim("AUC", {1}); + ctx->ShareLoD("Inference", /*->*/ "AUC"); } }; diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h index ad5585be30481c555bf9cc8996fecfb32addd515..be6ef29d5f6cff5b9ebdf7d8564b2e2792c3b5cb 100644 --- a/paddle/operators/auc_op.h +++ b/paddle/operators/auc_op.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" @@ -27,7 +26,7 @@ template ; template -class AucKernel : public framework::OpKernel { +class AucKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* inference = ctx.Input("Inference"); @@ -61,8 +60,7 @@ class AucKernel : public framework::OpKernel { } // Create local tensor for storing the curve: TP, FN, TN, FP - // TODO(typhoonzero): put these tensors in Scope - // TODO(typhoonzero): use op to caculate these values. + // TODO(typhoonzero): use eigen op to caculate these values. Tensor true_positive, false_positive, true_negative, false_negative; true_positive.Resize({num_thresholds});