From d411a038af0293917086344d3506c86afd4f34ea Mon Sep 17 00:00:00 2001 From: From00 Date: Sat, 18 Sep 2021 00:26:59 +0800 Subject: [PATCH] Add new API "eigvals" in linalg (#35720) * Add linalg.eigvals API * pre-commit check * Adjust code style * Fix conflict * Improve code style * Modify the test code to ignore testing CUDA kernel * Sort ouput data before checking in test code * Set timeout value for UT * Improve API example code to pass CI * Fix bug for None fetch_list in Windows * Delete grad Op --- paddle/fluid/framework/ddim.cc | 28 ++ paddle/fluid/framework/ddim.h | 7 + paddle/fluid/operators/eigvals_op.cc | 89 +++++ paddle/fluid/operators/eigvals_op.h | 129 ++++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../paddle/fluid/tests/unittests/op_test.py | 6 + .../fluid/tests/unittests/test_eigvals_op.py | 307 ++++++++++++++++++ python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 60 ++++ 10 files changed, 631 insertions(+) create mode 100644 paddle/fluid/operators/eigvals_op.cc create mode 100644 paddle/fluid/operators/eigvals_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_eigvals_op.py diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index fe7d2430662..8bac8b7df6d 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -107,6 +107,34 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } +DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims) { + PADDLE_ENFORCE_GE(src.size(), 3, + platform::errors::InvalidArgument( + "The rank of src dim should be at least 3 " + "in flatten_to_3d, but received %d.", + src.size())); + PADDLE_ENFORCE_EQ((num_row_dims >= 1 && num_row_dims < src.size()), true, + platform::errors::InvalidArgument( + "The num_row_dims should be inside [1, %d] " + "in flatten_to_3d, but received %d.", + src.size() - 1, num_row_dims)); + PADDLE_ENFORCE_EQ((num_col_dims >= 2 && num_col_dims <= src.size()), true, + platform::errors::InvalidArgument( + "The num_col_dims should be inside [2, %d] " + "in flatten_to_3d, but received %d.", + src.size(), num_col_dims)); + PADDLE_ENFORCE_GE( + num_col_dims, num_row_dims, + platform::errors::InvalidArgument( + "The num_row_dims should be less than num_col_dims in flatten_to_3d," + "but received num_row_dims = %d, num_col_dims = %d.", + num_row_dims, num_col_dims)); + + return DDim({product(slice_ddim(src, 0, num_row_dims)), + product(slice_ddim(src, num_row_dims, num_col_dims)), + product(slice_ddim(src, num_col_dims, src.size()))}); +} + DDim flatten_to_2d(const DDim& src, int num_col_dims) { return DDim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, src.size()))}); diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index e69fb4e7619..565e0b430df 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -230,6 +230,13 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +/** +* \brief Flatten dim to 3d +* e.g., DDim d = mak_ddim({1, 2, 3, 4, 5, 6}) +* flatten_to_3d(d, 2, 4); ===> {1*2, 3*4, 5*6} ===> {2, 12, 30} +*/ +DDim flatten_to_3d(const DDim& src, int num_row_dims, int num_col_dims); + // Reshape a tensor to a matrix. The matrix's first dimension(column length) // will be the product of tensor's first `num_col_dims` dimensions. DDim flatten_to_2d(const DDim& src, int num_col_dims); diff --git a/paddle/fluid/operators/eigvals_op.cc b/paddle/fluid/operators/eigvals_op.cc new file mode 100644 index 00000000000..dcf35019095 --- /dev/null +++ b/paddle/fluid/operators/eigvals_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/eigvals_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), A complex- or real-valued tensor with shape (*, n, n)" + "where * is zero or more batch dimensions"); + AddOutput("Out", + "(Tensor) The output tensor with shape (*,n) cointaining the " + "eigenvalues of X."); + AddComment(R"DOC(eigvals operator + Return the eigenvalues of one or more square matrices. The eigenvalues are complex even when the input matrices are real. + )DOC"); + } +}; + +class EigvalsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals"); + + DDim x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The dimensions of Input(X) for Eigvals operator " + "should be at least 2, " + "but received X's dimension = %d, X's shape = [%s].", + x_dims.size(), x_dims)); + + if (ctx->IsRuntime() || !framework::contain_unknown_dim(x_dims)) { + int last_dim = x_dims.size() - 1; + PADDLE_ENFORCE_EQ(x_dims[last_dim], x_dims[last_dim - 1], + platform::errors::InvalidArgument( + "The last two dimensions of Input(X) for Eigvals " + "operator should be equal, " + "but received X's shape = [%s].", + x_dims)); + } + + auto output_dims = vectorize(x_dims); + output_dims.resize(x_dims.size() - 1); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + } +}; + +class EigvalsOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* ctx) const { + auto input_dtype = ctx->GetInputDataType("X"); + auto output_dtype = framework::IsComplexType(input_dtype) + ? input_dtype + : framework::ToComplexType(input_dtype); + ctx->SetOutputDataType("Out", output_dtype); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker, + ops::EigvalsOpVarTypeInference); +REGISTER_OP_CPU_KERNEL(eigvals, + ops::EigvalsKernel, + ops::EigvalsKernel, + ops::EigvalsKernel>, + ops::EigvalsKernel>); diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h new file mode 100644 index 00000000000..998dcd9f1ef --- /dev/null +++ b/paddle/fluid/operators/eigvals_op.h @@ -0,0 +1,129 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "Eigen/Dense" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +struct PaddleComplex { + using Type = paddle::platform::complex; +}; +template <> +struct PaddleComplex> { + using Type = paddle::platform::complex; +}; +template <> +struct PaddleComplex> { + using Type = paddle::platform::complex; +}; + +template +struct StdComplex { + using Type = std::complex; +}; +template <> +struct StdComplex> { + using Type = std::complex; +}; +template <> +struct StdComplex> { + using Type = std::complex; +}; + +template +using PaddleCType = typename PaddleComplex::Type; +template +using StdCType = typename StdComplex::Type; +template +using EigenMatrixPaddle = Eigen::Matrix; +template +using EigenVectorPaddle = Eigen::Matrix, Eigen::Dynamic, 1>; +template +using EigenMatrixStd = + Eigen::Matrix, Eigen::Dynamic, Eigen::Dynamic>; +template +using EigenVectorStd = Eigen::Matrix, Eigen::Dynamic, 1>; + +static void SpiltBatchSquareMatrix(const Tensor &input, + std::vector *output) { + DDim input_dims = input.dims(); + int last_dim = input_dims.size() - 1; + int n_dim = input_dims[last_dim]; + + DDim flattened_input_dims, flattened_output_dims; + if (input_dims.size() > 2) { + flattened_input_dims = flatten_to_3d(input_dims, last_dim - 1, last_dim); + } else { + flattened_input_dims = framework::make_ddim({1, n_dim, n_dim}); + } + + Tensor flattened_input; + flattened_input.ShareDataWith(input); + flattened_input.Resize(flattened_input_dims); + (*output) = flattened_input.Split(1, 0); +} + +template +class EigvalsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const Tensor *input = ctx.Input("X"); + Tensor *output = ctx.Output("Out"); + + auto input_type = input->type(); + auto output_type = framework::IsComplexType(input_type) + ? input_type + : framework::ToComplexType(input_type); + output->mutable_data(ctx.GetPlace(), output_type); + + std::vector input_matrices; + SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); + + int n_dim = input_matrices[0].dims()[1]; + int n_batch = input_matrices.size(); + + DDim output_dims = output->dims(); + output->Resize(framework::make_ddim({n_batch, n_dim})); + std::vector output_vectors = output->Split(1, 0); + + Eigen::Map> input_emp(NULL, n_dim, n_dim); + Eigen::Map> output_evp(NULL, n_dim); + EigenMatrixStd input_ems; + EigenVectorStd output_evs; + + for (int i = 0; i < n_batch; ++i) { + new (&input_emp) Eigen::Map>( + input_matrices[i].data(), n_dim, n_dim); + new (&output_evp) Eigen::Map>( + output_vectors[i].data>(), n_dim); + input_ems = input_emp.template cast>(); + output_evs = input_ems.eigenvalues(); + output_evp = output_evs.template cast>(); + } + output->Resize(output_dims); + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c17c78c0080..00067095209 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1033,3 +1033,4 @@ if(WITH_GPU OR WITH_ROCM) set_tests_properties(test_rank_attention_op PROPERTIES TIMEOUT 120) endif() set_tests_properties(test_inplace_addto_strategy PROPERTIES TIMEOUT 120) +set_tests_properties(test_eigvals_op PROPERTIES TIMEOUT 400) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 2f9c0530227..a50a667f663 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1368,6 +1368,12 @@ class OpTest(unittest.TestCase): outs.sort(key=len) checker(outs) + def check_output_with_place_customized(self, checker, place): + outs = self.calc_output(place) + outs = [np.array(out) for out in outs] + outs.sort(key=len) + checker(outs) + def _assert_is_close(self, numeric_grads, analytic_grads, names, max_relative_error, msg_prefix): for a, b, name in six.moves.zip(numeric_grads, analytic_grads, names): diff --git a/python/paddle/fluid/tests/unittests/test_eigvals_op.py b/python/paddle/fluid/tests/unittests/test_eigvals_op.py new file mode 100644 index 00000000000..eff9d4ea6e8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eigvals_op.py @@ -0,0 +1,307 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import unittest +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +from op_test import OpTest + +np.set_printoptions(threshold=np.inf) + + +def np_eigvals(a): + res = np.linalg.eigvals(a) + if (a.dtype == np.float32 or a.dtype == np.complex64): + res = res.astype(np.complex64) + else: + res = res.astype(np.complex128) + + return res + + +class TestEigvalsOp(OpTest): + def setUp(self): + np.random.seed(0) + paddle.enable_static() + self.op_type = "eigvals" + self.set_dtype() + self.set_input_dims() + self.set_input_data() + + np_output = np_eigvals(self.input_data) + + self.inputs = {'X': self.input_data} + self.outputs = {'Out': np_output} + + def set_dtype(self): + self.dtype = np.float32 + + def set_input_dims(self): + self.input_dims = (5, 5) + + def set_input_data(self): + if (self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype( + self.dtype) + else: + self.input_data = ( + np.random.random(self.input_dims) + + np.random.random(self.input_dims) * 1j).astype(self.dtype) + + def test_check_output(self): + self.__class__.no_need_check_grad = True + self.check_output_with_place_customized( + checker=self.verify_output, place=core.CPUPlace()) + + def verify_output(self, outs): + actual_outs = np.sort(np.array(outs[0])) + expect_outs = np.sort(np.array(self.outputs['Out'])) + self.assertTrue( + actual_outs.shape == expect_outs.shape, "Output shape has diff.\n" + "Expect shape " + str(expect_outs.shape) + "\n" + "But Got" + + str(actual_outs.shape) + " in class " + self.__class__.__name__) + + n_dim = actual_outs.shape[-1] + for actual_row, expect_row in zip( + actual_outs.reshape((-1, n_dim)), + expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim, )) + for i in range(n_dim): + is_mapped = False + for j in range(n_dim): + if is_mapped_index[j] == 0 and np.isclose( + np.array(actual_row[i]), + np.array(expect_row[j]), + atol=1e-5): + is_mapped_index[j] = True + is_mapped = True + break + self.assertTrue( + is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + "\n" + "But Got" + + str(actual_outs) + "\nThe data " + str(actual_row[i]) + + " in " + str(actual_row) + " mismatch.") + + +class TestEigvalsOpFloat64(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.float64 + + +class TestEigvalsOpComplex64(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.complex64 + + +class TestEigvalsOpComplex128(TestEigvalsOp): + def set_dtype(self): + self.dtype = np.complex128 + + +class TestEigvalsOpLargeScare(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (128, 128) + + +class TestEigvalsOpLargeScareFloat64(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.float64 + + +class TestEigvalsOpLargeScareComplex64(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.complex64 + + +class TestEigvalsOpLargeScareComplex128(TestEigvalsOpLargeScare): + def set_dtype(self): + self.dtype = np.complex128 + + +class TestEigvalsOpBatch1(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (1, 2, 3, 4, 4) + + +class TestEigvalsOpBatch2(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (3, 1, 4, 5, 5) + + +class TestEigvalsOpBatch3(TestEigvalsOp): + def set_input_dims(self): + self.input_dims = (6, 2, 9, 6, 6) + + +class TestEigvalsAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + + self.small_dims = [6, 6] + self.large_dims = [128, 128] + self.batch_dims = [6, 9, 2, 2] + + self.set_dtype() + + self.input_dims = self.small_dims + self.set_input_data() + self.small_input = np.copy(self.input_data) + + self.input_dims = self.large_dims + self.set_input_data() + self.large_input = np.copy(self.input_data) + + self.input_dims = self.batch_dims + self.set_input_data() + self.batch_input = np.copy(self.input_data) + + def set_dtype(self): + self.dtype = np.float32 + + def set_input_data(self): + if (self.dtype == np.float32 or self.dtype == np.float64): + self.input_data = np.random.random(self.input_dims).astype( + self.dtype) + else: + self.input_data = ( + np.random.random(self.input_dims) + + np.random.random(self.input_dims) * 1j).astype(self.dtype) + + def verify_output(self, actural_outs, expect_outs): + actual_outs = np.array(actural_outs) + expect_outs = np.array(expect_outs) + self.assertTrue( + actual_outs.shape == expect_outs.shape, "Output shape has diff." + "\nExpect shape " + str(expect_outs.shape) + "\n" + "But Got" + + str(actual_outs.shape) + " in class " + self.__class__.__name__) + + n_dim = actual_outs.shape[-1] + for actual_row, expect_row in zip( + actual_outs.reshape((-1, n_dim)), + expect_outs.reshape((-1, n_dim))): + is_mapped_index = np.zeros((n_dim, )) + for i in range(n_dim): + is_mapped = False + for j in range(n_dim): + if is_mapped_index[j] == 0 and np.isclose( + np.array(actual_row[i]), + np.array(expect_row[j]), + atol=1e-5): + is_mapped_index[j] = True + is_mapped = True + break + self.assertTrue( + is_mapped, + "Output has diff in class " + self.__class__.__name__ + + "\nExpect " + str(expect_outs) + "\n" + "But Got" + + str(actual_outs) + "\nThe data " + str(actual_row[i]) + + " in " + str(actual_row) + " mismatch.") + + def run_dygraph(self, place): + paddle.disable_static() + paddle.set_device("cpu") + small_input_tensor = paddle.to_tensor(self.small_input, place=place) + large_input_tensor = paddle.to_tensor(self.large_input, place=place) + batch_input_tensor = paddle.to_tensor(self.batch_input, place=place) + + paddle_outs = paddle.linalg.eigvals(small_input_tensor, name='small_x') + np_outs = np_eigvals(self.small_input) + self.verify_output(paddle_outs, np_outs) + + paddle_outs = paddle.linalg.eigvals(large_input_tensor, name='large_x') + np_outs = np_eigvals(self.large_input) + self.verify_output(paddle_outs, np_outs) + + paddle_outs = paddle.linalg.eigvals(batch_input_tensor, name='small_x') + np_outs = np_eigvals(self.batch_input) + self.verify_output(paddle_outs, np_outs) + + def run_static(self, place): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + small_input_tensor = paddle.static.data( + name='small_x', shape=self.small_dims, dtype=self.dtype) + large_input_tensor = paddle.static.data( + name='large_x', shape=self.large_dims, dtype=self.dtype) + batch_input_tensor = paddle.static.data( + name='batch_x', shape=self.batch_dims, dtype=self.dtype) + + small_outs = paddle.linalg.eigvals( + small_input_tensor, name='small_x') + large_outs = paddle.linalg.eigvals( + large_input_tensor, name='large_x') + batch_outs = paddle.linalg.eigvals( + batch_input_tensor, name='batch_x') + + exe = paddle.static.Executor(place) + + paddle_outs = exe.run( + feed={ + "small_x": self.small_input, + "large_x": self.large_input, + "batch_x": self.batch_input + }, + fetch_list=[small_outs, large_outs, batch_outs]) + + np_outs = np_eigvals(self.small_input) + self.verify_output(paddle_outs[0], np_outs) + + np_outs = np_eigvals(self.large_input) + self.verify_output(paddle_outs[1], np_outs) + + np_outs = np_eigvals(self.batch_input) + self.verify_output(paddle_outs[2], np_outs) + + def test_cases(self): + places = [core.CPUPlace()] + #if core.is_compiled_with_cuda(): + # places.append(core.CUDAPlace(0)) + for place in places: + self.run_dygraph(place) + self.run_static(place) + + def test_error(self): + paddle.disable_static() + x = paddle.to_tensor([1]) + with self.assertRaises(BaseException): + paddle.linalg.eigvals(x) + + self.input_dims = [1, 2, 3, 4] + self.set_input_data() + x = paddle.to_tensor(self.input_data) + with self.assertRaises(BaseException): + paddle.linalg.eigvals(x) + + +class TestEigvalsAPIFloat64(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.float64 + + +class TestEigvalsAPIComplex64(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.complex64 + + +class TestEigvalsAPIComplex128(TestEigvalsAPI): + def set_dtype(self): + self.dtype = np.complex128 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index cad6580b661..361bee09def 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -17,6 +17,7 @@ from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor import inverse as inv # noqa: F401 +from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import matrix_rank from .tensor.linalg import svd @@ -28,6 +29,7 @@ __all__ = [ 'norm', 'cond', 'inv', + 'eigvals', 'multi_dot', 'matrix_rank', 'svd', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index fc21efc1bb1..052ffb12d47 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -46,6 +46,7 @@ from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import matrix_power # noqa: F401 +from .linalg import eigvals # noqa: F401 from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 from .linalg import eigh # noqa: F401 @@ -231,6 +232,7 @@ tensor_method_func = [ #noqa 'histogram', 'mv', 'matrix_power', + 'eigvals', 'abs', 'acos', 'all', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 3873e9f2978..c7862f61894 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1490,6 +1490,66 @@ def matrix_power(x, n, name=None): return out +def eigvals(x, name=None): + """ + Compute the eigenvalues of one or more general matrices. + + Warning: + The gradient kernel of this operator does not yet developed. + If you need back propagation through this operator, please replace it with paddle.linalg.eig. + + Args: + x (Tensor): A square matrix or a batch of square matrices whose eigenvalues will be computed. + Its shape should be `[*, M, M]`, where `*` is zero or more batch dimensions. + Its data type should be float32, float64, complex64, or complex128. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: A tensor containing the unsorted eigenvalues which has the same batch dimensions with `x`. + The eigenvalues are complex-valued even when `x` is real. + + Examples: + .. code-block:: python + + import paddle + + paddle.set_device("cpu") + paddle.seed(1234) + + x = paddle.rand(shape=[3, 3], dtype='float64') + # [[0.02773777, 0.93004224, 0.06911496], + # [0.24831591, 0.45733623, 0.07717843], + # [0.48016702, 0.14235102, 0.42620817]]) + + print(paddle.linalg.eigvals(x)) + # [(-0.27078833542132674+0j), (0.29962280156230725+0j), (0.8824477020120244+0j)] #complex128 + """ + + check_variable_and_dtype(x, 'dtype', + ['float32', 'float64', 'complex64', 'complex128'], + 'eigvals') + + x_shape = list(x.shape) + if len(x_shape) < 2: + raise ValueError( + "The dimension of Input(x) should be at least 2, but received x's dimention = {}, x's shape = {}". + format(len(x_shape), x_shape)) + + if x_shape[-1] != x_shape[-2]: + raise ValueError( + "The last two dimensions of Input(x) should be equal, but received x's shape = {}". + format(x_shape)) + + if in_dygraph_mode(): + return _C_ops.eigvals(x) + + helper = LayerHelper('eigvals', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type='eigvals', inputs={'X': x}, outputs={'Out': out}) + return out + + def multi_dot(x, name=None): """ Multi_dot is an operator that calculates multiple matrix multiplications. -- GitLab