From bc7e2b921d4b450f082c61d92b27a9b9479a5c7b Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Tue, 28 Sep 2021 17:05:59 +0800 Subject: [PATCH] add API paddle.linalg.eig (#35674) * Add paddle.linalg.eig op * remove comments * remove comments * extend batch_size to the origin * add real times complex functor & destroy the backward complex output bug * terminate output diff when input real tensors * correct tiny doc errors * move functions from eig_helper to svd_helper and remove eig_helper * remove tensor.Resize * remove no longer used code * use existing lapack functions * reply review comments 21/27 * remove .cu as this op is only executed on CPU * remove const_cast & add const in argument list for read-only references * fix sample code error in CI * remove template typename Tbase and more * remove eig exposure in paddle.* * add 'name=None' in eig python implementation * handle the unittest * try to solve the unittest * solve CI coverage * remove no longer used code * polish API doc and more * reply review comments * polish unittest, commit plan B * polish unittest --- paddle/fluid/operators/eig_op.cc | 168 +++++++++ paddle/fluid/operators/eig_op.h | 330 ++++++++++++++++++ paddle/fluid/operators/math/matrix_solve.h | 40 +++ paddle/fluid/operators/svd_helper.h | 66 ++++ .../paddle/fluid/tests/unittests/op_test.py | 4 + .../fluid/tests/unittests/test_eig_op.py | 250 +++++++++++++ python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 67 ++++ 9 files changed, 929 insertions(+) create mode 100644 paddle/fluid/operators/eig_op.cc create mode 100644 paddle/fluid/operators/eig_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_eig_op.py diff --git a/paddle/fluid/operators/eig_op.cc b/paddle/fluid/operators/eig_op.cc new file mode 100644 index 0000000000..c1aac4546e --- /dev/null +++ b/paddle/fluid/operators/eig_op.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/eig_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class EigOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eig"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", + "Eig"); + OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", + "Eig"); + + auto x_dims = ctx->GetInputDim("X"); + int rank = x_dims.size(); + PADDLE_ENFORCE_GE(rank, 2, platform::errors::InvalidArgument( + "Expects input tensor x to be not less than " + "2 dimentions, but got dimention %d", + rank)); + PADDLE_ENFORCE_EQ(x_dims[rank - 2], x_dims[rank - 1], + platform::errors::InvalidArgument( + "The input matrix must be a square matrix, " + "but receive a matrix with %d rows and %d colums", + x_dims[rank - 2], x_dims[rank - 1])); + + std::vector batch_dims_vec{}; + for (int i = 0; i < rank - 1; ++i) { + batch_dims_vec.emplace_back(x_dims[i]); + } + + ctx->SetOutputDim("Eigenvectors", x_dims); + ctx->SetOutputDim("Eigenvalues", framework::make_ddim(batch_dims_vec)); + } + + protected: + // The output of eig is always complex-valued even for real-valued inputs + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + if (dtype != framework::proto::VarType::FP32 && + dtype != framework::proto::VarType::FP64 && + dtype != framework::proto::VarType::COMPLEX64 && + dtype != framework::proto::VarType::COMPLEX128) { + PADDLE_THROW(platform::errors::InvalidArgument( + "unsupported data type: %s!", dtype)); + } + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +class EigOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "X", + "(Tensor), A complex-valued or real-valued tensor with shape (*, " + "n, n). The accepted datatype is one of float32, float64, complex64 " + "or complex128"); + AddOutput("Eigenvalues", + "(Tensor), The output eigenvalues tensor with shape (*, n). The " + "datatype is complex64 or complex128"); + AddOutput("Eigenvectors", + "(Tensor), The output eigenvectors tensor with shape (*, n, n). " + "The datatype is complex64 or complex128"); + + AddComment(R"DOC( + Eig Operator. + +This API processes eigen decomposition for general square matrices. + +)DOC"); + } +}; + +class EigGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", + "EigGrad"); + OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", + "EigGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), + "Input", "Eigenvalues@GRAD", "EigGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")), + "Input", "Eigenvectors@GRAD", "EigGrad"); + + auto dims = ctx->GetInputDim("Eigenvectors"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Eigenvectors")), + ctx.device_context()); + } +}; + +template +class EigGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Eigenvalues", this->Output("Eigenvalues")); + op->SetInput("Eigenvectors", this->Output("Eigenvectors")); + op->SetInput(framework::GradVarName("Eigenvalues"), + this->OutputGrad("Eigenvalues")); + op->SetInput(framework::GradVarName("Eigenvectors"), + this->OutputGrad("Eigenvectors")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +using complex64 = paddle::platform::complex; +using complex128 = paddle::platform::complex; + +namespace ops = paddle::operators; +REGISTER_OPERATOR(eig, ops::EigOp, ops::EigOpMaker, + ops::EigGradOpMaker, + ops::EigGradOpMaker); + +REGISTER_OPERATOR(eig_grad, ops::EigGradOp); + +REGISTER_OP_CPU_KERNEL( + eig, ops::EigKernel, + ops::EigKernel, + ops::EigKernel, + ops::EigKernel); + +REGISTER_OP_CPU_KERNEL( + eig_grad, + ops::EigGradKernel, + ops::EigGradKernel, + ops::EigGradKernel, + ops::EigGradKernel); diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h new file mode 100644 index 0000000000..b9a3cb300b --- /dev/null +++ b/paddle/fluid/operators/eig_op.h @@ -0,0 +1,330 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/matrix_solve.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/for_range.h" +#define EPSILON 1e-6 + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +inline int BatchCount(const Tensor& matrix) { + int count = 1; + int num_dims = matrix.dims().size(); + for (int i = 0; i < num_dims - 2; ++i) { + count *= matrix.dims()[i]; + } + return count; +} + +inline int MatrixStride(const Tensor& matrix) { + framework::DDim dims_list = matrix.dims(); + int num_dims = dims_list.size(); + return dims_list[num_dims - 1] * dims_list[num_dims - 2]; +} + +// Transpose two axis of a Tensor +template +void TransposeTwoAxis(const Tensor& input, Tensor* transposed_input, + const int axis1, const int axis2, + const framework::ExecutionContext& context) { + std::vector permute(input.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis1] = axis2; + permute[axis2] = axis1; + + transposed_input->mutable_data(input.dims(), context.GetPlace()); + auto& dev_ctx = context.template device_context(); + + TransCompute(input.dims().size(), dev_ctx, input, + transposed_input, permute); +} + +// Apply eig to a batch of matrices, values, vectors and (intermidiate +// tensor) info are overritten +template +void LapackEig(Tensor* input, Tensor* values, Tensor* vectors, int info, + const framework::ExecutionContext& context) { + char jobvl = 'N'; + char jobvr = 'V'; // only right eigenvectors are computed + int num_dims = input->dims().size(); + int order = input->dims()[num_dims - 1]; + + T* input_data = input->data(); + int lda = std::max(1, order); + T* values_data = values->mutable_data(context.GetPlace()); + T* lvector_data = nullptr; + int ldvl = 1; + T* rvector_data = vectors->mutable_data(context.GetPlace()); + int ldvr = lda; + int lwork = -1; + + int batch_count = BatchCount(*input); + int matrix_stride = MatrixStride(*input); + int values_stride = values->dims()[values->dims().size() - 1]; + + Tensor rwork; + math::Real* rwork_data = nullptr; + + rwork.Resize(framework::make_ddim({lda * 2})); + rwork_data = rwork.mutable_data>(context.GetPlace()); + + // call lapackEig once to compute the size of work; + T computed_work_size; + math::lapackEig>( + jobvl, jobvr, order, input_data, lda, values_data, lvector_data, ldvl, + rvector_data, ldvr, &computed_work_size, lwork, rwork_data, &info); + + lwork = std::max(1, static_cast(math::Real(computed_work_size))); + Tensor work; + work.Resize(framework::make_ddim({lwork})); + T* work_data = work.mutable_data(context.GetPlace()); + + for (auto i = 0; i < batch_count; ++i) { + T* current_matrix = &input_data[i * matrix_stride]; + T* current_values = &values_data[i * values_stride]; + T* current_rvectors = &rvector_data[i * matrix_stride]; + + math::lapackEig>( + jobvl, jobvr, order, current_matrix, lda, current_values, lvector_data, + ldvl, current_rvectors, ldvr, work_data, lwork, rwork_data, &info); + PADDLE_ENFORCE_EQ( + info, 0, + platform::errors::PreconditionNotMet( + "current info is not 0, computation failed. " + "= 0: successful exit." + "< 0: if INFO = -i, the i-th argument had an illegal value." + "> 0: if INFO = i, the QR algorithm failed to compute all the " + "eigenvalues, and no eigenvectors have been computed; " + "elements i+1:N of WR and WI contain eigenvalues which " + "have converged.")); + } +} + +template +void ApplyEigKernel(const Tensor& input, Tensor* values, Tensor* vectors, + const framework::ExecutionContext& context) { + Tensor input_column_major; + Tensor vectors_row_major; + int num_dims = input.dims().size(); + + // transfer to column-major memory layout i.e. make_ddim from tranposed_input: + // [batch,row,col]->[batch,col,row] + TransposeTwoAxis(input, &input_column_major, num_dims - 1, + num_dims - 2, context); + // make sure 'vectors_row_major' holds memory before passed to LapackEig() + vectors_row_major.Resize(input.dims()); + int info = 0; + LapackEig(&input_column_major, values, &vectors_row_major, info, context); + + // transfer column-major layout back + // vectors_row_major: column-major layout + // vector: original layout + TransposeTwoAxis(vectors_row_major, vectors, num_dims - 1, + num_dims - 2, context); +} + +template +void ConstructComplexVectors(Tensor* c_vectors, const Tensor& c_values, + const Tensor& r_vectors, + const framework::ExecutionContext& ctx, + int batch_count, int order) { + int matrix_stride = MatrixStride(r_vectors); + + auto* c_vectors_data = c_vectors->mutable_data(ctx.GetPlace()); + auto* c_values_data = c_values.data(); + auto* r_v_data = r_vectors.data(); + + for (int b = 0; b < batch_count; b++) { + auto* vecs = &r_v_data[b * matrix_stride]; + auto* res = &c_vectors_data[b * matrix_stride]; + auto* vals = &c_values_data[b * order]; + + for (int j = 0; j < order; j++) { + if (vals[j].imag < EPSILON) { + for (int i = 0; i < order; i++) { + res[j * order + i] = platform::complex(vecs[j * order + i], 0); + } + } else { + for (int i = 0; i < order; i++) { + res[j * order + i] = platform::complex(vecs[j * order + i], + vecs[(j + 1) * order + i]); + res[(j + 1) * order + i] = platform::complex( + vecs[j * order + i], -vecs[(j + 1) * order + i]); + } + j++; + } + } + } +} + +template +class EigKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out_values = context.Output("Eigenvalues"); + auto* out_vectors = context.Output("Eigenvectors"); + + if (!framework::IsComplexType(x->type())) { + out_values->mutable_data(context.GetPlace()); + out_vectors->mutable_data(context.GetPlace()); + + int batch_count = BatchCount(*x); + int order = x->dims()[x->dims().size() - 1]; + + Tensor real_values; + Tensor real_vectors; + // double the size of real_values, the first half stores the real part, + // the next half stores the imag part + std::vector origin_dim = + framework::vectorize(out_values->dims()); + int last_item = origin_dim.back(); + origin_dim.pop_back(); + origin_dim.push_back(last_item * 2); + framework::DDim big_dim = framework::make_ddim(origin_dim); + + real_values.mutable_data>(big_dim, context.GetPlace()); + real_vectors.mutable_data>(x->dims(), context.GetPlace()); + + ApplyEigKernel>(*x, &real_values, + &real_vectors, context); + auto dito = + math::DeviceIndependenceTensorOperations, + Tout>(context); + + // 1. extract real part & imag part from real_values + Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order}); + Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2}); + + // 2. construct complex values + auto* real_part_data = real_part.data>(); + auto* imag_part_data = imag_part.data>(); + int out_values_numel = out_values->numel(); + platform::ForRange for_range( + context.template device_context(), out_values_numel); + math::RealImagToComplexFunctor functor( + real_part_data, imag_part_data, + out_values->mutable_data(context.GetPlace()), out_values_numel); + for_range(functor); + + // 3. construct complex vectors + Tensor real_vector_trans = dito.Transpose(real_vectors); + Tensor out_vectors_trans; + out_vectors_trans.mutable_data(x->dims(), context.GetPlace()); + ConstructComplexVectors, Tout>( + &out_vectors_trans, *out_values, real_vector_trans, context, + batch_count, order); + TransposeTwoAxis(out_vectors_trans, out_vectors, + x->dims().size() - 1, + x->dims().size() - 2, context); + } else { + out_values->mutable_data(context.GetPlace()); + out_vectors->mutable_data(context.GetPlace()); + + ApplyEigKernel(*x, out_values, out_vectors, context); + } + } +}; + +template +void ComputeBackwardForComplexInput( + const Tensor& V, const Tensor& L, const Tensor& gL, const Tensor& gV, + Tout* x_grad_data, int batch_count, int order, + const framework::ExecutionContext& context) { + auto dito = + math::DeviceIndependenceTensorOperations( + context); + + Tensor trans_v = dito.Transpose(V); + Tensor Vh = dito.Conj(trans_v); + Tensor Lconj = dito.Conj(L); + Tensor Econj = dito.Sub(dito.Unsqueeze(Lconj, -2), dito.Unsqueeze(Lconj, -1)); + Tensor VhgV = dito.Matmul(Vh, gV); + Tensor diag_real = dito.Real(VhgV); + Tensor diag_res = dito.BatchDiag(diag_real, batch_count); + Tensor diag_unsqueezed = dito.Unsqueeze(diag_res, -2); + + // turn diag_unsqueezed into complex + auto numel = diag_unsqueezed.numel(); + Tensor diag_unsqueezed_complex; + auto* data_diag_un = diag_unsqueezed.data>(); + auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data( + diag_unsqueezed.dims(), context.GetPlace(), + static_cast(numel * sizeof(Tout))); + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + math::RealToComplexFunctor functor(data_diag_un, data_diag_un_com, + numel); + for_range(functor); + // real tensor multiply complex tensor in broadcast manner + Tensor res1 = dito.RealMulComplex(V, diag_unsqueezed_complex); + Tensor res2 = dito.Matmul(Vh, res1); + Tensor result = dito.Sub(VhgV, res2); + + result.mutable_data(V.dims(), context.GetPlace()); + result = dito.Div(result, Econj); + result = dito.DiagFill(order, order, order, 0, gL, result); + Tensor rhs = dito.Matmul(result, Vh); + + // solve linear system + // solve(Vh, rhs, out, m, k) + // Vh: matrix with shape [m,m] + // rhs: rhs with shape [m,k] + // x_grad: out + int m = Vh.dims()[Vh.dims().size() - 1]; + int k = rhs.dims()[rhs.dims().size() - 1]; + auto* matrix_data = Vh.data(); + auto* rhs_data = rhs.data(); + math::SolveLinearSystem(matrix_data, rhs_data, x_grad_data, m, k, + batch_count); +} + +template +class EigGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& L = *context.Input("Eigenvalues"); + auto& V = *context.Input("Eigenvectors"); + auto& gL = *context.Input(framework::GradVarName("Eigenvalues")); + auto& gV = *context.Input(framework::GradVarName("Eigenvectors")); + + auto& x_grad = *context.Output(framework::GradVarName("X")); + auto* x_grad_data = x_grad.mutable_data(context.GetPlace()); + + auto& dims = V.dims(); + framework::DDim dim_origin = dims; + int num_dims = dim_origin.size(); + int batch_count = BatchCount(V); + const int order = dim_origin[num_dims - 1]; + + ComputeBackwardForComplexInput( + V, L, gL, gV, x_grad_data, batch_count, order, context); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.h b/paddle/fluid/operators/math/matrix_solve.h index 93c37ae425..415d0c6dd8 100644 --- a/paddle/fluid/operators/math/matrix_solve.h +++ b/paddle/fluid/operators/math/matrix_solve.h @@ -70,6 +70,46 @@ void compute_solve_eigen(const DeviceContext& context, } } +// only used for complex input +template +void SolveLinearSystem(T* matrix_data, T* rhs_data, T* out_data, int order, + int rhs_cols, int batch) { + using Treal = typename Eigen::NumTraits::Real; + + // cast paddle::complex into std::complex + std::complex* matrix_data_ = + reinterpret_cast*>(matrix_data); + std::complex* rhs_data_ = + reinterpret_cast*>(rhs_data); + std::complex* out_data_ = + reinterpret_cast*>(out_data); + + using Matrix = Eigen::Matrix, Eigen::Dynamic, + Eigen::Dynamic, Eigen::RowMajor>; + using InputMatrixMap = Eigen::Map; + using OutputMatrixMap = Eigen::Map; + + for (int i = 0; i < batch; ++i) { + auto input_matrix = + InputMatrixMap(matrix_data_ + i * order * order, order, order); + auto input_rhs = + InputMatrixMap(rhs_data_ + i * order * rhs_cols, order, rhs_cols); + auto output = + OutputMatrixMap(out_data_ + i * order * rhs_cols, order, rhs_cols); + + Eigen::PartialPivLU lu_decomposition(order); + lu_decomposition.compute(input_matrix); + + const Treal min_abs_piv = + lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT(min_abs_piv, Treal(0), + platform::errors::InvalidArgument( + "Something's wrong with SolveLinearSystem. ")); + + output = lu_decomposition.solve(input_rhs); + } +} + template class MatrixSolveFunctor { public: diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index d592c62d49..9ba7c9a306 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -96,6 +96,20 @@ struct PowFunctor { float exp_; }; +template +struct RealMulComplexFunctor { + // x: complex number (a+bj) + // y: complex number (c+0j) pretend to be a real number + // out: complex number (ac+bcj) + inline HOSTDEVICE T operator()(T x, T y) { + PADDLE_ENFORCE_LT(y.imag, 1e-6, platform::errors::InvalidArgument( + "The image part of y must to be 0" + "but got [%d]", + y.imag)); + return platform::complex>(x.real * y.real, x.imag * y.real); + } +}; + static std::vector GetBroadcastShape(InTensors ins) { PADDLE_ENFORCE_EQ(ins.size(), 2, platform::errors::InvalidArgument( "GetBroadcastShape Receive 2 tensors" @@ -286,6 +300,45 @@ struct DeviceIndependenceTensorOperations { for_range(DiagFunctor(x.data(), x.numel(), output)); return ret; } + + // batch_diag for CPU only + Tensor BatchDiag(const Tensor& x, int batch) { + Tensor out; + auto* x_data = x.data>(); + auto numel = x.numel(); + auto* out_data = out.mutable_data>( + x.dims(), context.GetPlace(), + static_cast(numel * sizeof(math::Real))); + + auto x_dims = x.dims(); + int num_dims = x_dims.size(); + std::vector out_shape; + + for (int i = 0; i < num_dims - 1; ++i) { + out_shape.push_back(x.dims()[i]); + } + out.Resize(framework::make_ddim(out_shape)); + int order = x.dims()[num_dims - 1]; + int stride_out = order * order; + int stride_in = order + 1; + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < order; ++j) { + out_data[i * order + j] = x_data[stride_out * i + stride_in * j]; + } + } + return out; + } + + // a complex number x times a real number y, which is represented as (a+0j) + Tensor RealMulComplex(const Tensor& x, const Tensor& y) { + framework::Tensor ret; + std::vector out_shape = GetBroadcastShape({&x, &y}); + ret.Resize(framework::make_ddim(out_shape)); + ElementwiseComputeEx, DeviceContext, T>( + context, &x, &y, -1, RealMulComplexFunctor(), &ret); + return ret; + } + framework::Tensor Div(const framework::Tensor& x, const framework::Tensor& y) { framework::Tensor ret; @@ -459,6 +512,19 @@ struct DeviceIndependenceTensorOperations { return out; } + Tensor Real(const Tensor& x) { + Tensor out; + auto numel = x.numel(); + auto* out_data = out.mutable_data>( + x.dims(), context.GetPlace(), + static_cast(numel * sizeof(math::Real))); + auto* x_data = x.data(); + auto for_range = GetForRange(numel); + math::RealFunctor functor(x_data, out_data, numel); + for_range(functor); + return out; + } + Tensor DiagFill(const int m, const int n, const int num_lower_diags, const int num_upper_diags, const Tensor& scale, const Tensor& input) { diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index a50a667f66..3621d20fa2 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -134,6 +134,10 @@ def get_numeric_gradient(place, delta = np.array(delta).astype(np.float16) elif tensor_to_check_dtype == core.VarDesc.VarType.BF16: tensor_to_check_dtype = np.float32 + elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX64: + tensor_to_check_dtype = np.complex64 + elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX128: + tensor_tp_check_dtype = np.complex128 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) diff --git a/python/paddle/fluid/tests/unittests/test_eig_op.py b/python/paddle/fluid/tests/unittests/test_eig_op.py new file mode 100644 index 0000000000..bb83de7d0d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eig_op.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +import unittest +from paddle.fluid.op import Operator +from paddle.fluid import compiler, Program, program_guard + + +# cast output to complex for numpy.linalg.eig +def cast_to_complex(input, output): + if (input.dtype == np.float32): + output = output.astype(np.complex64) + elif (input.dtype == np.float64): + output = output.astype(np.complex128) + return output + + +# define eig backward function for a single square matrix +def eig_backward(w, v, grad_w, grad_v): + v_tran = np.transpose(v) + v_tran = np.conjugate(v_tran) + w_conj = np.conjugate(w) + w_conj_l = w_conj.reshape(1, w.size) + w_conj_r = w_conj.reshape(w.size, 1) + w_conj_2d = w_conj_l - w_conj_r + + vhgv = np.matmul(v_tran, grad_v) + real_vhgv = np.real(vhgv) + diag_real = real_vhgv.diagonal() + + diag_2d = diag_real.reshape(1, w.size) + rhs = v * diag_2d + mid = np.matmul(v_tran, rhs) + result = vhgv - mid + + res = np.divide(result, w_conj_2d) + row, col = np.diag_indices_from(res) + res[row, col] = 1.0 + + tmp = np.matmul(res, v_tran) + dx = np.linalg.solve(v_tran, tmp) + return dx + + +class TestEigOp(OpTest): + def setUp(self): + paddle.enable_static() + paddle.device.set_device("cpu") + self.op_type = "eig" + self.__class__.op_type = self.op_type + self.init_input() + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(self.x)} + self.outputs = {'Eigenvalues': self.out[0], 'Eigenvectors': self.out[1]} + + def init_input(self): + self.set_dtype() + self.set_dims() + self.x = np.random.random(self.shape).astype(self.dtype) + self.out = np.linalg.eig(self.x) + self.out = (cast_to_complex(self.x, self.out[0]), + cast_to_complex(self.x, self.out[1])) + + # for the real input, a customized checker is needed + def checker(self, outs): + actual_out_w = outs[0].flatten() + expect_out_w = self.out[0].flatten() + actual_out_v = outs[1].flatten() + expect_out_v = self.out[1].flatten() + + length_w = len(expect_out_w) + act_w_real = np.sort( + np.array([np.abs(actual_out_w[i].real) for i in range(length_w)])) + act_w_imag = np.sort( + np.array([np.abs(actual_out_w[i].imag) for i in range(length_w)])) + exp_w_real = np.sort( + np.array([np.abs(expect_out_w[i].real) for i in range(length_w)])) + exp_w_imag = np.sort( + np.array([np.abs(expect_out_w[i].imag) for i in range(length_w)])) + + for i in range(length_w): + self.assertTrue( + np.allclose(act_w_real[i], exp_w_real[i], 1e-6, 1e-5), + "The eigenvalues real part have diff: \nExpected " + + str(act_w_real[i]) + "\n" + "But got: " + str(exp_w_real[i])) + self.assertTrue( + np.allclose(act_w_imag[i], exp_w_imag[i], 1e-6, 1e-5), + "The eigenvalues image part have diff: \nExpected " + + str(act_w_imag[i]) + "\n" + "But got: " + str(exp_w_imag[i])) + + length_v = len(expect_out_v) + act_v_real = np.sort( + np.array([np.abs(actual_out_v[i].real) for i in range(length_v)])) + act_v_imag = np.sort( + np.array([np.abs(actual_out_v[i].imag) for i in range(length_v)])) + exp_v_real = np.sort( + np.array([np.abs(expect_out_v[i].real) for i in range(length_v)])) + exp_v_imag = np.sort( + np.array([np.abs(expect_out_v[i].imag) for i in range(length_v)])) + + for i in range(length_v): + self.assertTrue( + np.allclose(act_v_real[i], exp_v_real[i], 1e-6, 1e-5), + "The eigenvectors real part have diff: \nExpected " + + str(act_v_real[i]) + "\n" + "But got: " + str(exp_v_real[i])) + self.assertTrue( + np.allclose(act_v_imag[i], exp_v_imag[i], 1e-6, 1e-5), + "The eigenvectors image part have diff: \nExpected " + + str(act_v_imag[i]) + "\n" + "But got: " + str(exp_v_imag[i])) + + def set_dtype(self): + self.dtype = np.complex64 + + def set_dims(self): + self.shape = (10, 10) + + def init_grad(self): + # grad_w, grad_v complex dtype + gtype = self.dtype + if self.dtype == np.float32: + gtype = np.complex64 + elif self.dtype == np.float64: + gtype = np.complex128 + self.grad_w = np.ones(self.out[0].shape, gtype) + self.grad_v = np.ones(self.out[1].shape, gtype) + self.grad_x = eig_backward(self.out[0], self.out[1], self.grad_w, + self.grad_v) + + def test_check_output(self): + self.check_output_with_place_customized( + checker=self.checker, place=core.CPUPlace()) + + def test_check_grad(self): + self.init_grad() + self.check_grad( + ['X'], ['Eigenvalues', 'Eigenvectors'], + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_w, self.grad_v]) + + +class TestComplex128(TestEigOp): + def set_dtype(self): + self.dtype = np.complex128 + + +@skip_check_grad_ci( + reason="For float dtype, numpy.linalg.eig forward outputs real or complex when input is real, therefore the grad computation may be not the same with paddle.linalg.eig" +) +class TestDouble(TestEigOp): + def set_dtype(self): + self.dtype = np.float64 + + def test_check_grad(self): + pass + + +@skip_check_grad_ci( + reason="For float dtype, numpy.linalg.eig forward outputs real or complex when input is real, therefore the grad computation may be not the same with paddle.linalg.eig" +) +class TestEigBatchMarices(TestEigOp): + def set_dtype(self): + self.dtype = np.float64 + + def set_dims(self): + self.shape = (3, 10, 10) + + def test_check_grad(self): + pass + + +@skip_check_grad_ci( + reason="For float dtype, numpy.linalg.eig forward outputs real or complex when input is real, therefore the grad computation may be not the same with paddle.linalg.eig" +) +class TestFloat(TestEigOp): + def set_dtype(self): + self.dtype = np.float32 + + def test_check_grad(self): + pass + + +class TestEigStatic(TestEigOp): + def test_check_output_with_place(self): + paddle.enable_static() + place = core.CPUPlace() + input_np = np.random.random([3, 3]).astype('complex') + expect_val, expect_vec = np.linalg.eig(input_np) + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[3, 3], dtype='complex') + act_val, act_vec = paddle.linalg.eig(input) + + exe = fluid.Executor(place) + fetch_val, fetch_vec = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[act_val, act_vec]) + self.assertTrue( + np.allclose(expect_val, fetch_val, 1e-6, 1e-6), + "The eigen values have diff: \nExpected " + str(expect_val) + "\n" + + "But got: " + str(fetch_val)) + self.assertTrue( + np.allclose(np.abs(expect_vec), np.abs(fetch_vec), 1e-6, 1e-6), + "The eigen vectors have diff: \nExpected " + + str(np.abs(expect_vec)) + "\n" + "But got: " + + str(np.abs(fetch_vec))) + + +class TestEigWrongDimsError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("cpu") + paddle.disable_static() + a = np.random.random((3)).astype('float32') + x = paddle.to_tensor(a) + self.assertRaises(ValueError, paddle.linalg.eig, x) + + +class TestEigNotSquareError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("cpu") + paddle.disable_static() + a = np.random.random((1, 2, 3)).astype('float32') + x = paddle.to_tensor(a) + self.assertRaises(ValueError, paddle.linalg.eig, x) + + +class TestEigUnsupportedDtypeError(unittest.TestCase): + def test_error(self): + paddle.device.set_device("cpu") + paddle.disable_static() + a = (np.random.random((3, 3)) * 10).astype('int64') + x = paddle.to_tensor(a) + self.assertRaises(ValueError, paddle.linalg.eig, x) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index d57d9a4bdb..726355379e 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -14,6 +14,7 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 +from .tensor.linalg import eig # noqa: F401 from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import solve # noqa: F401 @@ -32,6 +33,7 @@ __all__ = [ 'norm', 'cond', 'inv', + 'eig', 'eigvals', 'multi_dot', 'matrix_rank', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 080a06455a..b5d79b6039 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -45,6 +45,7 @@ from .linalg import cholesky # noqa: F401 from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 from .linalg import mv # noqa: F401 +from .linalg import eig # noqa: F401 from .linalg import matrix_power # noqa: F401 from .linalg import eigvals # noqa: F401 from .linalg import multi_dot # noqa: F401 @@ -386,6 +387,7 @@ tensor_method_func = [ #noqa 'bitwise_xor', 'bitwise_not', 'broadcast_tensors', + 'eig', 'uniform_', 'multi_dot', 'solve', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 9ba9370a43..f112603fbb 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -23,6 +23,7 @@ import paddle from paddle.common_ops_import import core from paddle.common_ops_import import VarDesc from paddle import _C_ops +import paddle __all__ = [] @@ -1593,6 +1594,72 @@ def matrix_power(x, n, name=None): return out +def eig(x, name=None): + """ + This API performs the eigenvalue decomposition of a square matrix or a batch of square matrices. + + .. note:: + If the matrix is a Hermitian or a real symmetric matrix, please use :ref:`paddle.linalg.eigh` instead, which is much faster. + If only eigenvalues is needed, please use :ref:`paddle.linalg.eigvals` instead. + If the matrix is of any shape, please use :ref:`paddle.linalg.svd`. + This API is only supported on CPU device. + The output datatype is always complex for both real and complex input. + + Args: + x (Tensor): A tensor with shape math:`[*, N, N]`, The data type of the x should be one of ``float32``, + ``float64``, ``compplex64`` or ``complex128``. + name (str, optional): The default value is `None`. Normally there is no need for user to set + this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Eigenvalues(Tensors): A tensor with shape math:`[*, N]` refers to the eigen values. + Eigenvectors(Tensors): A tensor with shape math:`[*, N, N]` refers to the eigen vectors. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.device.set_device("cpu") + + x_data = np.array([[1.6707249, 7.2249975, 6.5045543], + [9.956216, 8.749598, 6.066444 ], + [4.4251957, 1.7983172, 0.370647 ]]).astype("float32") + x = paddle.to_tensor(x_data) + w, v = paddle.linalg.eig(x) + print(w) + # Tensor(shape=[3, 3], dtype=complex128, place=CPUPlace, stop_gradient=False, + # [[(-0.5061363550800655+0j) , (-0.7971760990842826+0j) , + # (0.18518077798279986+0j)], + # [(-0.8308237755993192+0j) , (0.3463813401919749+0j) , + # (-0.6837005269141947+0j) ], + # [(-0.23142567697893396+0j), (0.4944999840400175+0j) , + # (0.7058765252952796+0j) ]]) + + print(v) + # Tensor(shape=[3], dtype=complex128, place=CPUPlace, stop_gradient=False, + # [ (16.50471283351188+0j) , (-5.5034820550763515+0j) , + # (-0.21026087843552282+0j)]) + """ + if in_dygraph_mode(): + w, v = _C_ops.eig(x) + return w, v + + check_variable_and_dtype( + x, 'X', ['float32', 'float64', 'complex64', 'complex128'], 'eig') + helper = LayerHelper('eig', **locals()) + + w = helper.create_variable_for_type_inference(x.dtype) + v = helper.create_variable_for_type_inference(x.dtype) + + inputs = {'X': x} + outputs = {'Eigenvalues': w, 'Eigenvectors': v} + helper.append_op(type='eig', inputs=inputs, outputs=outputs) + + return w, v + + def eigvals(x, name=None): """ Compute the eigenvalues of one or more general matrices. -- GitLab