diff --git a/paddle/fluid/operators/eigvalsh_op.cc b/paddle/fluid/operators/eigvalsh_op.cc index f7abdbee84f1db4a21512560ed952396dd9e3578..9ba892b61badf3eae97581050d748a2589113f8e 100644 --- a/paddle/fluid/operators/eigvalsh_op.cc +++ b/paddle/fluid/operators/eigvalsh_op.cc @@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/eigvalsh_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -22,43 +26,6 @@ using framework::Tensor; class EigvalshOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvalsh"); - OP_INOUT_CHECK( - ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", "Eigvalsh"); - - auto input_dim = ctx->GetInputDim("X"); - auto rank = input_dim.size(); - - PADDLE_ENFORCE_GE(rank, - 2, - platform::errors::InvalidArgument( - "The Input(X) should have at least 2 dimensions." - "But received a %d dimension tensor.", - rank)); - PADDLE_ENFORCE_EQ( - input_dim[rank - 2], - input_dim[rank - 1], - platform::errors::InvalidArgument( - "Eigvalsh op is designed for square matrix, consequently" - "inner-most 2 dimensions of Input(X) should be symmetric." - "But received X's shape[-2] = %d and shape[-1] = %d.", - input_dim[rank - 2], - input_dim[rank - 1])); - - std::vector values_dim; - - for (auto i = 0; i < rank - 1; i++) { - values_dim.emplace_back(input_dim[i]); - } - - ctx->SetOutputDim("Eigenvalues", phi::make_ddim(values_dim)); - - if (ctx->HasOutput("Eigenvectors")) { - ctx->SetOutputDim("Eigenvectors", input_dim); - } - } }; class EigvalshOpMaker : public framework::OpProtoAndCheckerMaker { @@ -100,20 +67,6 @@ class EigvalshGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK( - ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EigvalshGrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), - "Input", - "Eigenvalues@GRAD", - "EigvalshGrad"); - auto dims = ctx->GetInputDim("Eigenvectors"); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, dims); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -144,30 +97,19 @@ class EigvalshGradOpMaker : public framework::SingleGradOpMaker { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh, + EigvalshInferShapeFunctor, + PD_INFER_META(phi::EigvalshInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh_grad, + EigvalshGradInferShapeFunctor, + PD_INFER_META(phi::EigvalshGradInferMeta)); + REGISTER_OPERATOR(eigvalsh, ops::EigvalshOp, ops::EigvalshOpMaker, ops::EigvalshGradOpMaker, - ops::EigvalshGradOpMaker); -REGISTER_OPERATOR(eigvalsh_grad, ops::EigvalshGradOp); - -REGISTER_OP_CPU_KERNEL(eigvalsh, - ops::EigvalshKernel, - ops::EigvalshKernel, - ops::EigvalshKernel>, - ops::EigvalshKernel>); - -REGISTER_OP_CPU_KERNEL( - eigvalsh_grad, - ops::EigvalshGradKernel, - ops::EigvalshGradKernel, - ops::EigvalshGradKernel>, - ops::EigvalshGradKernel>); + ops::EigvalshGradOpMaker, + EigvalshInferShapeFunctor); +REGISTER_OPERATOR(eigvalsh_grad, + ops::EigvalshGradOp, + EigvalshGradInferShapeFunctor); diff --git a/paddle/fluid/operators/eigvalsh_op.cu b/paddle/fluid/operators/eigvalsh_op.cu deleted file mode 100644 index 880570d1be09b9dcf2eab7a2a9139e39ba40e4fb..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/eigvalsh_op.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/eigvalsh_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(eigvalsh, - ops::EigvalshKernel, - ops::EigvalshKernel, - ops::EigvalshKernel>, - ops::EigvalshKernel>); - -REGISTER_OP_CUDA_KERNEL( - eigvalsh_grad, - ops::EigvalshGradKernel, - ops::EigvalshGradKernel, - ops::EigvalshGradKernel>, - ops::EigvalshGradKernel>); diff --git a/paddle/fluid/operators/eigvalsh_op.h b/paddle/fluid/operators/eigvalsh_op.h deleted file mode 100644 index 9aa548445f6c141900432e4f230053c4456630da..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/eigvalsh_op.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/eigen_values_vectors.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -using EigenVector = framework::EigenVector; - -template -class EigvalshKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto input = ctx.Input("X"); - auto output_w = ctx.Output("Eigenvalues"); - - std::string lower = ctx.Attr("UPLO"); - bool is_lower = (lower == "L"); - bool is_test = ctx.Attr("is_test"); - math::MatrixEighFunctor functor; - if (is_test) { - functor(ctx, *input, output_w, nullptr, is_lower, false); - } else { - auto output_v = ctx.Output("Eigenvectors"); - functor(ctx, *input, output_w, output_v, is_lower, true); - } - } -}; - -template -class EigvalshGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& x_grad = *ctx.Output(framework::GradVarName("X")); - auto& output_v = *ctx.Input("Eigenvectors"); - auto& output_w_grad = - *ctx.Input(framework::GradVarName("Eigenvalues")); - - auto dito = - math::DeviceIndependenceTensorOperations( - ctx); - auto tV = dito.Transpose(dito.Conj(output_v)); - - // compute elementwise multiply of output_v and output_w_grad - x_grad.mutable_data(output_v.dims(), ctx.GetPlace()); - auto output_v_vector = EigenVector::Flatten(output_v); - auto output_w_grad_vector = EigenVector::Flatten(output_w_grad); - auto result_vector = EigenVector::Flatten(x_grad); - auto& place = *ctx.template device_context().eigen_device(); - std::vector broadcast_factor; - broadcast_factor.push_back(output_v.dims().at(output_v.dims().size() - 1)); - result_vector.device(place) = - output_v_vector * output_w_grad_vector.broadcast(broadcast_factor); - - x_grad = dito.Matmul(x_grad, tV); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 9146ea46e2008335515d71ad0cad5d11acf06ae8..fa212ea8f12bfe2272d4adb590307f37cf0ad2bd 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -682,6 +682,15 @@ kernel : func : eigvals +- api : eigvalsh + args : (Tensor x, str uplo, bool is_test) + output : Tensor(eigenvalues), Tensor(eigenvectors) + infer_meta : + func : EigvalshInferMeta + kernel : + func : eigvalsh + backward : eigvalsh_grad + - api : einsum args : (Tensor[] x, str equation) output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 4f5fdbc32c5bcce03961c32d4565d221f1721021..c7fa3c13e6067f9ee0f7c579f1705ff15a7f67f6 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -667,6 +667,18 @@ data_transform: skip_transform : out_w, out_w_grad +- backward_api : eigvalsh_grad + forward : eigvalsh (Tensor x, str uplo, bool is_test) -> Tensor(eigenvalues), Tensor(eigenvectors) + args : (Tensor eigenvectors, Tensor eigenvalues_grad, str uplo, bool is_test) + output : Tensor(x_grad) + infer_meta : + func : EigvalshGradInferMeta + kernel : + func : eigvalsh_grad + data_type : eigenvectors + data_transform : + skip_transform : eigenvalues_grad + - backward_api : einsum_grad forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index a4c800017ccd987e26cb2ca3ee7719f218edff49..a8555827c0527f03f2979ea2d62c40a6468a0802 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -263,6 +263,18 @@ void EigGradInferMeta(const MetaTensor& out_w, } } +void EigvalshGradInferMeta(const MetaTensor& out_v, + const MetaTensor& out_w_grad, + const std::string& uplo, + bool is_test, + MetaTensor* x_grad) { + auto dims = out_v.dims(); + if (x_grad != nullptr) { + x_grad->set_dims(dims); + x_grad->set_dtype(out_v.dtype()); + } +} + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 8ce5dba92198ba9d5f6861395b5e138fe8764771..d9208b7c52491ffecd2cf3c1900fb4569c0a0db7 100755 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -126,6 +126,12 @@ void EigGradInferMeta(const MetaTensor& out_w, const MetaTensor& dout_v, MetaTensor* dx); +void EigvalshGradInferMeta(const MetaTensor& out_v, + const MetaTensor& out_w_grad, + const std::string& uplo, + bool is_test, + MetaTensor* x_grad); + void GatherNdGradInferMeta(const MetaTensor& x, const MetaTensor& index, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 87b5f8265d3ee85a206c8ddbbe04a3dcfd28a550..0aa2035257a5209ce1133de09403c796772e7f80 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -622,6 +622,46 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) { out->set_dtype(out_dtype); } +void EigvalshInferMeta(const MetaTensor& x, + const std::string& uplo, + bool is_test, + MetaTensor* out_w, + MetaTensor* out_v) { + auto input_dim = x.dims(); + auto rank = input_dim.size(); + + PADDLE_ENFORCE_GE( + rank, + 2, + errors::InvalidArgument("The Input(X) should have at least 2 dimensions." + "But received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + input_dim[rank - 2], + input_dim[rank - 1], + errors::InvalidArgument( + "Eigvalsh op is designed for square matrix, consequently" + "inner-most 2 dimensions of Input(X) should be symmetric." + "But received X's shape[-2] = %d and shape[-1] = %d.", + input_dim[rank - 2], + input_dim[rank - 1])); + + std::vector values_dim; + + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); + } + + if (out_w != nullptr) { + out_w->set_dims(phi::make_ddim(values_dim)); + out_w->set_dtype(dtype::ToReal(x.dtype())); + } + if (out_v != nullptr) { + out_v->set_dims(input_dim); + out_v->set_dtype(x.dtype()); + } +} + void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 21c052580a829b0bf5dfb2bcd8994dad710f821b..a37492cf7ec77a13f8bf2a1c361108cbeaf7df98 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -103,6 +103,12 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void EigvalshInferMeta(const MetaTensor& x, + const std::string& uplo, + bool is_test, + MetaTensor* out_w, + MetaTensor* out_v); + void EinsumInferMeta(const std::vector& inputs, const std::string& equation, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/eigvalsh_grad_kernel.cc b/paddle/phi/kernels/cpu/eigvalsh_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7b5927740e0939e61204bf0b83bd5af38e15ef9 --- /dev/null +++ b/paddle/phi/kernels/cpu/eigvalsh_grad_kernel.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/eigvalsh_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(eigvalsh_grad, + CPU, + ALL_LAYOUT, + phi::EigvalshGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/eigvalsh_kernel.cc b/paddle/phi/kernels/cpu/eigvalsh_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfbb7bd6fbc72a2720cda9ed54179eca3f02cd40 --- /dev/null +++ b/paddle/phi/kernels/cpu/eigvalsh_kernel.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/eigvalsh_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/eigvalsh_kernel_impl.h" + +PD_REGISTER_KERNEL(eigvalsh, + CPU, + ALL_LAYOUT, + phi::EigvalshKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/eigvalsh_grad_kernel.h b/paddle/phi/kernels/eigvalsh_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..934586f9b7bfbd756fc7242e63104ed2c50953d1 --- /dev/null +++ b/paddle/phi/kernels/eigvalsh_grad_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EigvalshGradKernel(const Context& dev_ctx, + const DenseTensor& out_v, + const DenseTensor& out_w_grad, + const std::string& uplo, + bool is_test, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/eigvalsh_kernel.h b/paddle/phi/kernels/eigvalsh_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..bd586a615924c848071416bc691ac57579df31c0 --- /dev/null +++ b/paddle/phi/kernels/eigvalsh_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EigvalshKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& uplo, + bool is_test, + DenseTensor* out_w, + DenseTensor* out_v); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/eigvalsh_grad_kernel.cu b/paddle/phi/kernels/gpu/eigvalsh_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..de26617d80f1b81af007a0327ae0d3ed18cff52c --- /dev/null +++ b/paddle/phi/kernels/gpu/eigvalsh_grad_kernel.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/eigvalsh_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(eigvalsh_grad, + GPU, + ALL_LAYOUT, + phi::EigvalshGradKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/eigvalsh_kernel.cu b/paddle/phi/kernels/gpu/eigvalsh_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..383f036c98cf9dfb65a419e46533dc974ccf8264 --- /dev/null +++ b/paddle/phi/kernels/gpu/eigvalsh_kernel.cu @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/eigvalsh_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/eigvalsh_kernel_impl.h" + +PD_REGISTER_KERNEL(eigvalsh, // cuda_only + GPU, + ALL_LAYOUT, + phi::EigvalshKernel, + float, + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h b/paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7248985bf294c60e370d76767b38d228c12be882 --- /dev/null +++ b/paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#pragma once + +#include "paddle/phi/kernels/eigvalsh_grad_kernel.h" + +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +void EigvalshGradKernel(const Context& dev_ctx, + const DenseTensor& out_v, + const DenseTensor& out_w_grad, + const std::string& uplo, + bool is_test, + DenseTensor* x_grad) { + auto tV = phi::TransposeLast2Dim(dev_ctx, phi::Conj(dev_ctx, out_v)); + + x_grad->Resize(out_v.dims()); + dev_ctx.template Alloc(x_grad); + + auto output_v_vector = EigenVector::Flatten(out_v); + auto output_w_grad_vector = + EigenVector>::Flatten(out_w_grad); + auto result_vector = EigenVector::Flatten(*x_grad); + auto& place = *dev_ctx.eigen_device(); + std::vector broadcast_factor; + broadcast_factor.push_back(out_v.dims().at(out_v.dims().size() - 1)); + result_vector.device(place) = + output_v_vector * output_w_grad_vector.broadcast(broadcast_factor); + + *x_grad = phi::Matmul(dev_ctx, *x_grad, tV); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/eigvalsh_kernel_impl.h b/paddle/phi/kernels/impl/eigvalsh_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e56192d9ed6f31ef761c77ccd06b08e3e47a3663 --- /dev/null +++ b/paddle/phi/kernels/impl/eigvalsh_kernel_impl.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/eigvalsh_kernel.h" + +#include "paddle/phi/kernels/funcs/values_vectors_functor.h" + +namespace phi { + +template +void EigvalshKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::string& uplo, + bool is_test, + DenseTensor* out_w, + DenseTensor* out_v) { + bool is_lower = (uplo == "L"); + phi::funcs::MatrixEighFunctor functor; + if (is_test) { + functor(dev_ctx, x, out_w, nullptr, is_lower, false); + } else { + functor(dev_ctx, x, out_w, out_v, is_lower, true); + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/eigvalsh_sig.cc b/paddle/phi/ops/compat/eigvalsh_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0635403355f7b4ba8044a744ec27718dd5b03ec --- /dev/null +++ b/paddle/phi/ops/compat/eigvalsh_sig.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EigvalshOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "eigvalsh", {"X"}, {"UPLO", "is_test"}, {"Eigenvalues", "Eigenvectors"}); +} + +KernelSignature EigvalshGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("eigvalsh_grad", + {"Eigenvectors", "Eigenvalues@GRAD"}, + {"UPLO", "is_test"}, + {"X@GRAD"}); +} +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(eigvalsh, phi::EigvalshOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(eigvalsh_grad, phi::EigvalshGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py index e518491588d51ada508e8686cf8e742ae7d1f26a..e1378eb722772caddf99a39ba59ebc17d2d18651 100644 --- a/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigvalsh_op.py @@ -51,6 +51,8 @@ class TestEigvalshOp(OpTest): def setUp(self): paddle.enable_static() self.op_type = "eigvalsh" + self.python_api = paddle.linalg.eigvalsh + self.python_out_sig = ['Eigenvalues'] self.init_input() self.init_config() np.random.seed(123) @@ -69,10 +71,10 @@ class TestEigvalshOp(OpTest): def test_check_output(self): # Vectors in posetive or negative is equivalent - self.check_output(no_check_set=['Eigenvectors']) + self.check_output(no_check_set=['Eigenvectors'], check_eager=True) def test_grad(self): - self.check_grad(["X"], ["Eigenvalues"]) + self.check_grad(["X"], ["Eigenvalues"], check_eager=True) class TestEigvalshUPLOCase(TestEigvalshOp): diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 4931f3ffbb919db9f42478ee8ad88ba88fec3869..4fd393da6f1bb7cf9c1bcdfc4ecab8e07fc6bf6e 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3052,7 +3052,11 @@ def eigvalsh(x, UPLO='L', name=None): print(out_value) #[0.17157288, 5.82842712] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + values, _ = _C_ops.final_state_eigvalsh(x, UPLO, x.stop_gradient) + return values + + elif paddle.in_dynamic_mode(): is_test = x.stop_gradient values, _ = _C_ops.eigvalsh(x, 'UPLO', UPLO, 'is_test', is_test) return values