未验证 提交 852c8db3 编写于 作者: Z zhangbo9674 提交者: GitHub

refine diagonal infermeta (#49520)

上级 d0e9b18e
......@@ -605,7 +605,6 @@ void DiagonalInferMeta(const MetaTensor& input,
int offset_ = offset;
int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1;
int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2;
PADDLE_ENFORCE_GE(
x_dims.size(),
2,
......@@ -621,6 +620,15 @@ void DiagonalInferMeta(const MetaTensor& input,
-(x_dims.size()),
(x_dims.size() - 1),
axis1));
PADDLE_ENFORCE_GE(
axis1_,
0,
phi::errors::OutOfRange(
"Attr(axis1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
axis1));
PADDLE_ENFORCE_LT(
axis2_,
x_dims.size(),
......@@ -630,6 +638,15 @@ void DiagonalInferMeta(const MetaTensor& input,
-(x_dims.size()),
(x_dims.size() - 1),
axis2));
PADDLE_ENFORCE_GE(
axis2_,
0,
phi::errors::OutOfRange(
"Attr(axis2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
axis2));
PADDLE_ENFORCE_NE(
axis1_,
axis2_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册