From 4a105f803e6bd5789b81fd6a0a9f4b27add4e911 Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Thu, 14 May 2020 16:17:22 +0800 Subject: [PATCH] SamplingID Op fix error print (#24521) * fix error print for sampling_id_op * fix spell err * fix spell err test=develop --- paddle/fluid/operators/sampling_id_op.cc | 20 ++++++++++++-------- paddle/fluid/operators/sampling_id_op.h | 12 +++++++++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index bf0739557bc..128bc3f7fac 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SamplingIdOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SamplingIdOp should not be null."); - PADDLE_ENFORCE_LT(ctx->Attrs().Get("min"), - ctx->Attrs().Get("max"), "min must less then max"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut"); + PADDLE_ENFORCE_LT( + ctx->Attrs().Get("min"), ctx->Attrs().Get("max"), + platform::errors::InvalidArgument( + "min must less then max, but here min is %f, max is %f", + ctx->Attrs().Get("min"), ctx->Attrs().Get("max"))); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(input_dims.size() == 2, - "Input(X, Filter) should be 2-D tensor."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 2, + platform::errors::InvalidArgument( + "Input(X, Filter) should be 2-D tensor. But X dim is %d", + input_dims.size())); auto dim0 = input_dims[0]; framework::DDim dims = framework::make_ddim({dim0}); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 133d3f72dbd..5ec32c98f7f 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + PADDLE_ENFORCE_GE( + batch_size, 0, + platform::errors::InvalidArgument( + "batch_size(dims[0]) must be nonnegative. but it is %d.", + batch_size)); + PADDLE_ENFORCE_GE( + width, 0, + platform::errors::InvalidArgument( + "width(dims[1]) must be nonnegative. but it is %d.", width)); std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); -- GitLab