From 95474815f976c4688393dacfe545207140eb6560 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Wed, 13 Jul 2022 19:13:22 +0800 Subject: [PATCH] Move eigvals OP to PHI (#44183) * Move eigvals OP to PHI * Fix CI errors * Fix CI errors --- paddle/fluid/operators/eigvals_op.cc | 57 +---- paddle/fluid/operators/eigvals_op.h | 273 ----------------------- paddle/phi/api/yaml/legacy_api.yaml | 8 + paddle/phi/core/utils/data_type.h | 17 ++ paddle/phi/infermeta/unary.cc | 33 +++ paddle/phi/infermeta/unary.h | 4 + paddle/phi/kernels/cpu/eigvals_kernel.cc | 260 +++++++++++++++++++++ paddle/phi/kernels/eigvals_kernel.h | 25 +++ paddle/phi/ops/compat/eigvals_sig.cc | 25 +++ python/paddle/tensor/linalg.py | 4 +- 10 files changed, 383 insertions(+), 323 deletions(-) delete mode 100644 paddle/fluid/operators/eigvals_op.h create mode 100644 paddle/phi/kernels/cpu/eigvals_kernel.cc create mode 100644 paddle/phi/kernels/eigvals_kernel.h create mode 100644 paddle/phi/ops/compat/eigvals_sig.cc diff --git a/paddle/fluid/operators/eigvals_op.cc b/paddle/fluid/operators/eigvals_op.cc index cb81a1a64d..78bd2b37f6 100644 --- a/paddle/fluid/operators/eigvals_op.cc +++ b/paddle/fluid/operators/eigvals_op.cc @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/eigvals_op.h" - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -36,59 +37,17 @@ class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker { class EigvalsOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals"); - - DDim x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_GE(x_dims.size(), - 2, - platform::errors::InvalidArgument( - "The dimensions of Input(X) for Eigvals operator " - "should be at least 2, " - "but received X's dimension = %d, X's shape = [%s].", - x_dims.size(), - x_dims)); - - if (ctx->IsRuntime() || !phi::contain_unknown_dim(x_dims)) { - int last_dim = x_dims.size() - 1; - PADDLE_ENFORCE_EQ(x_dims[last_dim], - x_dims[last_dim - 1], - platform::errors::InvalidArgument( - "The last two dimensions of Input(X) for Eigvals " - "operator should be equal, " - "but received X's shape = [%s].", - x_dims)); - } - - auto output_dims = vectorize(x_dims); - output_dims.resize(x_dims.size() - 1); - ctx->SetOutputDim("Out", phi::make_ddim(output_dims)); - } }; -class EigvalsOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(framework::InferVarTypeContext* ctx) const { - auto input_dtype = ctx->GetInputDataType("X"); - auto output_dtype = framework::IsComplexType(input_dtype) - ? input_dtype - : framework::ToComplexType(input_dtype); - ctx->SetOutputDataType("Out", output_dtype); - } -}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(eigvals, + EigvalsInferShapeFunctor, + PD_INFER_META(phi::EigvalsInferMeta)); REGISTER_OPERATOR(eigvals, ops::EigvalsOp, ops::EigvalsOpMaker, - ops::EigvalsOpVarTypeInference); -REGISTER_OP_CPU_KERNEL( - eigvals, - ops::EigvalsKernel, - ops::EigvalsKernel, - ops::EigvalsKernel>, - ops::EigvalsKernel>); + EigvalsInferShapeFunctor); diff --git a/paddle/fluid/operators/eigvals_op.h b/paddle/fluid/operators/eigvals_op.h deleted file mode 100644 index 38560bf7c3..0000000000 --- a/paddle/fluid/operators/eigvals_op.h +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/allocation/allocator.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/core/ddim.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; -using DDim = framework::DDim; - -template -struct PaddleComplex; - -template -struct PaddleComplex< - T, - typename std::enable_if::value>::type> { - using type = paddle::platform::complex; -}; -template -struct PaddleComplex< - T, - typename std::enable_if< - std::is_same>::value || - std::is_same>::value>::type> { - using type = T; -}; - -template -using PaddleCType = typename PaddleComplex::type; -template -using Real = typename phi::dtype::Real; - -static void SpiltBatchSquareMatrix(const Tensor& input, - std::vector* output) { - DDim input_dims = input.dims(); - int last_dim = input_dims.size() - 1; - int n_dim = input_dims[last_dim]; - - DDim flattened_input_dims, flattened_output_dims; - if (input_dims.size() > 2) { - flattened_input_dims = - phi::flatten_to_3d(input_dims, last_dim - 1, last_dim); - } else { - flattened_input_dims = phi::make_ddim({1, n_dim, n_dim}); - } - - Tensor flattened_input; - flattened_input.ShareDataWith(input); - flattened_input.Resize(flattened_input_dims); - (*output) = flattened_input.Split(1, 0); -} - -static void CheckLapackEigResult(const int info, const std::string& name) { - PADDLE_ENFORCE_LE(info, - 0, - platform::errors::PreconditionNotMet( - "The QR algorithm failed to compute all the " - "eigenvalues in function %s.", - name.c_str())); - PADDLE_ENFORCE_GE( - info, - 0, - platform::errors::InvalidArgument( - "The %d-th argument has an illegal value in function %s.", - -info, - name.c_str())); -} - -template -static typename std::enable_if::value>::type -LapackEigvals(const framework::ExecutionContext& ctx, - const Tensor& input, - Tensor* output, - Tensor* work, - Tensor* rwork /*unused*/) { - Tensor a; // will be overwritten when lapackEig exit - framework::TensorCopy(input, input.place(), &a); - - Tensor w; - int64_t n_dim = input.dims()[1]; - auto* w_data = - w.mutable_data(phi::make_ddim({n_dim << 1}), ctx.GetPlace()); - - int64_t work_mem = work->memory_size(); - int64_t required_work_mem = 3 * n_dim * sizeof(T); - PADDLE_ENFORCE_GE( - work_mem, - 3 * n_dim * sizeof(T), - platform::errors::InvalidArgument( - "The memory size of the work tensor in LapackEigvals function " - "should be at least %" PRId64 " bytes, " - "but received work\'s memory size = %" PRId64 " bytes.", - required_work_mem, - work_mem)); - - int info = 0; - phi::funcs::lapackEig('N', - 'N', - static_cast(n_dim), - a.template data(), - static_cast(n_dim), - w_data, - NULL, - 1, - NULL, - 1, - work->template data(), - static_cast(work_mem / sizeof(T)), - static_cast(NULL), - &info); - - std::string name = "framework::platform::dynload::dgeev_"; - if (framework::TransToProtoVarType(input.dtype()) == - framework::proto::VarType::FP64) { - name = "framework::platform::dynload::sgeev_"; - } - CheckLapackEigResult(info, name); - - platform::ForRange for_range( - ctx.template device_context(), n_dim); - phi::funcs::RealImagToComplexFunctor> functor( - w_data, w_data + n_dim, output->template data>(), n_dim); - for_range(functor); -} - -template -typename std::enable_if>::value || - std::is_same>::value>::type -LapackEigvals(const framework::ExecutionContext& ctx, - const Tensor& input, - Tensor* output, - Tensor* work, - Tensor* rwork) { - Tensor a; // will be overwritten when lapackEig exit - framework::TensorCopy(input, input.place(), &a); - - int64_t work_mem = work->memory_size(); - int64_t n_dim = input.dims()[1]; - int64_t required_work_mem = 3 * n_dim * sizeof(T); - PADDLE_ENFORCE_GE( - work_mem, - 3 * n_dim * sizeof(T), - platform::errors::InvalidArgument( - "The memory size of the work tensor in LapackEigvals function " - "should be at least %" PRId64 " bytes, " - "but received work\'s memory size = %" PRId64 " bytes.", - required_work_mem, - work_mem)); - - int64_t rwork_mem = rwork->memory_size(); - int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real); - PADDLE_ENFORCE_GE( - rwork_mem, - required_rwork_mem, - platform::errors::InvalidArgument( - "The memory size of the rwork tensor in LapackEigvals function " - "should be at least %" PRId64 " bytes, " - "but received rwork\'s memory size = %" PRId64 " bytes.", - required_rwork_mem, - rwork_mem)); - - int info = 0; - phi::funcs::lapackEig>( - 'N', - 'N', - static_cast(n_dim), - a.template data(), - static_cast(n_dim), - output->template data(), - NULL, - 1, - NULL, - 1, - work->template data(), - static_cast(work_mem / sizeof(T)), - rwork->template data>(), - &info); - - std::string name = "framework::platform::dynload::cgeev_"; - if (framework::TransToProtoVarType(input.dtype()) == - framework::proto::VarType::COMPLEX64) { - name = "framework::platform::dynload::zgeev_"; - } - CheckLapackEigResult(info, name); -} - -template -class EigvalsKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const Tensor* input = ctx.Input("X"); - Tensor* output = ctx.Output("Out"); - output->mutable_data>(ctx.GetPlace()); - - std::vector input_matrices; - SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); - - int64_t n_dim = input_matrices[0].dims()[1]; - int64_t n_batch = input_matrices.size(); - DDim output_dims = output->dims(); - output->Resize(phi::make_ddim({n_batch, n_dim})); - std::vector output_vectors = output->Split(1, 0); - - // query workspace size - T qwork; - int info; - phi::funcs::lapackEig>( - 'N', - 'N', - static_cast(n_dim), - input_matrices[0].template data(), - static_cast(n_dim), - NULL, - NULL, - 1, - NULL, - 1, - &qwork, - -1, - static_cast*>(NULL), - &info); - int64_t lwork = static_cast(qwork); - - Tensor work, rwork; - try { - work.mutable_data(phi::make_ddim({lwork}), ctx.GetPlace()); - } catch (memory::allocation::BadAlloc&) { - LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal " - << "memory size = " << lwork * sizeof(T) << " bytes, " - << "try reallocating a smaller workspace with the minimum " - << "required size = " << 3 * n_dim * sizeof(T) << " bytes, " - << "this may lead to bad performance."; - lwork = 3 * n_dim; - work.mutable_data(phi::make_ddim({lwork}), ctx.GetPlace()); - } - if (framework::IsComplexType( - framework::TransToProtoVarType(input->dtype()))) { - rwork.mutable_data>(phi::make_ddim({n_dim << 1}), - ctx.GetPlace()); - } - - for (int64_t i = 0; i < n_batch; ++i) { - LapackEigvals( - ctx, input_matrices[i], &output_vectors[i], &work, &rwork); - } - output->Resize(output_dims); - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index cd01c23641..3dad0b96ae 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -536,6 +536,14 @@ func : eigh backward : eigh_grad +- api : eigvals + args : (Tensor x) + output : Tensor + infer_meta : + func : EigvalsInferMeta + kernel : + func : eigvals + - api : einsum args : (Tensor[] x, str equation) output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} diff --git a/paddle/phi/core/utils/data_type.h b/paddle/phi/core/utils/data_type.h index 9ef8e8a356..975d55889c 100644 --- a/paddle/phi/core/utils/data_type.h +++ b/paddle/phi/core/utils/data_type.h @@ -80,4 +80,21 @@ inline void VisitDataTypeTiny(phi::DataType type, Visitor visitor) { "Not supported phi::DataType(%d) as data type.", static_cast(type))); } +inline bool IsComplexType(const DataType& type) { + return (type == DataType::COMPLEX64 || type == DataType::COMPLEX128); +} + +inline DataType ToComplexType(const DataType& type) { + switch (type) { + case DataType::FLOAT32: + return DataType::COMPLEX64; + case DataType::FLOAT64: + return DataType::COMPLEX128; + default: + PADDLE_THROW(errors::Unimplemented( + "Can not transform data type (%s) to complex type, now only support " + "float32 and float64 real value.", + type)); + } +} } // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0048f130ad..f6e3b0d724 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -399,6 +399,39 @@ void EighInferMeta(const MetaTensor& x, out_v->set_dims(input_dim); } +void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + errors::InvalidArgument( + "The dimensions of Input(X) for Eigvals operator " + "should be at least 2, " + "but received X's dimension = %d, X's shape = [%s].", + x_dims.size(), + x_dims)); + + if (config.is_runtime || !phi::contain_unknown_dim(x_dims)) { + int last_dim = x_dims.size() - 1; + PADDLE_ENFORCE_EQ(x_dims[last_dim], + x_dims[last_dim - 1], + errors::InvalidArgument( + "The last two dimensions of Input(X) for Eigvals " + "operator should be equal, " + "but received X's shape = [%s].", + x_dims)); + } + + auto out_dims = vectorize(x_dims); + out_dims.resize(x_dims.size() - 1); + + const DataType& x_dtype = x.dtype(); + const DataType& out_dtype = + IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype); + + out->set_dims(make_ddim(out_dims)); + out->set_dtype(out_dtype); +} + void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 0b9298cfd3..fc36e1d4f8 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v); +void EigvalsInferMeta(const MetaTensor& x, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out, diff --git a/paddle/phi/kernels/cpu/eigvals_kernel.cc b/paddle/phi/kernels/cpu/eigvals_kernel.cc new file mode 100644 index 0000000000..e99aa42fbd --- /dev/null +++ b/paddle/phi/kernels/cpu/eigvals_kernel.cc @@ -0,0 +1,260 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/eigvals_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" + +namespace phi { + +template +struct PaddleComplex; + +template +struct PaddleComplex< + T, + typename std::enable_if::value>::type> { + using type = dtype::complex; +}; + +template +struct PaddleComplex< + T, + typename std::enable_if< + std::is_same>::value || + std::is_same>::value>::type> { + using type = T; +}; + +template +using PaddleCType = typename PaddleComplex::type; +template +using Real = typename dtype::Real; + +inline void CheckLapackEigResult(const int info, const std::string& name) { + PADDLE_ENFORCE_LE( + info, + 0, + errors::PreconditionNotMet("The QR algorithm failed to compute all the " + "eigenvalues in function %s.", + name.c_str())); + PADDLE_ENFORCE_GE( + info, + 0, + errors::InvalidArgument( + "The %d-th argument has an illegal value in function %s.", + -info, + name.c_str())); +} + +template +typename std::enable_if::value>::type LapackEigvals( + const Context& ctx, + const DenseTensor& input, + DenseTensor* output, + DenseTensor* work, + DenseTensor* rwork /*unused*/) { + DenseTensor a; // will be overwritten when lapackEig exit + Copy(ctx, input, input.place(), /*blocking=*/true, &a); + + DenseTensor w; + int64_t n_dim = input.dims()[1]; + w.Resize(make_ddim({n_dim << 1})); + T* w_data = ctx.template Alloc(&w); + + int64_t work_mem = work->memory_size(); + int64_t required_work_mem = 3 * n_dim * sizeof(T); + PADDLE_ENFORCE_GE( + work_mem, + 3 * n_dim * sizeof(T), + errors::InvalidArgument( + "The memory size of the work tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received work\'s memory size = %" PRId64 " bytes.", + required_work_mem, + work_mem)); + + int info = 0; + phi::funcs::lapackEig('N', + 'N', + static_cast(n_dim), + a.template data(), + static_cast(n_dim), + w_data, + NULL, + 1, + NULL, + 1, + work->template data(), + static_cast(work_mem / sizeof(T)), + static_cast(NULL), + &info); + + std::string name = "phi::backend::dynload::dgeev_"; + if (input.dtype() == DataType::FLOAT64) { + name = "phi::backend::dynload::sgeev_"; + } + CheckLapackEigResult(info, name); + + funcs::ForRange for_range(ctx, n_dim); + funcs::RealImagToComplexFunctor> functor( + w_data, w_data + n_dim, output->template data>(), n_dim); + for_range(functor); +} + +template +typename std::enable_if>::value || + std::is_same>::value>::type +LapackEigvals(const Context& ctx, + const DenseTensor& input, + DenseTensor* output, + DenseTensor* work, + DenseTensor* rwork) { + DenseTensor a; // will be overwritten when lapackEig exit + Copy(ctx, input, input.place(), /*blocking=*/true, &a); + + int64_t work_mem = work->memory_size(); + int64_t n_dim = input.dims()[1]; + int64_t required_work_mem = 3 * n_dim * sizeof(T); + PADDLE_ENFORCE_GE( + work_mem, + 3 * n_dim * sizeof(T), + errors::InvalidArgument( + "The memory size of the work tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received work\'s memory size = %" PRId64 " bytes.", + required_work_mem, + work_mem)); + + int64_t rwork_mem = rwork->memory_size(); + int64_t required_rwork_mem = (n_dim << 1) * sizeof(dtype::Real); + PADDLE_ENFORCE_GE( + rwork_mem, + required_rwork_mem, + errors::InvalidArgument( + "The memory size of the rwork tensor in LapackEigvals function " + "should be at least %" PRId64 " bytes, " + "but received rwork\'s memory size = %" PRId64 " bytes.", + required_rwork_mem, + rwork_mem)); + + int info = 0; + phi::funcs::lapackEig>( + 'N', + 'N', + static_cast(n_dim), + a.template data(), + static_cast(n_dim), + output->template data(), + NULL, + 1, + NULL, + 1, + work->template data(), + static_cast(work_mem / sizeof(T)), + rwork->template data>(), + &info); + + std::string name = "phi::backend::dynload::cgeev_"; + if (input.dtype() == DataType::COMPLEX128) { + name = "phi::backend::dynload::zgeev_"; + } + CheckLapackEigResult(info, name); +} + +void SpiltBatchSquareMatrix(const DenseTensor& input, + std::vector* output) { + DDim input_dims = input.dims(); + int last_dim = input_dims.size() - 1; + int n_dim = input_dims[last_dim]; + + DDim flattened_input_dims, flattened_output_dims; + if (input_dims.size() > 2) { + flattened_input_dims = + phi::flatten_to_3d(input_dims, last_dim - 1, last_dim); + } else { + flattened_input_dims = phi::make_ddim({1, n_dim, n_dim}); + } + + DenseTensor flattened_input; + flattened_input.ShareDataWith(input); + flattened_input.Resize(flattened_input_dims); + (*output) = flattened_input.Split(1, 0); +} + +template +void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + ctx.template Alloc>(out); + + std::vector x_matrices; + SpiltBatchSquareMatrix(x, /*->*/ &x_matrices); + + int64_t n_dim = x_matrices[0].dims()[1]; + int64_t n_batch = x_matrices.size(); + DDim out_dims = out->dims(); + out->Resize(make_ddim({n_batch, n_dim})); + std::vector out_vectors = out->Split(1, 0); + + // query workspace size + T qwork; + int info; + funcs::lapackEig>('N', + 'N', + static_cast(n_dim), + x_matrices[0].template data(), + static_cast(n_dim), + NULL, + NULL, + 1, + NULL, + 1, + &qwork, + -1, + static_cast*>(NULL), + &info); + int64_t lwork = static_cast(qwork); + + DenseTensor work, rwork; + + work.Resize(make_ddim({lwork})); + ctx.template Alloc(&work); + + if (IsComplexType(x.dtype())) { + rwork.Resize(make_ddim({n_dim << 1})); + ctx.template Alloc>(&rwork); + } + + for (int64_t i = 0; i < n_batch; ++i) { + LapackEigvals( + ctx, x_matrices[i], &out_vectors[i], &work, &rwork); + } + out->Resize(out_dims); +} + +} // namespace phi + +PD_REGISTER_KERNEL(eigvals, + CPU, + ALL_LAYOUT, + phi::EigvalsKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/eigvals_kernel.h b/paddle/phi/kernels/eigvals_kernel.h new file mode 100644 index 0000000000..dd9f3370bd --- /dev/null +++ b/paddle/phi/kernels/eigvals_kernel.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" + +namespace phi { + +template +void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/eigvals_sig.cc b/paddle/phi/ops/compat/eigvals_sig.cc new file mode 100644 index 0000000000..cb29126abc --- /dev/null +++ b/paddle/phi/ops/compat/eigvals_sig.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EigvalsOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("eigvals", {"X"}, {}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(eigvals, phi::EigvalsOpArgumentMapping); diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 95eaee2cc0..1bc85a076a 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2339,7 +2339,9 @@ def eigvals(x, name=None): "The last two dimensions of Input(x) should be equal, but received x's shape = {}" .format(x_shape)) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_eigvals(x) + elif paddle.in_dynamic_mode(): return _C_ops.eigvals(x) helper = LayerHelper('eigvals', **locals()) -- GitLab