未验证 提交 852c8db3 编写于 作者: Z zhangbo9674 提交者: GitHub

refine diagonal infermeta (#49520)

上级 d0e9b18e
...@@ -605,7 +605,6 @@ void DiagonalInferMeta(const MetaTensor& input, ...@@ -605,7 +605,6 @@ void DiagonalInferMeta(const MetaTensor& input,
int offset_ = offset; int offset_ = offset;
int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1;
int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
x_dims.size(), x_dims.size(),
2, 2,
...@@ -621,6 +620,15 @@ void DiagonalInferMeta(const MetaTensor& input, ...@@ -621,6 +620,15 @@ void DiagonalInferMeta(const MetaTensor& input,
-(x_dims.size()), -(x_dims.size()),
(x_dims.size() - 1), (x_dims.size() - 1),
axis1)); axis1));
PADDLE_ENFORCE_GE(
axis1_,
0,
phi::errors::OutOfRange(
"Attr(axis1) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
axis1));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis2_, axis2_,
x_dims.size(), x_dims.size(),
...@@ -630,6 +638,15 @@ void DiagonalInferMeta(const MetaTensor& input, ...@@ -630,6 +638,15 @@ void DiagonalInferMeta(const MetaTensor& input,
-(x_dims.size()), -(x_dims.size()),
(x_dims.size() - 1), (x_dims.size() - 1),
axis2)); axis2));
PADDLE_ENFORCE_GE(
axis2_,
0,
phi::errors::OutOfRange(
"Attr(axis2) is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size()),
(x_dims.size() - 1),
axis2));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
axis1_, axis1_,
axis2_, axis2_,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册