未验证 提交 71e01d3f 编写于 作者: A andyjpaddle 提交者: GitHub

Add linalg pinv api (#35804)

* add pinv api, test=develop
* add linalg pinv api, test=develop
* update example code, test=develop
上级 cf9eae4c
......@@ -104,6 +104,7 @@ from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
from .tensor.logic import greater_than # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
from op_test import OpTest, skip_check_grad_ci
from gradient_checker import grad_check
from decorator_helper import prog_scope
class LinalgPinvTestCase(unittest.TestCase):
def setUp(self):
self.init_config()
self.generate_input()
self.generate_output()
self.places = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def generate_input(self):
self._input_shape = (5, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
def generate_output(self):
self._output_data = np.linalg.pinv(self._input_data, \
rcond=self.rcond, hermitian=self.hermitian)
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = False
def test_dygraph(self):
for place in self.places:
paddle.disable_static(place)
x = paddle.to_tensor(self._input_data, place=place)
out = paddle.linalg.pinv(
x, rcond=self.rcond, hermitian=self.hermitian).numpy()
if (np.abs(out - self._output_data) < 1e-6).any():
pass
else:
print("EXPECTED: \n", self._output_data)
print("GOT : \n", out)
raise RuntimeError("Check PINV dygraph Failed")
def test_static(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.fluid.data(
name="input",
shape=self._input_shape,
dtype=self._input_data.dtype)
out = paddle.linalg.pinv(
x, rcond=self.rcond, hermitian=self.hermitian)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input": self._input_data},
fetch_list=[out])
if (np.abs(fetches[0] - self._output_data) < 1e-6).any():
pass
else:
print("EXPECTED: \n", self._output_data)
print("GOT : \n", fetches[0])
raise RuntimeError("Check PINV static Failed")
def test_grad(self):
for place in self.places:
x = paddle.to_tensor(
self._input_data, place=place, stop_gradient=False)
out = paddle.linalg.pinv(
x, rcond=self.rcond, hermitian=self.hermitian)
try:
out.backward()
x_grad = x.grad
# print(x_grad)
except:
raise RuntimeError("Check PINV Grad Failed")
class LinalgPinvTestCase1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (4, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCase2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 4)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseBatch1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseBatch2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 4, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseBatch3(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 4)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseBatch4(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 6, 5, 4)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseBatchBig(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (2, 200, 300)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
class LinalgPinvTestCaseFP32(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.hermitian = False
class LinalgPinvTestCaseRcond(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
self._input_data = np.random.random(self._input_shape).astype(
self.dtype)
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-10
self.hermitian = False
class LinalgPinvTestCaseHermitian1(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 5)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose().conj()
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitian2(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj()
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitian3(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
x = np.random.random(self._input_shape).astype(self.dtype) + \
1J * np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1)).conj()
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitian4(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (5, 5)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose()
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitian5(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1))
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.hermitian = True
class LinalgPinvTestCaseHermitianFP32(LinalgPinvTestCase):
def generate_input(self):
self._input_shape = (3, 5, 5)
x = np.random.random(self._input_shape).astype(self.dtype)
self._input_data = x + x.transpose((0, 2, 1))
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.hermitian = True
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,7 @@ from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_rank
from .tensor.linalg import svd
from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import pinv
__all__ = [
'cholesky', #noqa
......@@ -31,5 +32,6 @@ __all__ = [
'matrix_rank',
'svd',
'matrix_power',
'eigh'
'eigh',
'pinv'
]
......@@ -49,6 +49,7 @@ from .linalg import matrix_power # noqa: F401
from .linalg import multi_dot # noqa: F401
from .linalg import svd # noqa: F401
from .linalg import eigh # noqa: F401
from .linalg import pinv # noqa: F401
from .logic import equal # noqa: F401
from .logic import greater_equal # noqa: F401
from .logic import greater_than # noqa: F401
......
......@@ -18,6 +18,8 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type
from ..fluid.framework import in_dygraph_mode, _varbase_creator, Variable
from ..fluid.layers import transpose, cast # noqa: F401
from ..fluid import layers
import paddle
from paddle.common_ops_import import core
from paddle.common_ops_import import VarDesc
from paddle import _C_ops
......@@ -1635,3 +1637,275 @@ def eigh(x, UPLO='L', name=None):
'Eigenvectors': out_vector},
attrs={'UPLO': UPLO})
return out_value, out_vector
def pinv(x, rcond=1e-15, hermitian=False, name=None):
r"""
Calculate pseudo inverse via SVD(singular value decomposition)
of one matrix or batches of regular matrix.
.. math::
if hermitian == False:
x = u * s * vt (SVD)
out = v * 1/s * ut
else:
x = u * s * ut (eigh)
out = u * 1/s * u.conj().transpose(-2,-1)
If x is hermitian or symmetric matrix, svd will be replaced with eigh.
Args:
x(Tensor): The input tensor. Its shape should be (*, m, n)
where * is zero or more batch dimensions. m and n can be
arbitraty positive number. The data type of x should be
float32 or float64 or complex64 or complex128. When data
type is complex64 or cpmplex128, hermitian should be set
True.
rcond(Tensor, optional): the tolerance value to determine
when is a singular value zero. Defalut:1e-15.
hermitian(bool, optional): indicates whether x is Hermitian
if complex or symmetric if real. Default: False.
name(str|None): A name for this layer(optional). If set None,
the layer will be named automatically.
Returns:
Tensor: The tensor with same data type with x. it represents
pseudo inverse of x. Its shape should be (*, n, m).
Examples:
.. code-block:: python
import paddle
x = paddle.arange(15).reshape((3, 5)).astype('float64')
input = paddle.to_tensor(x)
out = paddle.linalg.pinv(input)
print(input)
print(out)
# input:
# [[0. , 1. , 2. , 3. , 4. ],
# [5. , 6. , 7. , 8. , 9. ],
# [10., 11., 12., 13., 14.]]
# out:
# [[-0.22666667, -0.06666667, 0.09333333],
# [-0.12333333, -0.03333333, 0.05666667],
# [-0.02000000, 0.00000000, 0.02000000],
# [ 0.08333333, 0.03333333, -0.01666667],
# [ 0.18666667, 0.06666667, -0.05333333]]
# one can verify : x * out * x = x ;
# or out * x * out = x ;
"""
if in_dygraph_mode():
if not hermitian:
# combine svd and matmul op
u, s, vt = _C_ops.svd(x, 'full_matrices', False)
max_singular_val = _C_ops.reduce_max(s, 'dim', [-1], 'keep_dim', True, \
'reduce_all', False)
rcond = paddle.to_tensor(rcond, dtype=x.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=x.dtype)
condition = s > cutoff
cond_int = layers.cast(condition, s.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), s.dtype)
out1 = layers.elementwise_mul(1 / s, cond_int)
out2 = layers.elementwise_mul(1 / y, cond_not_int)
singular = layers.elementwise_add(out1, out2)
st, _ = _C_ops.unsqueeze2(singular, 'axes', [-2])
dims = list(range(len(vt.shape)))
perm = dims[:-2] + [dims[-1]] + [dims[-2]]
v, _ = _C_ops.transpose2(vt, 'axis', perm)
out_1 = v * st
out_2 = _C_ops.matmul_v2(out_1, u, 'trans_x', False, 'trans_y',
True)
return out_2
else:
# combine eigh and matmul op
s, u = _C_ops.eigh(x, 'UPLO', 'L')
s_abs = paddle.abs(s)
max_singular_val = _C_ops.reduce_max(s_abs, 'dim', [-1], 'keep_dim', True, \
'reduce_all', False)
rcond = paddle.to_tensor(rcond, dtype=s.dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = paddle.to_tensor(y, dtype=s.dtype)
condition = s_abs > cutoff
cond_int = layers.cast(condition, s.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), s.dtype)
out1 = layers.elementwise_mul(1 / s, cond_int)
out2 = layers.elementwise_mul(1 / y, cond_not_int)
singular = layers.elementwise_add(out1, out2)
st, _ = _C_ops.unsqueeze2(singular, 'axes', [-2])
out_1 = u * st
u_conj = _C_ops.conj(u)
out_2 = _C_ops.matmul_v2(out_1, u_conj, 'trans_x', False, 'trans_y',
True)
return out_2
else:
if not hermitian:
helper = LayerHelper('pinv', **locals())
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'pinv')
u = helper.create_variable_for_type_inference(dtype)
s = helper.create_variable_for_type_inference(dtype)
vt = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='svd',
inputs={'X': [x]},
outputs={'U': u,
'VH': vt,
'S': s},
attrs={'full_matrices': False}, )
max_singular_val = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='reduce_max',
inputs={'X': s},
outputs={'Out': max_singular_val},
attrs={'dim': [-1],
'keep_dim': True,
'reduce_all': False})
rcond = layers.fill_constant(shape=[1], value=rcond, dtype=dtype)
cutoff = rcond * max_singular_val
y = float('inf')
y = layers.fill_constant(shape=[1], value=y, dtype=dtype)
condition = s > cutoff
cond_int = layers.cast(condition, dtype)
cond_not_int = layers.cast(layers.logical_not(condition), dtype)
out1 = layers.elementwise_mul(1 / s, cond_int)
out2 = layers.elementwise_mul(1 / y, cond_not_int)
singular = layers.elementwise_add(out1, out2)
st = helper.create_variable_for_type_inference(dtype=dtype)
st_shape = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='unsqueeze2',
inputs={'X': singular},
attrs={'axes': [-2]},
outputs={'Out': st,
'XShape': st_shape})
dims = list(range(len(vt.shape)))
perm = dims[:-2] + [dims[-1]] + [dims[-2]]
v = helper.create_variable_for_type_inference(dtype)
v_shape = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='transpose2',
inputs={'X': [vt]},
outputs={'Out': [v],
'XShape': [v_shape]},
attrs={'axis': perm})
out_1 = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_mul',
inputs={'X': v,
'Y': st},
outputs={'Out': out_1},
attrs={'axis': -1,
'use_mkldnn': False})
out_1 = helper.append_activation(out_1)
out_2 = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2',
inputs={'X': out_1,
'Y': u},
outputs={'Out': out_2},
attrs={'trans_x': False,
'trans_y': True}, )
return out_2
else:
helper = LayerHelper('pinv', **locals())
dtype = x.dtype
check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'],
'pinv')
if dtype == paddle.complex128:
s_type = 'float64'
elif dtype == paddle.complex64:
s_type = 'float32'
else:
s_type = dtype
u = helper.create_variable_for_type_inference(dtype)
s = helper.create_variable_for_type_inference(s_type)
helper.append_op(
type='eigh',
inputs={'X': x},
outputs={'Eigenvalues': s,
'Eigenvectors': u},
attrs={'UPLO': 'L'})
s_abs = helper.create_variable_for_type_inference(s_type)
helper.append_op(
type='abs', inputs={'X': s}, outputs={'Out': s_abs})
max_singular_val = helper.create_variable_for_type_inference(s_type)
helper.append_op(
type='reduce_max',
inputs={'X': s_abs},
outputs={'Out': max_singular_val},
attrs={'dim': [-1],
'keep_dim': True,
'reduce_all': False})
rcond = layers.fill_constant(shape=[1], value=rcond, dtype=s_type)
cutoff = rcond * max_singular_val
y = float('inf')
y = layers.fill_constant(shape=[1], value=y, dtype=s_type)
condition = s_abs > cutoff
cond_int = layers.cast(condition, s_type)
cond_not_int = layers.cast(layers.logical_not(condition), s_type)
out1 = layers.elementwise_mul(1 / s, cond_int)
out2 = layers.elementwise_mul(1 / y, cond_not_int)
singular = layers.elementwise_add(out1, out2)
st = helper.create_variable_for_type_inference(dtype=s_type)
st_shape = helper.create_variable_for_type_inference(dtype=s_type)
helper.append_op(
type='unsqueeze2',
inputs={'X': singular},
attrs={'axes': [-2]},
outputs={'Out': st,
'XShape': st_shape})
out_1 = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_mul',
inputs={'X': u,
'Y': st},
outputs={'Out': out_1},
attrs={'axis': -1,
'use_mkldnn': False})
out_1 = helper.append_activation(out_1)
u_conj = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='conj', inputs={'X': u}, outputs={'Out': [u_conj]})
out_2 = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='matmul_v2',
inputs={'X': out_1,
'Y': u_conj},
outputs={'Out': out_2},
attrs={'trans_x': False,
'trans_y': True}, )
return out_2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册