未验证 提交 7c9ce097 编写于 作者: T Tao Luo 提交者: GitHub

refine reshape_op shape error message (#22480)

test=develop
上级 2b1386b2
...@@ -29,10 +29,11 @@ inline std::vector<int> get_new_shape( ...@@ -29,10 +29,11 @@ inline std::vector<int> get_new_shape(
auto tensor = list_new_shape_tensor[i]; auto tensor = list_new_shape_tensor[i];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensor->dims(), framework::make_ddim({1}), tensor->dims(), framework::make_ddim({1}),
"ShapeError: If the element type of 'shape' in ReshapeOp is Tensor, " platform::errors::InvalidArgument(
"the element's shape must be [1]. But received the element's shape " "If the element type of 'shape' in ReshapeOp is Tensor, "
"is [%s]", "the element's shape must be [1]. But received the element's shape "
tensor->dims()); "is [%s]",
tensor->dims()));
if (platform::is_gpu_place(tensor->place())) { if (platform::is_gpu_place(tensor->place())) {
framework::Tensor temp; framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp); TensorCopySync(*tensor, platform::CPUPlace(), &temp);
...@@ -64,10 +65,11 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -64,10 +65,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto ShapeTensor = ctx->Inputs("ShapeTensor"); auto ShapeTensor = ctx->Inputs("ShapeTensor");
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
ShapeTensor.size(), 0, ShapeTensor.size(), 0,
"ShapeError: When `shape` in ReshapeOp is a list or tuple " platform::errors::InvalidArgument(
"which contains Tensor, the shape's size can't be zero. " "When `shape` in ReshapeOp is a list or tuple "
"But received shape's size is %d.", "which contains Tensor, the shape's size can't be zero. "
ShapeTensor.size()); "But received shape's size is %d.",
ShapeTensor.size()));
auto infer_shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto infer_shape = ctx->Attrs().Get<std::vector<int>>("shape");
const int64_t copy_dim_val = 0; const int64_t copy_dim_val = 0;
auto in_dims = ctx->GetInputDim("X"); auto in_dims = ctx->GetInputDim("X");
...@@ -75,10 +77,11 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -75,10 +77,11 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (infer_shape[i] == copy_dim_val) { if (infer_shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
static_cast<int>(i), in_dims.size(), static_cast<int>(i), in_dims.size(),
"ShapeError: The index of 0 in `shape` must be less than " platform::errors::InvalidArgument(
"the input tensor X's dimensions. But received shape[%d] " "The index of 0 in `shape` must be less than "
"= 0, X's dimensions = %d, X's shape = [%s].", "the input tensor X's dimensions. But received shape[%d] "
i, in_dims.size(), in_dims); "= 0, X's dimensions = %d, X's shape = [%s].",
i, in_dims.size(), in_dims));
infer_shape[i] = in_dims[i]; infer_shape[i] = in_dims[i];
} }
} }
...@@ -108,10 +111,10 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -108,10 +111,10 @@ class ReshapeOp : public framework::OperatorWithKernel {
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(!shape.empty(), true,
!shape.empty(), true, platform::errors::InvalidArgument(
"ShapeError: The parameter 'shape' in ReshapeOp must be set. " "The parameter 'shape' in ReshapeOp must be set. "
"But received 'shape' is empty."); "But received 'shape' is empty."));
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ValidateShape(shape, x_dims); auto out_dims = ValidateShape(shape, x_dims);
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
...@@ -140,25 +143,28 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -140,25 +143,28 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (shape[i] == unk_dim_val) { if (shape[i] == unk_dim_val) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
unk_dim_idx, -1, unk_dim_idx, -1,
"ShapeError: Only one dimension value of 'shape' in ReshapeOp can " platform::errors::InvalidArgument(
"be -1. But received shape = [%s], shape[%d] is also -1.", "Only one dimension value of 'shape' in ReshapeOp can "
framework::make_ddim(shape), i); "be -1. But received shape = [%s], shape[%d] is also -1.",
framework::make_ddim(shape), i));
unk_dim_idx = i; unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) { } else if (shape[i] == copy_dim_val) {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
static_cast<int>(i), in_dims.size(), static_cast<int>(i), in_dims.size(),
"ShapeError: The index of 0 in `shape` must be less than " platform::errors::InvalidArgument(
"the input tensor X's dimensions. " "The index of 0 in `shape` must be less than "
"But received shape = [%s], shape[%d] = 0, X's shape = [%s], " "the input tensor X's dimensions. "
"X's dimensions = %d.", "But received shape = [%s], shape[%d] = 0, X's shape = [%s], "
framework::make_ddim(shape), i, in_dims, in_dims.size()); "X's dimensions = %d.",
framework::make_ddim(shape), i, in_dims, in_dims.size()));
} else { } else {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
shape[i], 0, shape[i], 0,
"ShapeError: Each dimension value of 'shape' in ReshapeOp must not " platform::errors::InvalidArgument(
"be negtive except one unknown dimension. " "Each dimension value of 'shape' in ReshapeOp must not "
"But received shape = [%s], shape[%d] = %d.", "be negtive except one unknown dimension. "
framework::make_ddim(shape), i, shape[i]); "But received shape = [%s], shape[%d] = %d.",
framework::make_ddim(shape), i, shape[i]));
} }
capacity *= (shape[i] ? shape[i] : in_dims[i]); capacity *= (shape[i] ? shape[i] : in_dims[i]);
...@@ -180,8 +186,7 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -180,8 +186,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
"The input tensor X'size must be divisible by known " "The input tensor X'size must be divisible by known "
"capacity of 'shape'. " "capacity of 'shape'. "
"But received X's shape = [%s], X's size = %d, " "But received X's shape = [%s], X's size = %d, "
"'shape' is [%s], known " "'shape' is [%s], known capacity of 'shape' is %d.",
"capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity)); in_dims, in_size, framework::make_ddim(shape), capacity));
} else { } else {
output_shape[unk_dim_idx] = -1; output_shape[unk_dim_idx] = -1;
...@@ -190,12 +195,13 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -190,12 +195,13 @@ class ReshapeOp : public framework::OperatorWithKernel {
if (all_positive) { if (all_positive) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
capacity, in_size, capacity, in_size,
"ShapeError: The 'shape' in ReshapeOp is invalid. " platform::errors::InvalidArgument(
"The input tensor X'size must be equal to the capacity of 'shape'. " "The 'shape' in ReshapeOp is invalid. "
"But received X's shape = [%s], X's size = %d, 'shape' is [%s], " "The input tensor X'size must be equal to the capacity of "
"the " "'shape'. "
"capacity of 'shape' is %d.", "But received X's shape = [%s], X's size = %d, 'shape' is "
in_dims, in_size, framework::make_ddim(shape), capacity); "[%s], the capacity of 'shape' is %d.",
in_dims, in_size, framework::make_ddim(shape), capacity));
} }
} }
return framework::make_ddim(output_shape); return framework::make_ddim(output_shape);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册