From 2aca8d90813170d364ed0dde6580ffc08451597a Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Wed, 9 Mar 2022 18:51:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90phi=E3=80=91migrate=20eigh=20op=20to?= =?UTF-8?q?=20phi=20(#40213)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * migrate eigh to phi * optimize code * modify code according to comment * conflict resolution --- paddle/fluid/operators/eigh_op.cc | 63 +-- paddle/fluid/operators/eigh_op.cu | 32 -- paddle/fluid/operators/eigh_op.h | 74 ---- paddle/phi/infermeta/unary.cc | 32 ++ paddle/phi/infermeta/unary.h | 5 + paddle/phi/kernels/CMakeLists.txt | 3 +- paddle/phi/kernels/cpu/eigh_grad_kernel.cc | 28 ++ paddle/phi/kernels/cpu/eigh_kernel.cc | 43 ++ paddle/phi/kernels/eigh_grad_kernel.h | 29 ++ paddle/phi/kernels/eigh_kernel.h | 29 ++ .../kernels/funcs/values_vectors_functor.h | 386 ++++++++++++++++++ paddle/phi/kernels/gpu/eigh_grad_kernel.cu | 29 ++ paddle/phi/kernels/gpu/eigh_kernel.cu | 48 +++ .../phi/kernels/impl/eigh_grad_kernel_impl.h | 79 ++++ paddle/phi/ops/compat/eigh_sig.cc | 31 ++ 15 files changed, 751 insertions(+), 160 deletions(-) delete mode 100644 paddle/fluid/operators/eigh_op.cu delete mode 100644 paddle/fluid/operators/eigh_op.h create mode 100644 paddle/phi/kernels/cpu/eigh_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/eigh_kernel.cc create mode 100644 paddle/phi/kernels/eigh_grad_kernel.h create mode 100644 paddle/phi/kernels/eigh_kernel.h create mode 100644 paddle/phi/kernels/funcs/values_vectors_functor.h create mode 100644 paddle/phi/kernels/gpu/eigh_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/eigh_kernel.cu create mode 100644 paddle/phi/kernels/impl/eigh_grad_kernel_impl.h create mode 100644 paddle/phi/ops/compat/eigh_sig.cc diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc index 553d0e679cc..4e33c567eb6 100644 --- a/paddle/fluid/operators/eigh_op.cc +++ b/paddle/fluid/operators/eigh_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/eigh_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -22,42 +25,9 @@ using framework::Tensor; class EighOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh"); - OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", - "Eigh"); - OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", - "Eigh"); - - auto input_dim = ctx->GetInputDim("X"); - auto rank = input_dim.size(); - - PADDLE_ENFORCE_GE(rank, 2, - platform::errors::InvalidArgument( - "The Input(X) should have at least 2 dimensions." - "But received a %d dimension tensor.", - rank)); - PADDLE_ENFORCE_EQ( - input_dim[rank - 2], input_dim[rank - 1], - platform::errors::InvalidArgument( - "Eigh op is designed for square matrix, consequently" - "inner-most 2 dimensions of Input(X) should be symmetric." - "But received X's shape[-2] = %d and shape[-1] = %d.", - input_dim[rank - 2], input_dim[rank - 1])); - - std::vector values_dim; - - for (auto i = 0; i < rank - 1; i++) { - values_dim.emplace_back(input_dim[i]); - } - - ctx->SetOutputDim("Eigenvalues", phi::make_ddim(values_dim)); - ctx->SetOutputDim("Eigenvectors", input_dim); - } }; -class EignOpMaker : public framework::OpProtoAndCheckerMaker { +class EighOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", @@ -140,24 +110,11 @@ class EighGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(eigh, EighInferShapeFunctor, + PD_INFER_META(phi::EighInferMeta)); -REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker, +REGISTER_OPERATOR(eigh, ops::EighOp, ops::EighOpMaker, ops::EighGradOpMaker, - ops::EighGradOpMaker); + ops::EighGradOpMaker, + EighInferShapeFunctor); REGISTER_OPERATOR(eigh_grad, ops::EighGradOp); - -REGISTER_OP_CPU_KERNEL( - eigh, ops::EighKernel, - ops::EighKernel, - ops::EighKernel>, - ops::EighKernel>); - -REGISTER_OP_CPU_KERNEL( - eigh_grad, ops::EighGradKernel, - ops::EighGradKernel, - ops::EighGradKernel>, - ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu deleted file mode 100644 index 827c551637d..00000000000 --- a/paddle/fluid/operators/eigh_op.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/eigh_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - eigh, ops::EighKernel, - ops::EighKernel, - ops::EighKernel>, - ops::EighKernel>); - -REGISTER_OP_CUDA_KERNEL( - eigh_grad, ops::EighGradKernel, - ops::EighGradKernel, - ops::EighGradKernel>, - ops::EighGradKernel>); diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h deleted file mode 100644 index 5279ec75093..00000000000 --- a/paddle/fluid/operators/eigh_op.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/eigen_values_vectors.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class EighKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto input = ctx.Input("X"); - auto output_w = ctx.Output("Eigenvalues"); - auto output_v = ctx.Output("Eigenvectors"); - std::string lower = ctx.Attr("UPLO"); - bool is_lower = (lower == "L"); - math::MatrixEighFunctor functor; - functor(ctx, *input, output_w, output_v, is_lower, true); - } -}; - -template -class EighGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using ValueType = phi::dtype::Real; - auto& x_grad = *ctx.Output(framework::GradVarName("X")); - x_grad.mutable_data(ctx.GetPlace()); - auto& output_w = *ctx.Input("Eigenvalues"); - auto& output_v = *ctx.Input("Eigenvectors"); - auto& output_w_grad = - *ctx.Input(framework::GradVarName("Eigenvalues")); - auto& output_v_grad = - *ctx.Input(framework::GradVarName("Eigenvectors")); - - auto& dims = output_v.dims(); - const int m = dims[dims.size() - 1]; - auto dito = - math::DeviceIndependenceTensorOperations( - ctx); - auto tV = dito.Transpose(dito.Conj(output_v)); - auto W = dito.template Sub(dito.Unsqueeze(output_w, -2), - dito.Unsqueeze(output_w, -1)); - Tensor result = dito.Matmul(tV, output_v_grad); - result.mutable_data(dims, ctx.GetPlace()); - std::vector out_shape = phi::vectorize(dims); - auto constant = dito.Fill(out_shape, 0.5); - result = dito.Sub(result, dito.Conj(dito.Transpose(result))); - result = dito.Mul(result, constant); - result = dito.Div(result, W); - result = dito.DiagFill(m, m, m, 0, output_w_grad, result); - x_grad = dito.Matmul(output_v, dito.Matmul(result, tV)); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 32744659163..544a5593014 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1123,6 +1123,38 @@ void TransposeInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void EighInferMeta(const MetaTensor& x, + const std::string& uplo, + MetaTensor* out_w, + MetaTensor* out_v) { + auto input_dim = x.dims(); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE(rank, + 2, + phi::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions." + "But received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + input_dim[rank - 2], + input_dim[rank - 1], + phi::errors::InvalidArgument( + "Eigh op is designed for square matrix, consequently" + "inner-most 2 dimensions of Input(X) should be symmetric." + "But received X's shape[-2] = %d and shape[-1] = %d.", + input_dim[rank - 2], + input_dim[rank - 1])); + + std::vector values_dim; + + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); + } + out_w->set_dims(phi::make_ddim(values_dim)); + out_v->set_dims(input_dim); +} + } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 735a77faefe..c57e1bdec8d 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -163,4 +163,9 @@ void TransposeInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out); +void EighInferMeta(const MetaTensor& x, + const std::string& uplo, + MetaTensor* out_w, + MetaTensor* out_v); + } // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index de3b5b53f46..71e0d9e3479 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel) # Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel) +set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) @@ -38,6 +38,7 @@ kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_k kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel) +kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function) # 4. auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} ) diff --git a/paddle/phi/kernels/cpu/eigh_grad_kernel.cc b/paddle/phi/kernels/cpu/eigh_grad_kernel.cc new file mode 100644 index 00000000000..5135778db56 --- /dev/null +++ b/paddle/phi/kernels/cpu/eigh_grad_kernel.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eigh_grad_kernel.h" +#include "paddle/phi/kernels/impl/eigh_grad_kernel_impl.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(eigh_grad, + CPU, + ALL_LAYOUT, + phi::EighGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/eigh_kernel.cc b/paddle/phi/kernels/cpu/eigh_kernel.cc new file mode 100644 index 00000000000..92fd20ca9b8 --- /dev/null +++ b/paddle/phi/kernels/cpu/eigh_kernel.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/kernels/funcs/values_vectors_functor.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" + +namespace phi { + +template +void EighKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& uplo, + DenseTensor* out_w, + DenseTensor* out_v) { + bool is_lower = (uplo == "L"); + phi::funcs::MatrixEighFunctor functor; + functor(dev_ctx, x, out_w, out_v, is_lower, true); +} + +} // namespace phi + +PD_REGISTER_KERNEL(eigh, + CPU, + ALL_LAYOUT, + phi::EighKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/eigh_grad_kernel.h b/paddle/phi/kernels/eigh_grad_kernel.h new file mode 100644 index 00000000000..73df76e676a --- /dev/null +++ b/paddle/phi/kernels/eigh_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EighGardKernel(const Context& dev_ctx, + const DenseTensor& out_w, + const DenseTensor& out_v, + const DenseTensor& dout_w, + const DenseTensor& dout_v, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/eigh_kernel.h b/paddle/phi/kernels/eigh_kernel.h new file mode 100644 index 00000000000..dd28752d929 --- /dev/null +++ b/paddle/phi/kernels/eigh_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void EighKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& uplo, + DenseTensor* out_w, + DenseTensor* out_v); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/values_vectors_functor.h b/paddle/phi/kernels/funcs/values_vectors_functor.h new file mode 100644 index 00000000000..b3189fc5cc3 --- /dev/null +++ b/paddle/phi/kernels/funcs/values_vectors_functor.h @@ -0,0 +1,386 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/memory/memory.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/phi/backends/dynload/cusolver.h" +#endif // PADDLE_WITH_CUDA +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { +namespace funcs { + +inline int64_t GetBatchSize(phi::DDim dims) { + int64_t batch_size = 1; + auto dim_size = dims.size(); + for (int i = 0; i < dim_size - 2; i++) { + batch_size *= dims[i]; + } + return batch_size; +} + +static void CheckEighResult(const int batch, const int info) { + PADDLE_ENFORCE_LE( + info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: the [%d] off-diagonal elements of an intermediate" + "tridiagonal form did not converge to zero", + batch, + info)); + PADDLE_ENFORCE_GE( + info, + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: the [%d] argument had an illegal value", + batch, + info)); +} + +template +struct MatrixEighFunctor { + void operator()(const DeviceContext &dev_ctx, + const DenseTensor &input, + DenseTensor *eigen_values, + DenseTensor *eigen_vectors, + bool is_lower, + bool has_vectors); +}; + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices, and uses the variable has_vectors to +// control whether to return the eigenvectors. +template +struct MatrixEighFunctor { + public: + void operator()(const CPUContext &dev_ctx, + const DenseTensor &input, + DenseTensor *eigen_values, + DenseTensor *eigen_vectors, + bool is_lower, + bool has_vectors) { + using ValueType = phi::dtype::Real; + ValueType *out_value = dev_ctx.template Alloc(eigen_values); + + DenseTensor input_trans; + // lapack is a column-major storge, transpose make the input to + // have a continuous memory layout + input_trans = phi::TransposeLast2Dim(dev_ctx, input); + T *input_vector = input_trans.data(); + + auto dims = input.dims(); + int dim_size = dims.size(); + int64_t batch_size = GetBatchSize(dims); + + int vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + int values_stride = dims[dim_size - 1]; + char uplo = is_lower ? 'L' : 'U'; + char jobz = has_vectors ? 'V' : 'N'; + int n = dims[dim_size - 1]; + int64_t lda = std::max(1, n); + // if work = -1, it means that you need to use the lapack function to query + // the optimal value + int lwork = -1; // The length of the array work + int lrwork = -1; // The dimension of the array rwork,rwork is REAL array + int liwork = -1; // The dimension of the array iwork + int iwork_opt = -1; // The optimal length of the array liwork + T lwork_opt = static_cast(-1); // The optimal length of the array work + ValueType rwork_opt = + static_cast(-1); // The optimal length of the array rwork + + int info = 0; + // Call lapackEigh to get the optimal size of work data + phi::funcs::lapackEigh(jobz, + uplo, + n, + input_vector, + lda, + out_value, + &lwork_opt, + lwork, + &rwork_opt, + lrwork, + &iwork_opt, + liwork, + &info); + lwork = std::max(1, static_cast(lwork_opt)); + liwork = std::max(1, iwork_opt); + + DenseTensor rwork_tensor; + ValueType *rwork_data = nullptr; + + // complex type + if (input.type() == phi::DataType::COMPLEX64 || + input.type() == phi::DataType::COMPLEX128) { + lrwork = std::max(1, static_cast(rwork_opt)); + + rwork_tensor.Resize(phi::make_ddim({lrwork})); + rwork_data = dev_ctx.template Alloc(&rwork_tensor); + } + + DenseTensor iwork_tensor, work_tensor; + + iwork_tensor.Resize(phi::make_ddim({liwork})); + int *iwork_data = dev_ctx.template Alloc(&iwork_tensor); + + work_tensor.Resize(phi::make_ddim({lwork})); + T *work_data = dev_ctx.template Alloc(&work_tensor); + + for (auto i = 0; i < batch_size; i++) { + auto *value_data = out_value + i * values_stride; + auto *input_data = input_vector + i * vector_stride; + phi::funcs::lapackEigh(jobz, + uplo, + n, + input_data, + lda, + value_data, + work_data, + lwork, + rwork_data, + lrwork, + iwork_data, + liwork, + &info); + CheckEighResult(i, info); + } + if (has_vectors) { + PADDLE_ENFORCE_NOT_NULL(eigen_vectors, + phi::errors::InvalidArgument( + "When has_vectors is true," + "the eigenvectors needs to be calculated, " + "so the eigenvectors must be provided.")); + input_trans = phi::TransposeLast2Dim(dev_ctx, input_trans); + eigen_vectors->ShareDataWith(input_trans); + } + } +}; + +#ifdef PADDLE_WITH_CUDA + +// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real +// symmetric matrices on GPU, and uses the variable has_vectors +// to control whether to return the eigenvectors. +template +struct MatrixEighFunctor { + public: + void operator()(const GPUContext &dev_ctx, + const DenseTensor &input, + DenseTensor *eigen_values, + DenseTensor *eigen_vectors, + bool is_lower, + bool has_vectors) { + using ValueType = phi::dtype::Real; + ValueType *out_value = dev_ctx.template Alloc(eigen_values); + + DenseTensor input_trans; + input_trans = phi::TransposeLast2Dim(dev_ctx, input); + T *input_vector = input_trans.data(); + auto &dims = input.dims(); + int dim_size = dims.size(); + int64_t batch_size = GetBatchSize(dims); + + cublasFillMode_t uplo = + is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + cusolverEigMode_t jobz = + has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + + int n = dims[dim_size - 1]; + int lda = std::max(1, n); + auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2]; + auto values_stride = dims[dim_size - 1]; + int lwork = 0; + auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_size); + auto *info_ptr = reinterpret_cast(info->ptr()); + + // When the input type is float32, and the feature value input dimension + // is greater than or equal to [*,32,32] and less than or equal to + // [*,512,512], Syevj has better performance. + bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 && + values_stride >= 32 && values_stride <= 512); + syevjInfo_t syevj_params; + if (use_syevj) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCreateSyevjInfo(&syevj_params)); + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize( + dev_ctx.cusolver_dn_handle(), + jobz, + uplo, + n, + reinterpret_cast(input_vector), + lda, + reinterpret_cast(out_value), + &lwork, + syevj_params)); + } else { + EvdBuffer(dev_ctx.cusolver_dn_handle(), + jobz, + uplo, + n, + input_vector, + lda, + out_value, + &lwork); + } + auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * lwork); + auto *work_ptr = reinterpret_cast(work->ptr()); + for (auto i = 0; i < batch_size; i++) { + auto *input_data = input_vector + i * vector_stride; + auto *value_data = out_value + i * values_stride; + auto handle = dev_ctx.cusolver_dn_handle(); + if (use_syevj) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnSsyevj(handle, + jobz, + uplo, + n, + reinterpret_cast(input_data), + lda, + reinterpret_cast(value_data), + reinterpret_cast(work_ptr), + lwork, + info_ptr, + syevj_params)); + } else { + Evd(handle, + jobz, + uplo, + n, + input_data, + lda, + value_data, + work_ptr, + lwork, + info_ptr); + } + int error_info = 0; + paddle::memory::Copy(phi::CPUPlace(), + &error_info, + dev_ctx.GetPlace(), + info_ptr, + sizeof(int), + dev_ctx.stream()); + CheckEighResult(i, error_info); + } + + if (use_syevj) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnDestroySyevjInfo(syevj_params)); + } + if (has_vectors) { + PADDLE_ENFORCE_NOT_NULL(eigen_vectors, + phi::errors::InvalidArgument( + "When has_vectors is true," + "the eigenvectors needs to be calculated," + "so the eigenvectors must be provided.")); + // input_trans = dito.Transpose(input_trans); + input_trans = phi::TransposeLast2Dim(dev_ctx, input_trans); + eigen_vectors->ShareDataWith(input_trans); + } + } + + using ValueType = phi::dtype::Real; + inline void EvdBuffer(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + const T *A, + int lda, + const ValueType *W, + int *lwork) const; + + inline void Evd(cusolverDnHandle_t handle, + cusolverEigMode_t jobz, + cublasFillMode_t uplo, + int n, + T *A, + int lda, + ValueType *W, + T *work, + int lwork, + int *devInfo) const; +}; + +using phi::dtype::complex; + +#define FUNC_WITH_TYPES(m) \ + m(float, Ssy, float) m(double, Dsy, double) m( \ + complex, Che, cuComplex) m(complex, Zhe, cuDoubleComplex) + +#define EVDBUFFER_INSTANCE(T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::EvdBuffer( \ + cusolverDnHandle_t handle, \ + cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, \ + int n, \ + const T *A, \ + int lda, \ + const ValueType *W, \ + int *lwork) const { \ + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDn##C##evd_bufferSize( \ + handle, \ + jobz, \ + uplo, \ + n, \ + reinterpret_cast(A), \ + lda, \ + W, \ + lwork)); \ + } + +FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); + +#define EVD_INSTANCE(T, C, CastType) \ + template <> \ + inline void MatrixEighFunctor::Evd(cusolverDnHandle_t handle, \ + cusolverEigMode_t jobz, \ + cublasFillMode_t uplo, \ + int n, \ + T *A, \ + int lda, \ + ValueType *W, \ + T *work, \ + int lwork, \ + int *devInfo) const { \ + PADDLE_ENFORCE_GPU_SUCCESS( \ + dynload::cusolverDn##C##evd(handle, \ + jobz, \ + uplo, \ + n, \ + reinterpret_cast(A), \ + lda, \ + W, \ + reinterpret_cast(work), \ + lwork, \ + devInfo)); \ + } + +FUNC_WITH_TYPES(EVD_INSTANCE); + +#undef FUNC_WITH_TYPES +#undef EVDBUFFER_INSTANCE +#undef EVD_INSTANCE + +#endif // PADDLE_WITH_CUDA + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/eigh_grad_kernel.cu b/paddle/phi/kernels/gpu/eigh_grad_kernel.cu new file mode 100644 index 00000000000..fdf61dc7399 --- /dev/null +++ b/paddle/phi/kernels/gpu/eigh_grad_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eigh_grad_kernel.h" +#include "paddle/phi/kernels/impl/eigh_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" + +PD_REGISTER_KERNEL(eigh_grad, + GPU, + ALL_LAYOUT, + phi::EighGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/eigh_kernel.cu b/paddle/phi/kernels/gpu/eigh_kernel.cu new file mode 100644 index 00000000000..4ff3b371b6a --- /dev/null +++ b/paddle/phi/kernels/gpu/eigh_kernel.cu @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/kernels/eigh_kernel.h" +#include "paddle/phi/kernels/funcs/values_vectors_functor.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" + +namespace phi { + +template +void EighKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& uplo, + DenseTensor* out_w, + DenseTensor* out_v) { + bool is_lower = (uplo == "L"); + phi::funcs::MatrixEighFunctor functor; + functor(dev_ctx, x, out_w, out_v, is_lower, true); +} + +} // namespace phi + +PD_REGISTER_KERNEL(eigh, // cuda_only + GPU, + ALL_LAYOUT, + phi::EighKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/eigh_grad_kernel_impl.h b/paddle/phi/kernels/impl/eigh_grad_kernel_impl.h new file mode 100644 index 00000000000..2f0530b638f --- /dev/null +++ b/paddle/phi/kernels/impl/eigh_grad_kernel_impl.h @@ -0,0 +1,79 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" +#include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +void EighGradKernel(const Context& dev_ctx, + const DenseTensor& out_w, + const DenseTensor& out_v, + const DenseTensor& dout_w, + const DenseTensor& dout_v, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + auto& dims = out_v.dims(); + const int m = dims[dims.size() - 1]; + DenseTensor tV = + phi::TransposeLast2Dim(dev_ctx, phi::Conj(dev_ctx, out_v)); + DenseTensor W = + phi::Subtract>(dev_ctx, + phi::funcs::Unsqueeze(out_w, -2), + phi::funcs::Unsqueeze(out_w, -1)); + DenseTensor result = phi::Matmul(dev_ctx, tV, dout_v); + result.Resize(dims); + dev_ctx.template Alloc(&result); + + std::vector out_shape = phi::vectorize(dims); + DenseTensor constant; + constant.Resize(phi::make_ddim(out_shape)); + dev_ctx.template Alloc(&constant); + phi::funcs::SetConstant()(dev_ctx, &constant, T(0.5)); + result = phi::Subtract( + dev_ctx, + result, + phi::Conj(dev_ctx, phi::TransposeLast2Dim(dev_ctx, result))); + result = phi::Multiply(dev_ctx, result, constant); + if (result.type() != W.type()) { + auto x_vector = EigenVector::Flatten(result); + auto y_vector = EigenVector>::Flatten(W); + auto out_vector = EigenVector::Flatten(result); + auto& place = *dev_ctx.eigen_device(); + out_vector.device(place) = x_vector / y_vector; + } else { + result = phi::Divide(dev_ctx, result, W); + } + result = phi::funcs::DiagFill>( + dev_ctx, m, m, m, 0, dout_w, result); + *dx = phi::Matmul(dev_ctx, out_v, phi::Matmul(dev_ctx, result, tV)); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/eigh_sig.cc b/paddle/phi/ops/compat/eigh_sig.cc new file mode 100644 index 00000000000..e50a9a5a12a --- /dev/null +++ b/paddle/phi/ops/compat/eigh_sig.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("eigh_grad", + {"Eigenvalues", + "Eigenvectors", + GradVarName("Eigenvalues"), + GradVarName("Eigenvectors")}, + {}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(eigh_grad, phi::EighGradOpArgumentMapping); -- GitLab