diff --git a/paddle/fluid/operators/batch_size_like.h b/paddle/fluid/operators/batch_size_like.h index d2cf38049300578eb1626d39c0959b9beed13cdd..f24a3c316a05a8bf171812be0a6b3445488aeb58 100644 --- a/paddle/fluid/operators/batch_size_like.h +++ b/paddle/fluid/operators/batch_size_like.h @@ -26,25 +26,47 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of %s should not be null.", Type()); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of %s should not be null.", Type()); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type()); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type()); auto &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE_GT(shape.size(), 0); + PADDLE_ENFORCE_GT(shape.size(), 0, + platform::errors::InvalidArgument( + "Shape size must be larger than 0, but received: %s.", + shape.size())); std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { return static_cast(a); }); auto output_dim = framework::make_ddim(shape_int64); int input_dim_idx = ctx->Attrs().Get("input_dim_idx"); - PADDLE_ENFORCE_GE(input_dim_idx, 0); - PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx); + int input_dim_size = static_cast(ctx->GetInputDim("Input").size()); + PADDLE_ENFORCE_GE(input_dim_idx, 0, + platform::errors::InvalidArgument( + "Input dimension index must be larger " + "equal than 0, but received: %s.", + input_dim_idx)); + PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx, + platform::errors::InvalidArgument( + "Input dimension size must be larger than " + "input dimension index, but received input " + "dimension size: %s, input dimension index: %s.", + input_dim_size, input_dim_idx)); int output_dim_idx = ctx->Attrs().Get("output_dim_idx"); - PADDLE_ENFORCE_GE(output_dim_idx, 0); - PADDLE_ENFORCE_GT(static_cast(shape.size()), output_dim_idx); + int output_dim_size = static_cast(shape.size()); + PADDLE_ENFORCE_GE(output_dim_idx, 0, + platform::errors::InvalidArgument( + "Output dimension index must be larger " + "equal than 0, but received: %s.", + output_dim_idx)); + PADDLE_ENFORCE_GT( + output_dim_size, output_dim_idx, + platform::errors::InvalidArgument( + "Output dimension size must be larger than output dimension index, " + "but received output dimension size: %s, output dimension index: " + "%s.", + output_dim_size, output_dim_idx)); output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx]; ctx->SetOutputDim("Out", output_dim);