diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index bf0739557bcda4a0b953f43df086b404b6fbe41b..128bc3f7fac5044c1c6503b1b7db9b76c5234b7b 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -24,16 +24,20 @@ class SamplingIdOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SamplingIdOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SamplingIdOp should not be null."); - PADDLE_ENFORCE_LT(ctx->Attrs().Get("min"), - ctx->Attrs().Get("max"), "min must less then max"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SampleIn"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "X", "SampleOut"); + PADDLE_ENFORCE_LT( + ctx->Attrs().Get("min"), ctx->Attrs().Get("max"), + platform::errors::InvalidArgument( + "min must less then max, but here min is %f, max is %f", + ctx->Attrs().Get("min"), ctx->Attrs().Get("max"))); auto input_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(input_dims.size() == 2, - "Input(X, Filter) should be 2-D tensor."); + PADDLE_ENFORCE_EQ( + input_dims.size(), 2, + platform::errors::InvalidArgument( + "Input(X, Filter) should be 2-D tensor. But X dim is %d", + input_dims.size())); auto dim0 = input_dims[0]; framework::DDim dims = framework::make_ddim({dim0}); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h index 133d3f72dbd6ab13c98d124369038309c94cba5b..5ec32c98f7f84abb255ec996d0cf6a58e6312ec3 100644 --- a/paddle/fluid/operators/sampling_id_op.h +++ b/paddle/fluid/operators/sampling_id_op.h @@ -36,9 +36,15 @@ class SamplingIdKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); const int width = static_cast(input->dims()[1]); - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + PADDLE_ENFORCE_GE( + batch_size, 0, + platform::errors::InvalidArgument( + "batch_size(dims[0]) must be nonnegative. but it is %d.", + batch_size)); + PADDLE_ENFORCE_GE( + width, 0, + platform::errors::InvalidArgument( + "width(dims[1]) must be nonnegative. but it is %d.", width)); std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector);