From 345cc8fa7115b00bc7589161346c147730bced69 Mon Sep 17 00:00:00 2001 From: From00 Date: Tue, 22 Feb 2022 19:19:52 +0800 Subject: [PATCH] Move real and imag op to phi (#39777) * Move Real OP to phi * Move Imag OP to phi * Move Real and Imag InferShape to phi * Move Real and Imag to complex_kernel * Change PT_REGISTER_XXX to PD_REGISTER_XXX --- paddle/fluid/operators/imag_op.cc | 30 +++------ paddle/fluid/operators/imag_op.cu | 28 -------- paddle/fluid/operators/imag_op.h | 67 ------------------- paddle/fluid/operators/real_op.cc | 29 +++----- paddle/fluid/operators/real_op.cu | 28 -------- paddle/fluid/operators/real_op.h | 67 ------------------- paddle/phi/kernels/complex_grad_kernel.h | 31 +++++++++ paddle/phi/kernels/complex_kernel.h | 10 +++ paddle/phi/kernels/cpu/complex_grad_kernel.cc | 33 +++++++++ paddle/phi/kernels/cpu/complex_kernel.cc | 14 ++++ paddle/phi/kernels/gpu/complex_grad_kernel.cu | 33 +++++++++ paddle/phi/kernels/gpu/complex_kernel.cu | 14 ++++ .../kernels/impl/complex_grad_kernel_impl.h | 50 ++++++++++++++ paddle/phi/kernels/impl/complex_kernel_impl.h | 28 ++++++++ paddle/phi/ops/compat/complex_sig.cc | 32 +++++++++ 15 files changed, 263 insertions(+), 231 deletions(-) delete mode 100644 paddle/fluid/operators/imag_op.cu delete mode 100644 paddle/fluid/operators/imag_op.h delete mode 100644 paddle/fluid/operators/real_op.cu delete mode 100644 paddle/fluid/operators/real_op.h create mode 100644 paddle/phi/kernels/complex_grad_kernel.h create mode 100644 paddle/phi/kernels/cpu/complex_grad_kernel.cc create mode 100644 paddle/phi/kernels/gpu/complex_grad_kernel.cu create mode 100644 paddle/phi/kernels/impl/complex_grad_kernel_impl.h create mode 100644 paddle/phi/ops/compat/complex_sig.cc diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc index 6a195bb940..33b68d6899 100644 --- a/paddle/fluid/operators/imag_op.cc +++ b/paddle/fluid/operators/imag_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/imag_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,15 +23,6 @@ namespace operators { class ImagOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Imag"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Imag"); - - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", "Out"); - } }; class ImagOpMaker : public framework::OpProtoAndCheckerMaker { @@ -88,19 +82,13 @@ DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer, } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(imag, ImagInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker, ops::ImagGradOpMaker, - ops::ImagGradOpMaker); + ops::ImagGradOpMaker, + ImagInferShapeFunctor); REGISTER_OPERATOR(imag_grad, ops::ImagGradOp); - -REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel>, - ops::ImagKernel>); -REGISTER_OP_CPU_KERNEL(imag_grad, - ops::ImagGradKernel>, - ops::ImagGradKernel>); diff --git a/paddle/fluid/operators/imag_op.cu b/paddle/fluid/operators/imag_op.cu deleted file mode 100644 index 9cfb2ef7f2..0000000000 --- a/paddle/fluid/operators/imag_op.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/imag_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(imag, - ops::ImagKernel>, - ops::ImagKernel>); -REGISTER_OP_CUDA_KERNEL(imag_grad, - ops::ImagGradKernel>, - ops::ImagGradKernel>); diff --git a/paddle/fluid/operators/imag_op.h b/paddle/fluid/operators/imag_op.h deleted file mode 100644 index 33eab2abb7..0000000000 --- a/paddle/fluid/operators/imag_op.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -template -class ImagKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* out = ctx.Output("Out"); - - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data>( - ctx.GetPlace(), - static_cast(numel * sizeof(phi::funcs::Real))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::ImagFunctor functor(x_data, out_data, numel); - for_range(functor); - } -}; - -template -class ImagGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - auto* dout_data = d_out->data>(); - auto* dx_data = d_x->mutable_data( - ctx.GetPlace(), static_cast(numel * sizeof(T))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::ImagToComplexFunctor functor(dout_data, dx_data, numel); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc index 1174e72a76..1f3691978b 100644 --- a/paddle/fluid/operators/real_op.cc +++ b/paddle/fluid/operators/real_op.cc @@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/real_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -20,14 +23,6 @@ namespace operators { class RealOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Real"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Real"); - - auto x_dims = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", x_dims); - ctx->ShareLoD("X", "Out"); - } }; class RealOpMaker : public framework::OpProtoAndCheckerMaker { @@ -87,19 +82,13 @@ DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer, } // namespace operators } // namespace paddle +DELCARE_INFER_SHAPE_FUNCTOR(real, RealInferShapeFunctor, + PT_INFER_META(phi::UnchangedInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker, ops::RealGradOpMaker<::paddle::framework::OpDesc>, - ops::RealGradOpMaker<::paddle::imperative::OpBase>); + ops::RealGradOpMaker<::paddle::imperative::OpBase>, + RealInferShapeFunctor); REGISTER_OPERATOR(real_grad, ops::RealGradOp); - -REGISTER_OP_CPU_KERNEL(real, ops::RealKernel>, - ops::RealKernel>); -REGISTER_OP_CPU_KERNEL(real_grad, - ops::RealGradKernel>, - ops::RealGradKernel>); diff --git a/paddle/fluid/operators/real_op.cu b/paddle/fluid/operators/real_op.cu deleted file mode 100644 index 9bfb2878a6..0000000000 --- a/paddle/fluid/operators/real_op.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/real_op.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(real, - ops::RealKernel>, - ops::RealKernel>); -REGISTER_OP_CUDA_KERNEL(real_grad, - ops::RealGradKernel>, - ops::RealGradKernel>); diff --git a/paddle/fluid/operators/real_op.h b/paddle/fluid/operators/real_op.h deleted file mode 100644 index c5a9724e8a..0000000000 --- a/paddle/fluid/operators/real_op.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -template -class RealKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* x = ctx.Input("X"); - framework::Tensor* out = ctx.Output("Out"); - - auto numel = x->numel(); - auto* x_data = x->data(); - auto* out_data = out->mutable_data>( - ctx.GetPlace(), - static_cast(numel * sizeof(phi::funcs::Real))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::RealFunctor functor(x_data, out_data, numel); - for_range(functor); - } -}; - -template -class RealGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - const framework::Tensor* d_out = - ctx.Input(framework::GradVarName("Out")); - framework::Tensor* d_x = - ctx.Output(framework::GradVarName("X")); - - auto numel = d_out->numel(); - auto* dout_data = d_out->data>(); - auto* dx_data = d_x->mutable_data( - ctx.GetPlace(), static_cast(numel * sizeof(T))); - - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::RealToComplexFunctor functor(dout_data, dx_data, numel); - for_range(functor); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/complex_grad_kernel.h b/paddle/phi/kernels/complex_grad_kernel.h new file mode 100644 index 0000000000..505d4d3744 --- /dev/null +++ b/paddle/phi/kernels/complex_grad_kernel.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void RealGradKernel(const Context& dev_ctx, + const DenseTensor& dout, + DenseTensor* dx); + +template +void ImagGradKernel(const Context& dev_ctx, + const DenseTensor& dout, + DenseTensor* dx); + +} // namespace phi diff --git a/paddle/phi/kernels/complex_kernel.h b/paddle/phi/kernels/complex_kernel.h index cfe9da2388..44bfae9820 100644 --- a/paddle/phi/kernels/complex_kernel.h +++ b/paddle/phi/kernels/complex_kernel.h @@ -50,4 +50,14 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { return x; } +template +void RealKernel(const DeviceContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +template +void ImagKernel(const DeviceContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/cpu/complex_grad_kernel.cc b/paddle/phi/kernels/cpu/complex_grad_kernel.cc new file mode 100644 index 0000000000..5c1d50f5bf --- /dev/null +++ b/paddle/phi/kernels/cpu/complex_grad_kernel.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/complex_grad_kernel.h" +#include "paddle/phi/kernels/impl/complex_grad_kernel_impl.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(real_grad, + CPU, + ALL_LAYOUT, + phi::RealGradKernel, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(imag_grad, + CPU, + ALL_LAYOUT, + phi::ImagGradKernel, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/complex_kernel.cc b/paddle/phi/kernels/cpu/complex_kernel.cc index ae09f2a5ef..801502e167 100644 --- a/paddle/phi/kernels/cpu/complex_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_kernel.cc @@ -31,3 +31,17 @@ PD_REGISTER_KERNEL(conj, double, int, int64_t) {} + +PD_REGISTER_KERNEL(real, + CPU, + ALL_LAYOUT, + phi::RealKernel, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(imag, + CPU, + ALL_LAYOUT, + phi::ImagKernel, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/complex_grad_kernel.cu b/paddle/phi/kernels/gpu/complex_grad_kernel.cu new file mode 100644 index 0000000000..ad694445d1 --- /dev/null +++ b/paddle/phi/kernels/gpu/complex_grad_kernel.cu @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/complex_grad_kernel.h" +#include "paddle/phi/kernels/impl/complex_grad_kernel_impl.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(imag_grad, + GPU, + ALL_LAYOUT, + phi::ImagGradKernel, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(real_grad, + GPU, + ALL_LAYOUT, + phi::RealGradKernel, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/complex_kernel.cu b/paddle/phi/kernels/gpu/complex_kernel.cu index 02fd408aba..d0b086718a 100644 --- a/paddle/phi/kernels/gpu/complex_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_kernel.cu @@ -32,3 +32,17 @@ PD_REGISTER_KERNEL(conj, double, int, int64_t) {} + +PD_REGISTER_KERNEL(real, + GPU, + ALL_LAYOUT, + phi::RealKernel, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(imag, + GPU, + ALL_LAYOUT, + phi::ImagKernel, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h new file mode 100644 index 0000000000..febc464e6a --- /dev/null +++ b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h @@ -0,0 +1,50 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" + +namespace phi { + +template +void RealGradKernel(const Context& dev_ctx, + const DenseTensor& dout, + DenseTensor* dx) { + auto numel = dout.numel(); + auto* dout_data = dout.data>(); + auto* dx_data = + dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); + + paddle::platform::ForRange for_range(dev_ctx, numel); + phi::funcs::RealToComplexFunctor functor(dout_data, dx_data, numel); + for_range(functor); +} + +template +void ImagGradKernel(const Context& dev_ctx, + const DenseTensor& dout, + DenseTensor* dx) { + auto numel = dout.numel(); + auto* dout_data = dout.data>(); + auto* dx_data = + dev_ctx.template Alloc(dx, static_cast(numel * sizeof(T))); + + paddle::platform::ForRange for_range(dev_ctx, numel); + phi::funcs::ImagToComplexFunctor functor(dout_data, dx_data, numel); + for_range(functor); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/complex_kernel_impl.h b/paddle/phi/kernels/impl/complex_kernel_impl.h index 910a7be965..2f9b1ad046 100644 --- a/paddle/phi/kernels/impl/complex_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_kernel_impl.h @@ -33,4 +33,32 @@ void ConjKernel(const Context& dev_ctx, for_range(functor); } +template +void RealKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc>( + out, static_cast(numel * sizeof(phi::funcs::Real))); + + paddle::platform::ForRange for_range(dev_ctx, numel); + phi::funcs::RealFunctor functor(x_data, out_data, numel); + for_range(functor); +} + +template +void ImagKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc>( + out, static_cast(numel * sizeof(phi::funcs::Real))); + + paddle::platform::ForRange for_range(dev_ctx, numel); + phi::funcs::ImagFunctor functor(x_data, out_data, numel); + for_range(functor); +} + } // namespace phi diff --git a/paddle/phi/ops/compat/complex_sig.cc b/paddle/phi/ops/compat/complex_sig.cc new file mode 100644 index 0000000000..b9f59c97fb --- /dev/null +++ b/paddle/phi/ops/compat/complex_sig.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "real_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); +} + +KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "imag_grad", {GradVarName("Out")}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping); -- GitLab