From 852c8db38cea0baecad9756b01903fa98e992395 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 4 Jan 2023 17:35:13 +0800 Subject: [PATCH] refine diagonal infermeta (#49520) --- paddle/phi/infermeta/unary.cc | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 131d504795..d7f4971724 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -605,7 +605,6 @@ void DiagonalInferMeta(const MetaTensor& input, int offset_ = offset; int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; - PADDLE_ENFORCE_GE( x_dims.size(), 2, @@ -621,6 +620,15 @@ void DiagonalInferMeta(const MetaTensor& input, -(x_dims.size()), (x_dims.size() - 1), axis1)); + PADDLE_ENFORCE_GE( + axis1_, + 0, + phi::errors::OutOfRange( + "Attr(axis1) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis1)); PADDLE_ENFORCE_LT( axis2_, x_dims.size(), @@ -630,6 +638,15 @@ void DiagonalInferMeta(const MetaTensor& input, -(x_dims.size()), (x_dims.size() - 1), axis2)); + PADDLE_ENFORCE_GE( + axis2_, + 0, + phi::errors::OutOfRange( + "Attr(axis2) is out of range (expected to be in range of [%ld, " + "%ld], but got %ld).", + -(x_dims.size()), + (x_dims.size() - 1), + axis2)); PADDLE_ENFORCE_NE( axis1_, axis2_, -- GitLab