未验证 提交 f6e874bc 编写于 作者: R Ryan 提交者: GitHub

[Divide by 0 Error] add pinv check (#49951)

* add pinv check

* add unitest

* update unitest

* roll back

* fix not call stupid bug

* use context
上级 094e3b8c
......@@ -105,6 +105,14 @@ void SvdKernel(const Context& dev_ctx,
// int k = std::min(rows, cols);
// int col_u = full ? rows : k;
// int col_v = full ? cols : k;
PADDLE_ENFORCE_LT(
0,
rows,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
cols,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
int batches = numel / (rows * cols);
auto* U_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* VH_out = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
......
......@@ -217,6 +217,15 @@ void SvdKernel(const Context& dev_ctx,
int m = dims[rank - 2];
int n = dims[rank - 1];
PADDLE_ENFORCE_LT(
0,
m,
errors::InvalidArgument("The row of Input(X) should be greater than 0."));
PADDLE_ENFORCE_LT(
0,
n,
errors::InvalidArgument("The col of Input(X) should be greater than 0."));
auto* u_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(U);
auto* vh_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(VH);
auto* s_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(S);
......
......@@ -280,5 +280,27 @@ class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
self.hermitian = True
class TestDivByZero(unittest.TestCase):
def pinv_zero_input_static(self):
paddle.enable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)
def pinv_zero_input_dynamic(self):
paddle.disable_static()
array = np.array([], dtype=np.float32)
x = paddle.to_tensor(np.reshape(array, [0, 0]), dtype='float32')
paddle.linalg.pinv(x)
def test_div_by_zero(self):
with self.assertRaises(ValueError):
self.pinv_zero_input_dynamic()
self.pinv_zero_input_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册