未验证 提交 226a6567 编写于 作者: G gouzil 提交者: GitHub

[Divide by 0 Error] add eig check (#49971)

* [Divide by 0 Error] add eig check

* [Divide by 0 Error] eig check migrate to c++

* [Divide by 0 Error] Fix class name error
上级 5dfddaea
...@@ -31,6 +31,11 @@ void EigKernel(const Context& dev_ctx, ...@@ -31,6 +31,11 @@ void EigKernel(const Context& dev_ctx,
int batch_count = BatchCount(x); int batch_count = BatchCount(x);
int order = x.dims()[x.dims().size() - 1]; int order = x.dims()[x.dims().size() - 1];
PADDLE_ENFORCE_LT(0,
order,
errors::InvalidArgument(
"The order of Input(X) should be greater than 0."));
DenseTensor real_w; DenseTensor real_w;
DenseTensor real_v; DenseTensor real_v;
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestEigAPIError(unittest.TestCase):
def test_errors(self):
# The size of input in Eig should not be 0.
def test_0_size():
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.eig(x)
self.assertRaises(ValueError, test_0_size)
if __name__ == '__main__':
unittest.main()
...@@ -2323,6 +2323,7 @@ def eig(x, name=None): ...@@ -2323,6 +2323,7 @@ def eig(x, name=None):
# [ (16.50471283351188+0j) , (-5.5034820550763515+0j) , # [ (16.50471283351188+0j) , (-5.5034820550763515+0j) ,
# (-0.21026087843552282+0j)]) # (-0.21026087843552282+0j)])
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.eig(x) return _C_ops.eig(x)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册