From 95aa53425dc20e519bc054be08782859c311f16c Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 14 Oct 2020 14:36:41 +0800 Subject: [PATCH] update the code for the topk message optimize update the code for the topk message optimize --- paddle/fluid/operators/top_k_v2_op.cc | 32 ++++++++++++++------------- paddle/fluid/operators/top_k_v2_op.cu | 9 +++++--- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/top_k_v2_op.cc b/paddle/fluid/operators/top_k_v2_op.cc index 0e3fcced19..810afc901d 100644 --- a/paddle/fluid/operators/top_k_v2_op.cc +++ b/paddle/fluid/operators/top_k_v2_op.cc @@ -23,20 +23,18 @@ class TopkV2Op : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of TopkOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of TopkOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Indices"), - "Output(Indices) of TopkOp should not be null."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "topk_v2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "topk_v2"); + OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "topk_v2"); auto input_dims = ctx->GetInputDim("X"); const int& dim_size = input_dims.size(); int axis = static_cast(ctx->Attrs().Get("axis")); - PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true, - "the axis of topk" - "must be [-%d, %d), but you set axis is %d", - dim_size, dim_size, axis); + PADDLE_ENFORCE_EQ( + (axis < dim_size) && (axis >= (-1 * dim_size)), true, + paddle::platform::errors::InvalidArgument( + "the axis of topk must be [-%d, %d), but you set axis is %d", + dim_size, dim_size, axis)); if (axis < 0) axis += dim_size; @@ -47,18 +45,22 @@ class TopkV2Op : public framework::OperatorWithKernel { } else { k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_EQ(k >= 1, true, - "the attribute of k in the topk must >= 1 or be a " - "Tensor, but received %d .", - k); + paddle::platform::errors::InvalidArgument( + "the attribute of k in the topk must >= 1 or be a " + "Tensor, but received %d .", + k)); } PADDLE_ENFORCE_GE(input_dims.size(), 1, - "input of topk must have >= 1d shape"); + paddle::platform::errors::InvalidArgument( + "input of topk must have >= 1d shape")); if (ctx->IsRuntime()) { PADDLE_ENFORCE_GE( input_dims[axis], k, - "input of topk op must have >= %d columns in axis of %d", k, axis); + paddle::platform::errors::InvalidArgument( + "input of topk op must have >= %d columns in axis of %d", k, + axis)); } framework::DDim dims = input_dims; diff --git a/paddle/fluid/operators/top_k_v2_op.cu b/paddle/fluid/operators/top_k_v2_op.cu index 2c94dca1e3..a2c97aee92 100644 --- a/paddle/fluid/operators/top_k_v2_op.cu +++ b/paddle/fluid/operators/top_k_v2_op.cu @@ -38,8 +38,10 @@ template class TopkV2OpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "It must use CUDAPlace, you must check your device set.")); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); @@ -194,7 +196,8 @@ class TopkV2OpGradCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ( platform::is_gpu_place(context.GetPlace()), true, - platform::errors::InvalidArgument("It must use CUDAPlace.")); + platform::errors::InvalidArgument( + "It must use CUDAPlace, you must check your device set.")); auto* x = context.Input("X"); auto* out_grad = context.Input(framework::GradVarName("Out")); auto* indices = context.Input("Indices"); -- GitLab