未验证 提交 434343c6 编写于 作者: A andyj 提交者: GitHub

fix pinv api for divide zero (#53815)

上级 2174e91c
......@@ -267,6 +267,19 @@ class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
self.hermitian = True
class LinalgPinvTestCaseHermitian6(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
np.random.seed(123)
x = np.ones(self._input_shape).astype(self.dtype)
self._input_data = x + 0.01
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
......
......@@ -26,9 +26,7 @@ from ..fluid.data_feeder import (
)
from ..framework import LayerHelper, in_dygraph_mode
from .creation import full
from .logic import logical_not
from .manipulation import cast
from .math import add, multiply
__all__ = []
......@@ -2665,12 +2663,7 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
y = float('inf')
y = paddle.to_tensor(y, dtype=x.dtype)
condition = s > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
singular = paddle.where(s > cutoff, 1 / s, 1 / y)
st = _C_ops.unsqueeze(singular, [-2])
dims = list(range(len(vt.shape)))
......@@ -2690,12 +2683,7 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
y = float('inf')
y = paddle.to_tensor(y, dtype=s.dtype)
condition = s_abs > cutoff
cond_int = cast(condition, s.dtype)
cond_not_int = cast(logical_not(condition), s.dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
singular = paddle.where(s_abs > cutoff, 1 / s, 1 / y)
st = _C_ops.unsqueeze(singular, [-2])
out_1 = u * st
......@@ -2731,12 +2719,7 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
y = float('inf')
y = full(shape=[1], fill_value=y, dtype=dtype)
condition = s > cutoff
cond_int = cast(condition, dtype)
cond_not_int = cast(logical_not(condition), dtype)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
singular = paddle.where(s > cutoff, 1 / s, 1 / y)
st = helper.create_variable_for_type_inference(dtype=dtype)
st_shape = helper.create_variable_for_type_inference(dtype=dtype)
......@@ -2817,12 +2800,7 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
y = float('inf')
y = full(shape=[1], fill_value=y, dtype=s_type)
condition = s_abs > cutoff
cond_int = cast(condition, s_type)
cond_not_int = cast(logical_not(condition), s_type)
out1 = multiply(1 / s, cond_int)
out2 = multiply(1 / y, cond_not_int)
singular = add(out1, out2)
singular = paddle.where(s_abs > cutoff, 1 / s, 1 / y)
st = helper.create_variable_for_type_inference(dtype=s_type)
st_shape = helper.create_variable_for_type_inference(dtype=s_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册