diff --git a/paddle/fluid/operators/eig_op.cc b/paddle/fluid/operators/eig_op.cc index b53bba9fac0c40ca46e7bbf23c53c6fedf825a21..d67f7f4432b22b0e12f710430c200799b7881439 100644 --- a/paddle/fluid/operators/eig_op.cc +++ b/paddle/fluid/operators/eig_op.cc @@ -17,7 +17,11 @@ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -25,37 +29,6 @@ namespace operators { class EigOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eig"); - OP_INOUT_CHECK( - ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", "Eig"); - OP_INOUT_CHECK( - ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", "Eig"); - - auto x_dims = ctx->GetInputDim("X"); - int rank = x_dims.size(); - PADDLE_ENFORCE_GE(rank, - 2, - platform::errors::InvalidArgument( - "Expects input tensor x to be not less than " - "2 dimentions, but got dimention %d", - rank)); - PADDLE_ENFORCE_EQ(x_dims[rank - 2], - x_dims[rank - 1], - platform::errors::InvalidArgument( - "The input matrix must be a square matrix, " - "but receive a matrix with %d rows and %d colums", - x_dims[rank - 2], - x_dims[rank - 1])); - - std::vector batch_dims_vec{}; - for (int i = 0; i < rank - 1; ++i) { - batch_dims_vec.emplace_back(x_dims[i]); - } - - ctx->SetOutputDim("Eigenvectors", x_dims); - ctx->SetOutputDim("Eigenvalues", phi::make_ddim(batch_dims_vec)); - } protected: // The output of eig is always complex-valued even for real-valued inputs @@ -100,26 +73,6 @@ This API processes eigen decomposition for general square matrices. class EigGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", "EigGrad"); - OP_INOUT_CHECK( - ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EigGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), - "Input", - "Eigenvalues@GRAD", - "EigGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")), - "Input", - "Eigenvectors@GRAD", - "EigGrad"); - - auto dims = ctx->GetInputDim("Eigenvectors"); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, dims); - } - } protected: framework::OpKernelType GetExpectedKernelType( @@ -152,27 +105,20 @@ class EigGradOpMaker : public framework::SingleGradOpMaker { } // namespace operators } // namespace paddle -using complex64 = paddle::platform::complex; -using complex128 = paddle::platform::complex; - namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(eig, + EigInferShapeFunctor, + PD_INFER_META(phi::EigInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(eig_grad, + EigGradInferShapeFunctor, + PD_INFER_META(phi::EigGradInferMeta)); + REGISTER_OPERATOR(eig, ops::EigOp, ops::EigOpMaker, ops::EigGradOpMaker, - ops::EigGradOpMaker); - -REGISTER_OPERATOR(eig_grad, ops::EigGradOp); - -REGISTER_OP_CPU_KERNEL(eig, - ops::EigKernel, - ops::EigKernel, - ops::EigKernel, - ops::EigKernel); - -REGISTER_OP_CPU_KERNEL( - eig_grad, - ops::EigGradKernel, - ops::EigGradKernel, - ops::EigGradKernel, - ops::EigGradKernel); + ops::EigGradOpMaker, + EigInferShapeFunctor); + +REGISTER_OPERATOR(eig_grad, ops::EigGradOp, EigGradInferShapeFunctor); diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h index 82c7fe68819697fc7e901e43db54814d241a458e..c9c81f2ed8a61cd8f03d00a91c96ed0e92c49a85 100644 --- a/paddle/fluid/operators/eig_op.h +++ b/paddle/fluid/operators/eig_op.h @@ -57,341 +57,5 @@ inline int MatrixStride(const Tensor& matrix) { return dims_list[num_dims - 1] * dims_list[num_dims - 2]; } -// Transpose two axis of a Tensor -template -void TransposeTwoAxis(const Tensor& input, - Tensor* transposed_input, - const int axis1, - const int axis2, - const framework::ExecutionContext& context) { - std::vector permute(input.dims().size()); - std::iota(permute.begin(), permute.end(), 0); - permute[axis1] = axis2; - permute[axis2] = axis1; - - transposed_input->mutable_data(input.dims(), context.GetPlace()); - auto& dev_ctx = context.template device_context(); - - TransCompute( - input.dims().size(), dev_ctx, input, transposed_input, permute); -} - -// Apply eig to a batch of matrices, values, vectors and (intermidiate -// tensor) info are overritten -template -void LapackEig(Tensor* input, - Tensor* values, - Tensor* vectors, - int info, - const framework::ExecutionContext& context) { - char jobvl = 'N'; - char jobvr = 'V'; // only right eigenvectors are computed - int num_dims = input->dims().size(); - int order = input->dims()[num_dims - 1]; - - T* input_data = input->data(); - int lda = std::max(1, order); - T* values_data = values->mutable_data(context.GetPlace()); - T* lvector_data = nullptr; - int ldvl = 1; - T* rvector_data = vectors->mutable_data(context.GetPlace()); - int ldvr = lda; - int lwork = -1; - - int batch_count = BatchCount(*input); - int matrix_stride = MatrixStride(*input); - int values_stride = values->dims()[values->dims().size() - 1]; - - Tensor rwork; - phi::dtype::Real* rwork_data = nullptr; - - rwork.Resize(phi::make_ddim({lda * 2})); - rwork_data = rwork.mutable_data>(context.GetPlace()); - - // call lapackEig once to compute the size of work; - T computed_work_size; - phi::funcs::lapackEig>(jobvl, - jobvr, - order, - input_data, - lda, - values_data, - lvector_data, - ldvl, - rvector_data, - ldvr, - &computed_work_size, - lwork, - rwork_data, - &info); - - lwork = std::max( - 1, static_cast(phi::dtype::Real(computed_work_size))); - Tensor work; - work.Resize(phi::make_ddim({lwork})); - T* work_data = work.mutable_data(context.GetPlace()); - - for (auto i = 0; i < batch_count; ++i) { - T* current_matrix = &input_data[i * matrix_stride]; - T* current_values = &values_data[i * values_stride]; - T* current_rvectors = &rvector_data[i * matrix_stride]; - - phi::funcs::lapackEig>(jobvl, - jobvr, - order, - current_matrix, - lda, - current_values, - lvector_data, - ldvl, - current_rvectors, - ldvr, - work_data, - lwork, - rwork_data, - &info); - PADDLE_ENFORCE_EQ( - info, - 0, - platform::errors::PreconditionNotMet( - "current info is not 0, computation failed. " - "= 0: successful exit." - "< 0: if INFO = -i, the i-th argument had an illegal value." - "> 0: if INFO = i, the QR algorithm failed to compute all the " - "eigenvalues, and no eigenvectors have been computed; " - "elements i+1:N of WR and WI contain eigenvalues which " - "have converged.")); - } -} - -template -void ApplyEigKernel(const Tensor& input, - Tensor* values, - Tensor* vectors, - const framework::ExecutionContext& context) { - Tensor input_column_major; - Tensor vectors_row_major; - int num_dims = input.dims().size(); - - // transfer to column-major memory layout i.e. make_ddim from tranposed_input: - // [batch,row,col]->[batch,col,row] - TransposeTwoAxis( - input, &input_column_major, num_dims - 1, num_dims - 2, context); - // make sure 'vectors_row_major' holds memory before passed to LapackEig() - vectors_row_major.Resize(input.dims()); - int info = 0; - LapackEig(&input_column_major, values, &vectors_row_major, info, context); - - // transfer column-major layout back - // vectors_row_major: column-major layout - // vector: original layout - TransposeTwoAxis( - vectors_row_major, vectors, num_dims - 1, num_dims - 2, context); -} - -template -void ConstructComplexVectors(Tensor* c_vectors, - const Tensor& c_values, - const Tensor& r_vectors, - const framework::ExecutionContext& ctx, - int batch_count, - int order) { - int matrix_stride = MatrixStride(r_vectors); - - auto* c_vectors_data = c_vectors->mutable_data(ctx.GetPlace()); - auto* c_values_data = c_values.data(); - auto* r_v_data = r_vectors.data(); - - for (int b = 0; b < batch_count; b++) { - auto* vecs = &r_v_data[b * matrix_stride]; - auto* res = &c_vectors_data[b * matrix_stride]; - auto* vals = &c_values_data[b * order]; - - for (int j = 0; j < order; j++) { - if (vals[j].imag < EPSILON) { - for (int i = 0; i < order; i++) { - res[j * order + i] = platform::complex(vecs[j * order + i], 0); - } - } else { - for (int i = 0; i < order; i++) { - res[j * order + i] = platform::complex(vecs[j * order + i], - vecs[(j + 1) * order + i]); - res[(j + 1) * order + i] = platform::complex( - vecs[j * order + i], -vecs[(j + 1) * order + i]); - } - j++; - } - } - } -} - -template -class EigKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* out_values = context.Output("Eigenvalues"); - auto* out_vectors = context.Output("Eigenvectors"); - - if (!framework::IsComplexType(framework::TransToProtoVarType(x->dtype()))) { - out_values->mutable_data(context.GetPlace()); - out_vectors->mutable_data(context.GetPlace()); - - int batch_count = BatchCount(*x); - int order = x->dims()[x->dims().size() - 1]; - - Tensor real_values; - Tensor real_vectors; - // double the size of real_values, the first half stores the real part, - // the next half stores the imag part - std::vector origin_dim = phi::vectorize(out_values->dims()); - int last_item = origin_dim.back(); - origin_dim.pop_back(); - origin_dim.push_back(last_item * 2); - framework::DDim big_dim = phi::make_ddim(origin_dim); - - real_values.mutable_data>(big_dim, - context.GetPlace()); - real_vectors.mutable_data>(x->dims(), - context.GetPlace()); - - ApplyEigKernel>( - *x, &real_values, &real_vectors, context); - - auto& orig_dev_ctx = context.template device_context(); - auto& dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE&>( - orig_dev_ctx); - - // 1. extract real part & imag part from real_values - Tensor real_part = - phi::funcs::Slice(dev_ctx, real_values, {-1}, {0}, {order}); - Tensor imag_part = phi::funcs::Slice( - dev_ctx, real_values, {-1}, {order}, {order * 2}); - - // 2. construct complex values - auto* real_part_data = real_part.data>(); - auto* imag_part_data = imag_part.data>(); - int out_values_numel = out_values->numel(); - platform::ForRange for_range( - context.template device_context(), out_values_numel); - phi::funcs::RealImagToComplexFunctor functor( - real_part_data, - imag_part_data, - out_values->mutable_data(context.GetPlace()), - out_values_numel); - for_range(functor); - - // 3. construct complex vectors - Tensor real_vector_trans = - phi::TransposeLast2Dim(dev_ctx, real_vectors); - Tensor out_vectors_trans; - out_vectors_trans.mutable_data(x->dims(), context.GetPlace()); - ConstructComplexVectors, Tout>(&out_vectors_trans, - *out_values, - real_vector_trans, - context, - batch_count, - order); - TransposeTwoAxis(out_vectors_trans, - out_vectors, - x->dims().size() - 1, - x->dims().size() - 2, - context); - } else { - out_values->mutable_data(context.GetPlace()); - out_vectors->mutable_data(context.GetPlace()); - - ApplyEigKernel(*x, out_values, out_vectors, context); - } - } -}; - -template -void ComputeBackwardForComplexInput( - const Tensor& V, - const Tensor& L, - const Tensor& gL, - const Tensor& gV, - T* x_grad_data, - int batch_count, - int order, - const framework::ExecutionContext& context) { - auto& orig_dev_ctx = context.template device_context(); - auto& dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE&>( - orig_dev_ctx); - - Tensor trans_v = phi::TransposeLast2Dim(dev_ctx, V); - Tensor Vh = phi::Conj(dev_ctx, trans_v); - Tensor Lconj = phi::Conj(dev_ctx, L); - Tensor Econj = phi::Subtract(dev_ctx, - phi::funcs::Unsqueeze(Lconj, -2), - phi::funcs::Unsqueeze(Lconj, -1)); - Tensor VhgV = phi::Matmul(dev_ctx, Vh, gV); - Tensor diag_real = phi::Real(dev_ctx, VhgV); - Tensor diag_res = phi::funcs::BatchDiag(dev_ctx, diag_real, batch_count); - Tensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2); - - // turn diag_unsqueezed into complex - auto numel = diag_unsqueezed.numel(); - Tensor diag_unsqueezed_complex; - auto* data_diag_un = diag_unsqueezed.data>(); - auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data( - diag_unsqueezed.dims(), - context.GetPlace(), - static_cast(numel * sizeof(T))); - - platform::ForRange for_range(orig_dev_ctx, numel); - phi::funcs::RealToComplexFunctor functor( - data_diag_un, data_diag_un_com, numel); - for_range(functor); - // real tensor multiply complex tensor in broadcast manner - Tensor res1 = phi::Multiply(dev_ctx, V, diag_unsqueezed_complex); - Tensor res2 = phi::Matmul(dev_ctx, Vh, res1); - Tensor result = phi::Subtract(dev_ctx, VhgV, res2); - - result.mutable_data(V.dims(), context.GetPlace()); - result = phi::Divide(dev_ctx, result, Econj); - result = - phi::funcs::DiagFill(dev_ctx, order, order, order, 0, gL, result); - Tensor rhs = phi::Matmul(dev_ctx, result, Vh); - - // solve linear system - // solve(Vh, rhs, out, m, k) - // Vh: matrix with shape [m,m] - // rhs: rhs with shape [m,k] - // x_grad: out - int m = Vh.dims()[Vh.dims().size() - 1]; - int k = rhs.dims()[rhs.dims().size() - 1]; - auto* matrix_data = Vh.data(); - auto* rhs_data = rhs.data(); - phi::funcs::SolveLinearSystem( - matrix_data, rhs_data, x_grad_data, m, k, batch_count); -} - -template -class EigGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto& L = *context.Input("Eigenvalues"); - auto& V = *context.Input("Eigenvectors"); - auto& gL = *context.Input(framework::GradVarName("Eigenvalues")); - auto& gV = *context.Input(framework::GradVarName("Eigenvectors")); - - auto& x_grad = *context.Output(framework::GradVarName("X")); - auto* x_grad_data = x_grad.mutable_data(context.GetPlace()); - - auto& dims = V.dims(); - framework::DDim dim_origin = dims; - int num_dims = dim_origin.size(); - int batch_count = BatchCount(V); - const int order = dim_origin[num_dims - 1]; - - ComputeBackwardForComplexInput( - V, L, gL, gV, x_grad_data, batch_count, order, context); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index b9e7361abea7dafc9fab232e5beacd801013f009..564de7b67435e301902ed1195d31fffe317bc9dc 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2285,3 +2285,13 @@ args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {}) output : Tensor invoke : full_like(x, 0, dtype, place) + +# eig +- api: eig + args: (Tensor x) + output: Tensor(out_w), Tensor(out_v) + infer_meta: + func: EigInferMeta + kernel: + func: eig + backward: eig_grad diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b4972c68a6477bbba4db8bac7884d78dd774daeb..fa365d28dabcbcb9930cb71833ac1e5c01b00366 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -557,6 +557,19 @@ kernel : func : dropout_grad +- backward_api : eig_grad + forward : eig (Tensor x) -> Tensor(out_w), Tensor(out_v) + args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_v] + kernel : + func : eig_grad + data_type : out_v + data_transform: + skip_transform : out_w, out_w_grad + - backward_api : eigh_grad forward : eigh (Tensor x, str uplo) -> Tensor(out_w), Tensor(out_v) args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index dd2d1eb482c8edcaddc2c4e24cddb2a43a091d2f..eee75af3a329bb11d4b38d7b07a2f0dcc6be7ad6 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -223,6 +223,18 @@ void DeformableConvGradInferMeta(const MetaTensor& x, } } +void EigGradInferMeta(const MetaTensor& out_w, + const MetaTensor& out_v, + const MetaTensor& dout_w, + const MetaTensor& dout_v, + MetaTensor* dx) { + auto dims = out_v.dims(); + + if (dx) { + dx->set_dims(dims); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 6a4eba74b47beec1b6d22b4face65728ba36813e..527a3c107f820c8bed119cec0fabdd8731e54375 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -106,6 +106,12 @@ void DeformableConvGradInferMeta(const MetaTensor& x, MetaTensor* filter_grad, MetaTensor* mask_grad); +void EigGradInferMeta(const MetaTensor& out_w, + const MetaTensor& out_v, + const MetaTensor& dout_w, + const MetaTensor& dout_v, + MetaTensor* dx); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5958f0e71e76a96717c8d863649c2237cbbfdd90..84aa58d0f19355b1095819c19e5e437c91e61399 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -375,6 +375,32 @@ void DiagonalInferMeta(const MetaTensor& input, out->set_dims(phi::make_ddim(out_dims)); } +void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) { + auto x_dims = x.dims(); + int rank = x_dims.size(); + PADDLE_ENFORCE_GE( + rank, + 2, + phi::errors::InvalidArgument("Expects input tensor x to be not less than " + "2 dimentions, but got dimention %d", + rank)); + PADDLE_ENFORCE_EQ(x_dims[rank - 2], + x_dims[rank - 1], + phi::errors::InvalidArgument( + "The input matrix must be a square matrix, " + "but receive a matrix with %d rows and %d colums", + x_dims[rank - 2], + x_dims[rank - 1])); + + std::vector batch_dims_vec{}; + for (int i = 0; i < rank - 1; ++i) { + batch_dims_vec.emplace_back(x_dims[i]); + } + + out_w->set_dims(phi::make_ddim(batch_dims_vec)); + out_v->set_dims(x_dims); +} + void EighInferMeta(const MetaTensor& x, const std::string& uplo, MetaTensor* out_w, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 30db8dcae9882844f7b27e77c1721abb136d7e13..d0efbaa51ea5c96b7d5c5fb7a5063096ee992e90 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -77,6 +77,8 @@ void DiagInferMeta(const MetaTensor& x, void DiagonalInferMeta( const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out); +void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); + void EighInferMeta(const MetaTensor& x, const std::string& uplo, MetaTensor* out_w, diff --git a/paddle/phi/kernels/cpu/eig.h b/paddle/phi/kernels/cpu/eig.h new file mode 100644 index 0000000000000000000000000000000000000000..3ec862c1d471b282ff0d0035be854274e7924011 --- /dev/null +++ b/paddle/phi/kernels/cpu/eig.h @@ -0,0 +1,332 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include + +#include "Eigen/Core" +#include "Eigen/LU" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/elementwise_divide_kernel.h" +#include "paddle/phi/kernels/elementwise_multiply_kernel.h" +#include "paddle/phi/kernels/elementwise_subtract_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +#define EPSILON 1e-6 + +namespace phi { + +inline int BatchCount(const DenseTensor& matrix) { + int count = 1; + int num_dims = matrix.dims().size(); + for (int i = 0; i < num_dims - 2; ++i) { + count *= matrix.dims()[i]; + } + return count; +} + +inline int MatrixStride(const DenseTensor& matrix) { + phi::DDim dims_list = matrix.dims(); + int num_dims = dims_list.size(); + return dims_list[num_dims - 1] * dims_list[num_dims - 2]; +} + +// only used for complex input +template +void SolveLinearSystem(T* matrix_data, + T* rhs_data, + T* out_data, + int order, + int rhs_cols, + int batch) { + using Treal = typename Eigen::NumTraits::Real; + + // cast paddle::complex into std::complex + std::complex* matrix_data_ = + reinterpret_cast*>(matrix_data); + std::complex* rhs_data_ = + reinterpret_cast*>(rhs_data); + std::complex* out_data_ = + reinterpret_cast*>(out_data); + + using Matrix = Eigen::Matrix, + Eigen::Dynamic, + Eigen::Dynamic, + Eigen::RowMajor>; + using InputMatrixMap = Eigen::Map; + using OutputMatrixMap = Eigen::Map; + + for (int i = 0; i < batch; ++i) { + auto input_matrix = + InputMatrixMap(matrix_data_ + i * order * order, order, order); + auto input_rhs = + InputMatrixMap(rhs_data_ + i * order * rhs_cols, order, rhs_cols); + auto output = + OutputMatrixMap(out_data_ + i * order * rhs_cols, order, rhs_cols); + + Eigen::PartialPivLU lu_decomposition(order); + lu_decomposition.compute(input_matrix); + + const Treal min_abs_piv = + lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT( + min_abs_piv, + Treal(0), + errors::InvalidArgument("Something's wrong with SolveLinearSystem. ")); + + output = lu_decomposition.solve(input_rhs); + } +} + +template +void TransposeTwoAxis(const DenseTensor& input, + DenseTensor* transposed_input, + const int axis1, + const int axis2, + const Context& dev_ctx) { + std::vector permute(input.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis1] = axis2; + permute[axis2] = axis1; + + transposed_input->Resize(input.dims()); + dev_ctx.template Alloc(transposed_input); + + funcs::TransCompute( + input.dims().size(), dev_ctx, input, transposed_input, permute); +} + +// Apply eig to a batch of matrices, values, vectors and (intermidiate +// DenseTensor) info are overritten +template +void LapackEig(DenseTensor* input, + DenseTensor* values, + DenseTensor* vectors, + int info, + const Context& dev_ctx) { + char jobvl = 'N'; + char jobvr = 'V'; // only right eigenvectors are computed + int num_dims = input->dims().size(); + int order = input->dims()[num_dims - 1]; + + T* input_data = input->data(); + int lda = std::max(1, order); + + T* values_data = dev_ctx.template Alloc(values); + T* lvector_data = nullptr; + int ldvl = 1; + T* rvector_data = dev_ctx.template Alloc(vectors); + int ldvr = lda; + int lwork = -1; + + int batch_count = BatchCount(*input); + int matrix_stride = MatrixStride(*input); + int values_stride = values->dims()[values->dims().size() - 1]; + + DenseTensor rwork; + phi::dtype::Real* rwork_data = nullptr; + + rwork.Resize(phi::make_ddim({lda * 2})); + rwork_data = dev_ctx.template Alloc>(&rwork); + + // call lapackEig once to compute the size of work; + T computed_work_size; + phi::funcs::lapackEig>(jobvl, + jobvr, + order, + input_data, + lda, + values_data, + lvector_data, + ldvl, + rvector_data, + ldvr, + &computed_work_size, + lwork, + rwork_data, + &info); + + lwork = std::max( + 1, static_cast(phi::dtype::Real(computed_work_size))); + DenseTensor work; + work.Resize(phi::make_ddim({lwork})); + T* work_data = dev_ctx.template Alloc(&work); + + for (auto i = 0; i < batch_count; ++i) { + T* current_matrix = &input_data[i * matrix_stride]; + T* current_values = &values_data[i * values_stride]; + T* current_rvectors = &rvector_data[i * matrix_stride]; + + phi::funcs::lapackEig>(jobvl, + jobvr, + order, + current_matrix, + lda, + current_values, + lvector_data, + ldvl, + current_rvectors, + ldvr, + work_data, + lwork, + rwork_data, + &info); + PADDLE_ENFORCE_EQ( + info, + 0, + errors::PreconditionNotMet( + "current info is not 0, computation failed. " + "= 0: successful exit." + "< 0: if INFO = -i, the i-th argument had an illegal value." + "> 0: if INFO = i, the QR algorithm failed to compute all the " + "eigenvalues, and no eigenvectors have been computed; " + "elements i+1:N of WR and WI contain eigenvalues which " + "have converged.")); + } +} + +template +void ApplyEigKernel(const DenseTensor& input, + DenseTensor* values, + DenseTensor* vectors, + const Context& dev_ctx) { + DenseTensor input_column_major; + DenseTensor vectors_row_major; + int num_dims = input.dims().size(); + + // transfer to column-major memory layout i.e. make_ddim from tranposed_input: + // [batch,row,col]->[batch,col,row] + TransposeTwoAxis( + input, &input_column_major, num_dims - 1, num_dims - 2, dev_ctx); + // make sure 'vectors_row_major' holds memory before passed to LapackEig() + vectors_row_major.Resize(input.dims()); + int info = 0; + LapackEig( + &input_column_major, values, &vectors_row_major, info, dev_ctx); + + // transfer column-major layout back + // vectors_row_major: column-major layout + // vector: original layout + TransposeTwoAxis( + vectors_row_major, vectors, num_dims - 1, num_dims - 2, dev_ctx); +} + +// template +template +void ConstructComplexVectors(DenseTensor* c_vectors, + const DenseTensor& c_values, + const DenseTensor& r_vectors, + const Context& dev_ctx, + int batch_count, + int order) { + int matrix_stride = MatrixStride(r_vectors); + + auto* c_vectors_data = dev_ctx.template Alloc(c_vectors); + auto* c_values_data = c_values.data(); + auto* r_v_data = r_vectors.data(); + + for (int b = 0; b < batch_count; b++) { + auto* vecs = &r_v_data[b * matrix_stride]; + auto* res = &c_vectors_data[b * matrix_stride]; + auto* vals = &c_values_data[b * order]; + + for (int j = 0; j < order; j++) { + if (vals[j].imag < EPSILON) { + for (int i = 0; i < order; i++) { + res[j * order + i] = dtype::complex(vecs[j * order + i], 0); + } + } else { + for (int i = 0; i < order; i++) { + res[j * order + i] = + dtype::complex(vecs[j * order + i], vecs[(j + 1) * order + i]); + res[(j + 1) * order + i] = dtype::complex( + vecs[j * order + i], -vecs[(j + 1) * order + i]); + } + j++; + } + } + } +} + +template +void ComputeBackwardForComplexInput(const DenseTensor& L, + const DenseTensor& V, + const DenseTensor& gL, + const DenseTensor& gV, + T* x_grad_data, + int batch_count, + int order, + const Context& dev_ctx) { + DenseTensor trans_v = phi::TransposeLast2Dim(dev_ctx, V); + DenseTensor Vh = phi::Conj(dev_ctx, trans_v); + DenseTensor Lconj = phi::Conj(dev_ctx, L); + DenseTensor Econj = phi::Subtract(dev_ctx, + phi::funcs::Unsqueeze(Lconj, -2), + phi::funcs::Unsqueeze(Lconj, -1)); + DenseTensor VhgV = phi::Matmul(dev_ctx, Vh, gV); + DenseTensor diag_real = phi::Real(dev_ctx, VhgV); + DenseTensor diag_res = + phi::funcs::BatchDiag(dev_ctx, diag_real, batch_count); + DenseTensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2); + + // turn diag_unsqueezed into complex + auto numel = diag_unsqueezed.numel(); + DenseTensor diag_unsqueezed_complex; + auto* data_diag_un = diag_unsqueezed.data>(); + diag_unsqueezed_complex.Resize(diag_unsqueezed.dims()); + auto* data_diag_un_com = dev_ctx.template Alloc( + &diag_unsqueezed_complex, static_cast(numel * sizeof(T))); + + phi::funcs::ForRange for_range(dev_ctx, numel); + phi::funcs::RealToComplexFunctor functor( + data_diag_un, data_diag_un_com, numel); + for_range(functor); + // real tensor multiply complex tensor in broadcast manner + DenseTensor res1 = phi::Multiply(dev_ctx, V, diag_unsqueezed_complex); + DenseTensor res2 = phi::Matmul(dev_ctx, Vh, res1); + DenseTensor result = phi::Subtract(dev_ctx, VhgV, res2); + + result.Resize(V.dims()); + dev_ctx.template Alloc(&result); + result = phi::Divide(dev_ctx, result, Econj); + result = + phi::funcs::DiagFill(dev_ctx, order, order, order, 0, gL, result); + DenseTensor rhs = phi::Matmul(dev_ctx, result, Vh); + + // solve linear system + // solve(Vh, rhs, out, m, k) + // Vh: matrix with shape [m,m] + // rhs: rhs with shape [m,k] + // x_grad: out + int m = Vh.dims()[Vh.dims().size() - 1]; + int k = rhs.dims()[rhs.dims().size() - 1]; + auto* matrix_data = Vh.data(); + auto* rhs_data = rhs.data(); + + SolveLinearSystem(matrix_data, rhs_data, x_grad_data, m, k, batch_count); +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/eig_grad_kernel.cc b/paddle/phi/kernels/cpu/eig_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..2617cea00fd72a1d67301b3d0c1c859122d1be17 --- /dev/null +++ b/paddle/phi/kernels/cpu/eig_grad_kernel.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eig_grad_kernel.h" +#include "paddle/phi/kernels/cpu/eig.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void EigGradKernel(const Context& dev_ctx, + const DenseTensor& out_w, + const DenseTensor& out_v, + const DenseTensor& dout_w, + const DenseTensor& dout_v, + DenseTensor* dx) { + auto* dx_data = dev_ctx.template Alloc>(dx); + + auto& dims = out_v.dims(); + phi::DDim dim_origin = dims; + int num_dims = dim_origin.size(); + int batch_count = BatchCount(out_v); + const int order = dim_origin[num_dims - 1]; + + ComputeBackwardForComplexInput, Context>( + out_w, out_v, dout_w, dout_v, dx_data, batch_count, order, dev_ctx); +} + +} // namespace phi + +PD_REGISTER_KERNEL(eig_grad, + CPU, + ALL_LAYOUT, + phi::EigGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/eig_kernel.cc b/paddle/phi/kernels/cpu/eig_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..42a843391872ff1acf83d5bcdab3e6296939b38c --- /dev/null +++ b/paddle/phi/kernels/cpu/eig_kernel.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eig_kernel.h" +#include "paddle/phi/kernels/cpu/eig.h" + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void EigKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out_w, + DenseTensor* out_v) { + if (!IsComplexType(x.dtype())) { + dev_ctx.template Alloc>(out_w); + dev_ctx.template Alloc>(out_v); + + int batch_count = BatchCount(x); + int order = x.dims()[x.dims().size() - 1]; + + DenseTensor real_w; + DenseTensor real_v; + + // double the size of real_w, the first half stores the real part, + // the next half stores the imag part + std::vector origin_dim = phi::vectorize(out_w->dims()); + int last_item = origin_dim.back(); + origin_dim.pop_back(); + origin_dim.push_back(last_item * 2); + + phi::DDim big_dim = phi::make_ddim(origin_dim); + + real_w.Resize(big_dim); + dev_ctx.template Alloc>(&real_w); + real_v.Resize(x.dims()); + dev_ctx.template Alloc>(&real_v); + + phi::ApplyEigKernel, Context>( + x, &real_w, &real_v, dev_ctx); + + // 1. extract real part & imag part from real_w + DenseTensor real_part = + phi::funcs::Slice(dev_ctx, real_w, {-1}, {0}, {order}); + DenseTensor imag_part = + phi::funcs::Slice(dev_ctx, real_w, {-1}, {order}, {order * 2}); + + // 2. construct complex values + auto* real_part_data = real_part.data>(); + auto* imag_part_data = imag_part.data>(); + int out_w_numel = out_w->numel(); + + phi::funcs::ForRange for_range(dev_ctx, out_w_numel); + phi::funcs::RealImagToComplexFunctor> functor( + real_part_data, + imag_part_data, + dev_ctx.template Alloc>(out_w), + out_w_numel); + + for_range(functor); + + // 3. construct complex vectors + DenseTensor real_vector_trans = phi::TransposeLast2Dim(dev_ctx, real_v); + DenseTensor out_v_trans; + out_v_trans.Resize(x.dims()); + dev_ctx.template Alloc>(&out_v_trans); + phi::ConstructComplexVectors, + phi::dtype::Complex, + Context>( + &out_v_trans, *out_w, real_vector_trans, dev_ctx, batch_count, order); + TransposeTwoAxis, Context>( + out_v_trans, out_v, x.dims().size() - 1, x.dims().size() - 2, dev_ctx); + } else { + dev_ctx.template Alloc(out_w); + dev_ctx.template Alloc(out_v); + + phi::ApplyEigKernel(x, out_w, out_v, dev_ctx); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(eig, + CPU, + ALL_LAYOUT, + phi::EigKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/eig_grad_kernel.h b/paddle/phi/kernels/eig_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..93ae707f6f95f3f3ae43c8a232edcda3d6b3c37b --- /dev/null +++ b/paddle/phi/kernels/eig_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EigGardKernel(const Context& dev_ctx, + const DenseTensor& out_w, + const DenseTensor& out_v, + const DenseTensor& dout_w, + const DenseTensor& dout_v, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/eig_kernel.h b/paddle/phi/kernels/eig_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..fd894c35862332be52e3f4e84d5e8bb05bb72a30 --- /dev/null +++ b/paddle/phi/kernels/eig_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EigKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out_w, + DenseTensor* out_v); + +} // namespace phi diff --git a/paddle/phi/ops/compat/eig_sig.cc b/paddle/phi/ops/compat/eig_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9da3c7f447e398f135be346862870e744365d98 --- /dev/null +++ b/paddle/phi/ops/compat/eig_sig.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EigGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "eig_grad", + {"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"}, + {}, + {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(eig_grad, phi::EigGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_eig_op.py b/python/paddle/fluid/tests/unittests/test_eig_op.py index b4044c9e7991ca1c2d347c691e6f21525331d8d2..4e8f69a6bda064baebfde65b82fdad0f4a93ff1c 100644 --- a/python/paddle/fluid/tests/unittests/test_eig_op.py +++ b/python/paddle/fluid/tests/unittests/test_eig_op.py @@ -227,6 +227,52 @@ class TestEigStatic(TestEigOp): str(np.abs(fetch_vec))) +class TestEigDyGraph(unittest.TestCase): + + def test_check_output_with_place(self): + input_np = np.random.random([3, 3]).astype('complex') + expect_val, expect_vec = np.linalg.eig(input_np) + + paddle.set_device("cpu") + paddle.disable_static() + + input_tensor = paddle.to_tensor(input_np) + fetch_val, fetch_vec = paddle.linalg.eig(input_tensor) + + self.assertTrue( + np.allclose(expect_val, fetch_val.numpy(), 1e-6, + 1e-6), "The eigen values have diff: \nExpected " + + str(expect_val) + "\n" + "But got: " + str(fetch_val)) + self.assertTrue( + np.allclose(np.abs(expect_vec), np.abs(fetch_vec.numpy()), 1e-6, + 1e-6), "The eigen vectors have diff: \nExpected " + + str(np.abs(expect_vec)) + "\n" + "But got: " + + str(np.abs(fetch_vec.numpy()))) + + def test_check_grad(self): + test_shape = [3, 3] + test_type = 'float64' + paddle.set_device("cpu") + + input_np = np.random.random(test_shape).astype(test_type) + real_w, real_v = np.linalg.eig(input_np) + + grad_w = np.ones(real_w.shape, test_type) + grad_v = np.ones(real_v.shape, test_type) + grad_x = eig_backward(real_w, real_v, grad_w, grad_v) + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(input_np) + x.stop_gradient = False + w, v = paddle.linalg.eig(x) + (w.sum() + v.sum()).backward() + + self.assertTrue( + np.allclose(np.abs(x.grad.numpy()), np.abs(grad_x), 1e-5, 1e-5), + "The grad x have diff: \nExpected " + str(np.abs(grad_x)) + "\n" + + "But got: " + str(np.abs(x.grad.numpy()))) + + class TestEigWrongDimsError(unittest.TestCase): def test_error(self): @@ -254,7 +300,7 @@ class TestEigUnsupportedDtypeError(unittest.TestCase): paddle.disable_static() a = (np.random.random((3, 3)) * 10).astype('int64') x = paddle.to_tensor(a) - self.assertRaises(ValueError, paddle.linalg.eig, x) + self.assertRaises(RuntimeError, paddle.linalg.eig, x) if __name__ == "__main__": diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 68936855826bb03abb273ac81939462afd643982..35336228984273a22e28c9050c8517fc31d837d6 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2269,7 +2269,9 @@ def eig(x, name=None): # [ (16.50471283351188+0j) , (-5.5034820550763515+0j) , # (-0.21026087843552282+0j)]) """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_eig(x) + elif paddle.in_dynamic_mode(): w, v = _C_ops.eig(x) return w, v