未验证 提交 4ebb4764 编写于 作者: N NetPunk 提交者: GitHub

【Hackathon 4 No.9】Add pca_lowrank API to Paddle (#53743)

上级 7309f8ab
...@@ -29,6 +29,7 @@ from .tensor.linalg import matrix_power # noqa: F401 ...@@ -29,6 +29,7 @@ from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import matrix_rank # noqa: F401 from .tensor.linalg import matrix_rank # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import pca_lowrank # noqa: F401
from .tensor.linalg import pinv # noqa: F401 from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import qr # noqa: F401 from .tensor.linalg import qr # noqa: F401
from .tensor.linalg import slogdet # noqa: F401 from .tensor.linalg import slogdet # noqa: F401
...@@ -50,6 +51,7 @@ __all__ = [ ...@@ -50,6 +51,7 @@ __all__ = [
'matrix_rank', 'matrix_rank',
'svd', 'svd',
'qr', 'qr',
'pca_lowrank',
'lu', 'lu',
'lu_unpack', 'lu_unpack',
'matrix_power', 'matrix_power',
......
...@@ -28,6 +28,7 @@ from .unary import square ...@@ -28,6 +28,7 @@ from .unary import square
from .unary import log1p from .unary import log1p
from .unary import abs from .unary import abs
from .unary import pow from .unary import pow
from .unary import pca_lowrank
from .unary import cast from .unary import cast
from .unary import neg from .unary import neg
from .unary import coalesce from .unary import coalesce
...@@ -69,6 +70,7 @@ __all__ = [ ...@@ -69,6 +70,7 @@ __all__ = [
'log1p', 'log1p',
'abs', 'abs',
'pow', 'pow',
'pca_lowrank',
'cast', 'cast',
'neg', 'neg',
'deg2rad', 'deg2rad',
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import numpy as np import numpy as np
import paddle
from paddle import _C_ops, in_dynamic_mode from paddle import _C_ops, in_dynamic_mode
from paddle.common_ops_import import Variable from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
...@@ -920,3 +921,205 @@ def slice(x, axes, starts, ends, name=None): ...@@ -920,3 +921,205 @@ def slice(x, axes, starts, ends, name=None):
type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs type=op_type, inputs={'x': x}, outputs={'out': out}, attrs=attrs
) )
return out return out
def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a sparse matrix.
Let :math:`X` be the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X = U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[N, M]`,
N and M can be arbitraty positive number.
The data type of x should be float32 or float64.
q (int, optional): a slightly overestimated rank of :math:`X`.
Default value is :math:`q=min(6,N,M)`.
center (bool, optional): if True, center the input tensor.
Default value is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`X`.
Examples:
.. code-block:: python
import paddle
format = "coo"
dense_x = paddle.randn((5, 5), dtype='float64')
if format == "coo":
sparse_x = dense_x.to_sparse_coo(len(dense_x.shape))
else:
sparse_x = dense_x.to_sparse_csr()
print("sparse.pca_lowrank API only support CUDA 11.x")
U, S, V = None, None, None
# use code blow when your device CUDA version >= 11.0
# U, S, V = paddle.sparse.pca_lowrank(sparse_x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.02206024, 0.53170082, -0.22392168, -0.48450657, 0.65720625],
# [ 0.02206024, 0.53170082, -0.22392168, -0.32690402, -0.74819812],
# [ 0.02206024, 0.53170082, -0.22392168, 0.81141059, 0.09099187],
# [ 0.15045792, 0.37840027, 0.91333217, -0.00000000, 0.00000000],
# [ 0.98787775, -0.09325209, -0.12410317, -0.00000000, -0.00000000]])
print(S)
# Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [2.28621761, 0.93618564, 0.53234942, 0.00000000, 0.00000000])
print(V)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.26828910, -0.57116436, -0.26548201, 0.67342660, -0.27894114],
# [-0.19592125, -0.31629129, 0.02001645, -0.50484498, -0.77865626],
# [-0.82913017, -0.09391036, 0.37975388, 0.39938099, -0.00241046],
# [-0.41163516, 0.27490410, -0.86666276, 0.03382656, -0.05230341],
# [ 0.18092947, 0.69952818, 0.18385126, 0.36190987, -0.55959343]])
"""
def get_floating_dtype(x):
dtype = x.dtype
if dtype in (paddle.float16, paddle.float32, paddle.float64):
return dtype
return paddle.float32
def conjugate(x):
if x.is_complex():
return x.conj()
return x
def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
if x.is_sparse():
return paddle.sparse.transpose(x, perm)
return paddle.transpose(x, perm)
def transjugate(x):
return conjugate(transpose(x))
def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr
R = paddle.randn((n, q), dtype=x.dtype)
A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.sparse.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.sparse.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.sparse.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.sparse.matmul(x, Q) - paddle.matmul(M, Q))[0]
return Q
def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)
if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.sparse.matmul(x, Q_c)
else:
B_t = paddle.sparse.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.sparse.matmul(A_t, Q_c)
else:
B = paddle.sparse.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)
return U, S, V
if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')
if not x.is_sparse():
raise ValueError('Input must be sparse, but got dense')
cuda_version = paddle.version.cuda()
if (
cuda_version is None
or cuda_version == 'False'
or int(cuda_version.split('.')[0]) < 11
):
raise ValueError('sparse.pca_lowrank API only support CUDA 11.x')
(m, n) = x.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
'q(={}) must be non-negative integer'
' and not greater than min(m, n)={}'.format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')
dtype = get_floating_dtype(x)
if not center:
return svd_lowrank(x, q, niter=niter, M=None)
if len(x.shape) != 2:
raise ValueError('input is expected to be 2-dimensional tensor')
# TODO: complement sparse_csr_tensor test
# when sparse.sum with axis(-2) is implemented
s_sum = paddle.sparse.sum(x, axis=-2)
s_val = s_sum.values() / m
c = paddle.sparse.sparse_coo_tensor(
s_sum.indices(), s_val, dtype=s_sum.dtype, place=s_sum.place
)
column_indices = c.indices()[0]
indices = paddle.zeros((2, len(column_indices)), dtype=column_indices.dtype)
indices[0] = column_indices
C_t = paddle.sparse.sparse_coo_tensor(
indices, c.values(), (n, 1), dtype=dtype, place=x.place
)
ones_m1_t = paddle.ones(x.shape[:-2] + [1, m], dtype=dtype)
M = transpose(paddle.matmul(C_t.to_dense(), ones_m1_t))
return svd_lowrank(x, q, niter=niter, M=M)
...@@ -46,6 +46,7 @@ from .linalg import dot # noqa: F401 ...@@ -46,6 +46,7 @@ from .linalg import dot # noqa: F401
from .linalg import cov # noqa: F401 from .linalg import cov # noqa: F401
from .linalg import corrcoef # noqa: F401 from .linalg import corrcoef # noqa: F401
from .linalg import norm # noqa: F401 from .linalg import norm # noqa: F401
from .linalg import pca_lowrank # noqa: F401
from .linalg import cond # noqa: F401 from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401 from .linalg import transpose # noqa: F401
from .linalg import lstsq # noqa: F401 from .linalg import lstsq # noqa: F401
...@@ -333,6 +334,7 @@ tensor_method_func = [ # noqa ...@@ -333,6 +334,7 @@ tensor_method_func = [ # noqa
'mv', 'mv',
'matrix_power', 'matrix_power',
'qr', 'qr',
'pca_lowrank',
'eigvals', 'eigvals',
'eigvalsh', 'eigvalsh',
'abs', 'abs',
......
...@@ -1963,6 +1963,159 @@ def svd(x, full_matrices=False, name=None): ...@@ -1963,6 +1963,159 @@ def svd(x, full_matrices=False, name=None):
return u, s, vh return u, s, vh
def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a low-rank matrix or batches of such matrices.
Let :math:`X` be the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X = U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[..., N, M]`,
where `...` is zero or more batch dimensions. N and M can be arbitraty
positive number. The data type of x should be float32 or float64.
q (int, optional): a slightly overestimated rank of :math:`X`.
Default value is :math:`q=min(6,N,M)`.
center (bool, optional): if True, center the input tensor.
Default value is True.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of a centered matrix :math:`X`.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((5, 5), dtype='float64')
U, S, V = paddle.linalg.pca_lowrank(x)
print(U)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.41057070, 0.40364287, 0.59099574, -0.34529432, 0.44721360],
# [-0.30243321, 0.55670611, -0.15025419, 0.61321785, 0.44721360],
# [ 0.57427340, -0.15936327, -0.66414981, -0.06097905, 0.44721360],
# [-0.63897516, -0.09968973, -0.17298615, -0.59316819, 0.44721360],
# [-0.04343573, -0.70129598, 0.39639442, 0.38622370, 0.44721360]])
print(S)
# Tensor(shape=[5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [3.33724265, 2.57573259, 1.69479048, 0.68069312, 0.00000000])
print(V)
# Tensor(shape=[5, 5], dtype=float64, place=Place(gpu:0), stop_gradient=True,
# [[ 0.09800724, -0.32627008, -0.23593953, 0.81840445, 0.39810690],
# [-0.60100303, 0.63741176, -0.01953663, 0.09023999, 0.47326173],
# [ 0.25073864, -0.21305240, -0.32662950, -0.54786156, 0.69634740],
# [ 0.33057205, 0.48282641, -0.75998527, 0.06744040, -0.27472705],
# [ 0.67604895, 0.45688227, 0.50959437, 0.13179682, 0.23908071]])
"""
def conjugate(x):
if x.is_complex():
return x.conj()
return x
def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def transjugate(x):
return conjugate(transpose(x))
def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr
R = paddle.randn((n, q), dtype=x.dtype)
A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q))[0]
Q = qr(paddle.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0]
return Q
def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)
if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.matmul(x, Q_c)
else:
B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.matmul(A_t, Q_c)
else:
B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)
return U, S, V
if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')
(m, n) = x.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
'q(={}) must be non-negative integer'
' and not greater than min(m, n)={}'.format(q, min(m, n))
)
if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')
if not center:
return svd_lowrank(x, q, niter=niter, M=None)
C = x.mean(axis=-2, keepdim=True)
return svd_lowrank(x - C, q, niter=niter, M=None)
def matrix_power(x, n, name=None): def matrix_power(x, n, name=None):
r""" r"""
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestPcaLowrankAPI(unittest.TestCase):
def transpose(self, x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def random_matrix(self, rows, columns, *batch_dims, **kwargs):
dtype = kwargs.get('dtype', paddle.float64)
x = paddle.randn(batch_dims + (rows, columns), dtype=dtype)
if x.numel() == 0:
return x
u, _, vh = paddle.linalg.svd(x, full_matrices=False)
k = min(rows, columns)
s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype)
return (u * s.unsqueeze(-2)) @ vh
def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs):
B = self.random_matrix(rows, rank, *batch_dims, **kwargs)
C = self.random_matrix(rank, columns, *batch_dims, **kwargs)
return B.matmul(C)
def run_subtest(
self, guess_rank, actual_rank, matrix_size, batches, pca, **options
):
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
a_input = self.random_lowrank_matrix(
actual_rank, rows, columns, *batches
)
a = a_input
u, s, v = pca(a_input, q=guess_rank, **options)
self.assertEqual(s.shape[-1], guess_rank)
self.assertEqual(u.shape[-2], rows)
self.assertEqual(u.shape[-1], guess_rank)
self.assertEqual(v.shape[-1], guess_rank)
self.assertEqual(v.shape[-2], columns)
A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul(
self.transpose(v)
)
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
c = a.sum(axis=-2) / rows
c = c.reshape(batches + (1, columns))
A2 = a - ones_m1.matmul(c)
np.testing.assert_allclose(A1.numpy(), A2.numpy(), atol=1e-5)
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
left = actual_rank * paddle.ones(batches, dtype=paddle.int64)
if not left.shape:
np.testing.assert_allclose(int(left), int(detect_rank))
else:
np.testing.assert_allclose(left.numpy(), detect_rank.numpy())
S = paddle.linalg.svd(A2, full_matrices=False)[1]
left = s[..., :actual_rank]
right = S[..., :actual_rank]
np.testing.assert_allclose(left.numpy(), right.numpy())
def test_forward(self):
pca_lowrank = paddle.linalg.pca_lowrank
all_batches = [(), (1,), (3,), (2, 3)]
for actual_rank, size in [
(2, (17, 4)),
(2, (100, 4)),
(6, (100, 40)),
]:
for batches in all_batches:
for guess_rank in [
actual_rank,
actual_rank + 2,
actual_rank + 6,
]:
if guess_rank <= min(*size):
self.run_subtest(
guess_rank, actual_rank, size, batches, pca_lowrank
)
self.run_subtest(
guess_rank,
actual_rank,
size[::-1],
batches,
pca_lowrank,
)
x = np.random.randn(5, 5).astype('float64')
x = paddle.to_tensor(x)
q = None
U, S, V = pca_lowrank(x, q, center=False)
def test_errors(self):
pca_lowrank = paddle.linalg.pca_lowrank
x = np.random.randn(5, 5).astype('float64')
x = paddle.to_tensor(x)
def test_x_not_tensor():
U, S, V = pca_lowrank(x.numpy())
self.assertRaises(ValueError, test_x_not_tensor)
def test_q_range():
q = -1
U, S, V = pca_lowrank(x, q)
self.assertRaises(ValueError, test_q_range)
def test_niter_range():
n = -1
U, S, V = pca_lowrank(x, niter=n)
self.assertRaises(ValueError, test_niter_range)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
import numpy as np
import paddle
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
class TestSparsePcaLowrankAPI(unittest.TestCase):
def transpose(self, x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)
def random_sparse_matrix(self, rows, columns, density=0.01, **kwargs):
dtype = kwargs.get('dtype', paddle.float64)
nonzero_elements = max(
min(rows, columns), int(rows * columns * density)
)
row_indices = [i % rows for i in range(nonzero_elements)]
column_indices = [i % columns for i in range(nonzero_elements)]
random.shuffle(column_indices)
indices = [row_indices, column_indices]
values = paddle.randn((nonzero_elements,), dtype=dtype)
values *= paddle.to_tensor(
[-float(i - j) ** 2 for i, j in zip(*indices)], dtype=dtype
).exp()
indices_tensor = paddle.to_tensor(indices)
x = paddle.sparse.sparse_coo_tensor(
indices_tensor, values, (rows, columns)
)
return paddle.sparse.coalesce(x)
def run_subtest(self, guess_rank, matrix_size, batches, pca, **options):
density = options.pop('density', 0.5)
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
a_input = self.random_sparse_matrix(rows, columns, density)
a = a_input.to_dense()
u, s, v = pca(a_input, q=guess_rank, **options)
self.assertEqual(s.shape[-1], guess_rank)
self.assertEqual(u.shape[-2], rows)
self.assertEqual(u.shape[-1], guess_rank)
self.assertEqual(v.shape[-1], guess_rank)
self.assertEqual(v.shape[-2], columns)
A1 = u.matmul(paddle.nn.functional.diag_embed(s)).matmul(
self.transpose(v)
)
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
c = a.sum(axis=-2) / rows
c = c.reshape(batches + (1, columns))
A2 = a - ones_m1.matmul(c)
np.testing.assert_allclose(A1.numpy(), A2.numpy(), atol=1e-5)
@unittest.skipIf(
not paddle.is_compiled_with_cuda() or get_cuda_version() < 11000,
"only support cuda>=11.0",
)
def test_sparse(self):
pca_lowrank = paddle.sparse.pca_lowrank
for guess_rank, size in [
(4, (17, 4)),
(4, (4, 17)),
(16, (17, 17)),
(21, (100, 40)),
]:
for density in [0.005, 0.01]:
self.run_subtest(
guess_rank, size, (), pca_lowrank, density=density
)
def test_errors(self):
pca_lowrank = paddle.sparse.pca_lowrank
x = np.random.randn(5, 5).astype('float64')
dense_x = paddle.to_tensor(x)
sparse_x = dense_x.to_sparse_coo(len(x.shape))
def test_x_not_tensor():
U, S, V = pca_lowrank(x)
self.assertRaises(ValueError, test_x_not_tensor)
def test_x_not_sparse():
U, S, V = pca_lowrank(sparse_x.to_dense())
self.assertRaises(ValueError, test_x_not_sparse)
def test_q_range():
q = -1
U, S, V = pca_lowrank(sparse_x, q)
self.assertRaises(ValueError, test_q_range)
def test_niter_range():
n = -1
U, S, V = pca_lowrank(sparse_x, niter=n)
self.assertRaises(ValueError, test_niter_range)
def test_x_wrong_shape():
x = np.random.randn(5, 5, 5).astype('float64')
dense_x = paddle.to_tensor(x)
sparse_x = dense_x.to_sparse_coo(len(x.shape))
U, S, V = pca_lowrank(sparse_x)
self.assertRaises(ValueError, test_x_wrong_shape)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册