未验证 提交 95a502a2 编写于 作者: L liqitong-a 提交者: GitHub

【PaddlePaddle Hackathon 2】3、为 Paddle 新增 corrcoef(皮尔逊积矩相关系数) API (#40690)

* corrcoef commit

* corrcoef commit

* Update test_corr.py

* Update linalg.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update linalg.py

* Update linalg.py

* Update linalg.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py

* Update test_corr.py
上级 bba5e083
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import numpy as np
import six
import paddle
import warnings
def numpy_corr(np_arr, rowvar=True, dtype='float64'):
return np.corrcoef(np_arr, rowvar=rowvar, dtype=dtype)
class Corr_Test(unittest.TestCase):
def setUp(self):
self.shape = [4, 5]
def test_tensor_corr_default(self):
typelist = ['float64', 'float32']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
tensor = paddle.to_tensor(np_arr, place=p)
corr = paddle.linalg.corrcoef(tensor)
np_corr = numpy_corr(np_arr, rowvar=True, dtype=dtype)
if dtype == 'float32':
self.assertTrue(
np.allclose(
np_corr, corr.numpy(), atol=1.e-5))
else:
self.assertTrue(np.allclose(np_corr, corr.numpy()))
def test_tensor_corr_rowvar(self):
typelist = ['float64', 'float32']
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for idx, p in enumerate(places):
if idx == 0:
paddle.set_device('cpu')
else:
paddle.set_device('gpu')
for dtype in typelist:
np_arr = np.random.rand(*self.shape).astype(dtype)
tensor = paddle.to_tensor(np_arr, place=p)
corr = paddle.linalg.corrcoef(tensor, rowvar=False)
np_corr = numpy_corr(np_arr, rowvar=False, dtype=dtype)
if dtype == 'float32':
self.assertTrue(
np.allclose(
np_corr, corr.numpy(), atol=1.e-5))
else:
self.assertTrue(np.allclose(np_corr, corr.numpy()))
# Input(x) only support N-D (1<=N<=2) tensor
class Corr_Test2(Corr_Test):
def setUp(self):
self.shape = [10]
class Corr_Test3(Corr_Test):
def setUp(self):
self.shape = [4, 5]
# Input(x) only support N-D (1<=N<=2) tensor
class Corr_Test4(unittest.TestCase):
def setUp(self):
self.shape = [2, 5, 2]
def test_errors(self):
def test_err():
np_arr = np.random.rand(*self.shape).astype('float64')
tensor = paddle.to_tensor(np_arr)
covrr = paddle.linalg.corrcoef(tensor)
self.assertRaises(ValueError, test_err)
# test unsupported complex input
class Corr_Comeplex_Test(unittest.TestCase):
def setUp(self):
self.dtype = 'complex128'
def test_errors(self):
paddle.enable_static()
x1 = fluid.data(name=self.dtype, shape=[2], dtype=self.dtype)
self.assertRaises(TypeError, paddle.linalg.corrcoef, x=x1)
paddle.disable_static()
class Corr_Test5(Corr_Comeplex_Test):
def setUp(self):
self.dtype = 'complex64'
if __name__ == '__main__':
unittest.main()
......@@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import eig # noqa: F401
from .tensor.linalg import cov # noqa: F401
from .tensor.linalg import corrcoef # noqa: F401
from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import solve # noqa: F401
......@@ -41,6 +42,7 @@ __all__ = [
'norm',
'cond',
'cov',
'corrcoef',
'inv',
'eig',
'eigvals',
......
......@@ -40,6 +40,7 @@ from .creation import complex # noqa: F401
from .linalg import matmul # noqa: F401
from .linalg import dot # noqa: F401
from .linalg import cov # noqa: F401
from .linalg import corrcoef # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
......@@ -278,6 +279,7 @@ tensor_method_func = [ #noqa
'matmul',
'dot',
'cov',
'corrcoef',
'norm',
'cond',
'transpose',
......
......@@ -24,6 +24,7 @@ from .logic import logical_not
from .creation import full
import paddle
import warnings
from paddle.common_ops_import import core
from paddle.common_ops_import import VarDesc
from paddle import _C_ops
......@@ -3181,3 +3182,72 @@ def lstsq(x, y, rcond=None, driver=None, name=None):
singular_values = paddle.static.data(name='singular_values', shape=[0])
return solution, residuals, rank, singular_values
def corrcoef(x, rowvar=True, name=None):
"""
A correlation coefficient matrix indicate the correlation of each pair variables in the input matrix.
For example, for an N-dimensional samples X=[x1,x2,…xN]T, then the correlation coefficient matrix
element Rij is the correlation of xi and xj. The element Rii is the covariance of xi itself.
The relationship between the correlation coefficient matrix `R` and the
covariance matrix `C`, is
.. math:: R_{ij} = \\frac{ C_{ij} } { \\sqrt{ C_{ii} * C_{jj} } }
The values of `R` are between -1 and 1.
Parameters:
x(Tensor): A N-D(N<=2) Tensor containing multiple variables and observations. By default, each row of x represents a variable. Also see rowvar below.
rowvar(Bool, optional): If rowvar is True (default), then each row represents a variable, with observations in the columns. Default: True.
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`.
Returns:
The correlation coefficient matrix of the variables.
Examples:
.. code-block:: python
:name: code-example1
import paddle
xt = paddle.rand((3,4))
print(paddle.linalg.corrcoef(xt))
# Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[ 1. , -0.73702252, 0.66228950],
# [-0.73702258, 1. , -0.77104872],
# [ 0.66228974, -0.77104825, 1. ]])
"""
if len(x.shape) > 2 or len(x.shape) < 1:
raise ValueError(
"Input(x) only support N-D (1<=N<=2) tensor in corrcoef, but received "
"length of Input(input) is %s." % len(x.shape))
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'corrcoef')
c = cov(x, rowvar)
if (c.ndim == 0):
# scalar covariance
# nan if incorrect value (nan, inf, 0), 1 otherwise
return c / c
d = paddle.diag(c)
if paddle.is_complex(d):
d = d.real()
stddev = paddle.sqrt(d)
c /= stddev[:, None]
c /= stddev[None, :]
# Clip to [-1, 1]. This does not guarantee
if paddle.is_complex(c):
return paddle.complex(
paddle.clip(c.real(), -1, 1), paddle.clip(c.imag(), -1, 1))
else:
c = paddle.clip(c, -1, 1)
return c
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册