未验证 提交 dc713116 编写于 作者: W wangchaochaohu 提交者: GitHub

refine the error message for bath size like OP (#27446)

* refine the error message for bath size like
上级 5c8fdb59
......@@ -26,25 +26,47 @@ class BatchSizeLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of %s should not be null.", Type());
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of %s should not be null.", Type());
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", Type());
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", Type());
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE_GT(shape.size(), 0);
PADDLE_ENFORCE_GT(shape.size(), 0,
platform::errors::InvalidArgument(
"Shape size must be larger than 0, but received: %s.",
shape.size()));
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto output_dim = framework::make_ddim(shape_int64);
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
PADDLE_ENFORCE_GE(input_dim_idx, 0);
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
int input_dim_size = static_cast<int>(ctx->GetInputDim("Input").size());
PADDLE_ENFORCE_GE(input_dim_idx, 0,
platform::errors::InvalidArgument(
"Input dimension index must be larger "
"equal than 0, but received: %s.",
input_dim_idx));
PADDLE_ENFORCE_GT(input_dim_size, input_dim_idx,
platform::errors::InvalidArgument(
"Input dimension size must be larger than "
"input dimension index, but received input "
"dimension size: %s, input dimension index: %s.",
input_dim_size, input_dim_idx));
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
PADDLE_ENFORCE_GE(output_dim_idx, 0);
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
int output_dim_size = static_cast<int>(shape.size());
PADDLE_ENFORCE_GE(output_dim_idx, 0,
platform::errors::InvalidArgument(
"Output dimension index must be larger "
"equal than 0, but received: %s.",
output_dim_idx));
PADDLE_ENFORCE_GT(
output_dim_size, output_dim_idx,
platform::errors::InvalidArgument(
"Output dimension size must be larger than output dimension index, "
"but received output dimension size: %s, output dimension index: "
"%s.",
output_dim_size, output_dim_idx));
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
ctx->SetOutputDim("Out", output_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册