From 192eb4d5d08b31e014de58155e6b4e4b417f94e3 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 5 Jan 2023 13:15:43 +0800 Subject: [PATCH] support generate static graph code for imag and real op (#49523) --- .../fluid/operators/generator/parse_utils.py | 15 ++- .../generator/templates/operator_utils.c.j2 | 3 + paddle/fluid/operators/imag_op.cc | 101 ------------------ paddle/fluid/operators/real_op.cc | 101 ------------------ paddle/phi/api/lib/api_custom_impl.cc | 52 --------- paddle/phi/api/yaml/backward.yaml | 20 ++++ paddle/phi/api/yaml/generator/api_base.py | 13 ++- paddle/phi/api/yaml/generator/api_gen.py | 1 + .../api/yaml/generator/backward_api_gen.py | 1 + paddle/phi/api/yaml/legacy_backward.yaml | 12 --- paddle/phi/api/yaml/legacy_ops.yaml | 18 ---- paddle/phi/api/yaml/op_compat.yaml | 14 +++ paddle/phi/api/yaml/ops.yaml | 18 ++++ paddle/phi/kernels/cpu/complex_grad_kernel.cc | 8 +- paddle/phi/kernels/gpu/complex_grad_kernel.cu | 8 +- paddle/phi/ops/compat/complex_sig.cc | 30 ------ 16 files changed, 95 insertions(+), 320 deletions(-) delete mode 100644 paddle/fluid/operators/imag_op.cc delete mode 100644 paddle/fluid/operators/real_op.cc delete mode 100644 paddle/phi/ops/compat/complex_sig.cc diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 8e80cdecf3..ba8587c3a1 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -187,7 +187,20 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]: kernel['layout'] = parse_candidates(kernel_config["layout"]) if 'data_type' in kernel_config: - kernel['data_type'] = parse_candidates(kernel_config["data_type"]) + data_type_item = parse_candidates(kernel_config["data_type"]) + params_num = len(data_type_item['candidates']) + data_type_item['to_complex_flag'] = [False] * params_num + for i in range(params_num): + complex_match_result = re.match( + r"complex\((?P\w+)\)", + data_type_item['candidates'][i], + ) + if complex_match_result: + data_type_item['candidates'][i] = complex_match_result.group( + 'param_name' + ) + data_type_item['to_complex_flag'][i] = True + kernel['data_type'] = data_type_item kernel_funcs = re.compile(r'([a-zA-Z0-9_]+)\s*({[^}]+})?').findall( kernel_config['func'] diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 207589edd5..0112c98afc 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -262,6 +262,9 @@ phi::KernelKey GetExpectedKernelType( {% set inputs = op["inputs"] | map(attribute="name") | list %} {% if data_type_arg in inputs %} auto data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, {{data_type_arg | to_opmaker_name}}); + {% if kernel["data_type"]["to_complex_flag"][0] %} + data_type = framework::ToComplexType(data_type); + {% endif %} {% else %}{# it is an attribute and probably named dtype#} auto data_type = framework::proto::VarType::Type(ctx.Attr("{{data_type_arg}}")); {% endif %} diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc deleted file mode 100644 index a2fdd53e03..0000000000 --- a/paddle/fluid/operators/imag_op.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class ImagOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class ImagOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of imag op."); - AddOutput("Out", "(Tensor), The output tensor of imag op."); - AddComment(R"DOC( -Imag Operator. - -This operator is used to get a new tensor containing imaginary values -from a tensor with complex data type. - -)DOC"); - } -}; - -class ImagGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@Grad", - "ImagGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@Grad", - "ImagGrad"); - - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - auto complex_dtype = framework::ToComplexType(dtype); - return phi::KernelKey(complex_dtype, ctx.GetPlace()); - } -}; - -template -class ImagGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("imag_grad"); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -DECLARE_INPLACE_OP_INFERER(ImagOpInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(imag, - ImagInferShapeFunctor, - PD_INFER_META(phi::RealAndImagInferMeta)); - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(imag, - ops::ImagOp, - ops::ImagOpMaker, - ops::ImagGradOpMaker, - ops::ImagGradOpMaker, - ImagInferShapeFunctor); -REGISTER_OPERATOR(imag_grad, ops::ImagGradOp); diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc deleted file mode 100644 index 94cdc2d658..0000000000 --- a/paddle/fluid/operators/real_op.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/unary.h" - -namespace paddle { -namespace operators { - -class RealOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; -}; - -class RealOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor), The input tensor of real op."); - AddOutput("Out", "(Tensor), The output tensor of real op."); - AddComment(R"DOC( -Real Operator. - -This operator is used to get a new tensor containing real values -from a tensor with complex data type. - -)DOC"); - } -}; - -class RealGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@Grad", - "RealGrad"); - OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), - "Output", - "X@Grad", - "RealGrad"); - - auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto dtype = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); - auto complex_dtype = framework::ToComplexType(dtype); - return phi::KernelKey(complex_dtype, ctx.GetPlace()); - } -}; - -template -class RealGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("real_grad"); - grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - } -}; - -DECLARE_INPLACE_OP_INFERER(RealOpInplaceInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer, - {framework::GradVarName("Out"), - framework::GradVarName("X")}); - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(real, - RealInferShapeFunctor, - PD_INFER_META(phi::RealAndImagInferMeta)); - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(real, - ops::RealOp, - ops::RealOpMaker, - ops::RealGradOpMaker<::paddle::framework::OpDesc>, - ops::RealGradOpMaker<::paddle::imperative::OpBase>, - RealInferShapeFunctor); -REGISTER_OPERATOR(real_grad, ops::RealGradOp); diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 77b2fa59c3..2bb5957df2 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -141,32 +141,6 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { ////////////////// Backward(grad) api impls ////////////////////// -void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) { - phi::KernelKey kernel_key{ParseBackend(out_grad), - out_grad.layout(), - phi::dtype::ToComplex(out_grad.dtype())}; - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "imag_grad", kernel_key); - const auto& kernel = kernel_result.kernel; - - VLOG(6) << "imag_grad API kernel key: " << kernel_key; - VLOG(6) << "imag_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - - auto dense_out_grad = TensorToDenseTensor(out_grad); - - auto kernel_out = SetKernelOutput(x_grad); - phi::MetaTensor meta_out(kernel_out); - phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out); - - using kernel_signature = void (*)( - const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); - - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out); -} - void embedding_grad_impl(const Tensor& x, const Tensor& weight, const Tensor& out_grad, @@ -290,31 +264,5 @@ void embedding_grad_impl(const Tensor& x, } } -void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) { - phi::KernelKey kernel_key{ParseBackend(out_grad), - out_grad.layout(), - phi::dtype::ToComplex(out_grad.dtype())}; - auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "real_grad", kernel_key); - const auto& kernel = kernel_result.kernel; - - VLOG(6) << "real_grad API kernel key: " << kernel_key; - VLOG(6) << "real_grad API kernel: " << kernel; - - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); - - auto dense_out_grad = TensorToDenseTensor(out_grad); - - auto kernel_out = SetKernelOutput(x_grad); - phi::MetaTensor meta_out(kernel_out); - phi::RealAndImagGradInferMeta(*dense_out_grad, &meta_out); - - using kernel_signature = void (*)( - const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); - - auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out); -} - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 8f107f02da..0eee2b6a2b 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -563,6 +563,16 @@ func : hard_sigmoid_grad inplace : (out_grad -> x_grad) +- backward_op : imag_grad + forward : imag (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : RealAndImagGradInferMeta + kernel : + func : imag_grad + data_type : complex(out_grad) + - backward_op : index_sample_grad forward : index_sample (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) @@ -868,6 +878,16 @@ kernel : func : qr_grad +- backward_op : real_grad + forward : real (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : RealAndImagGradInferMeta + kernel : + func : real_grad + data_type : complex(out_grad) + - backward_op : reciprocal_grad forward : reciprocal (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index f40583f63c..e1def4e913 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -486,6 +486,17 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d ) if kernel['data_type'] is not None: + + def process_data_type_args(args_item): + args_item = args_item.strip() + complex_match_result = re.match( + r"complex\((?P\w+)\)", args_item + ) + if complex_match_result: + return f"phi::dtype::ToComplex(ParseDataType({complex_match_result.group('param_name')}))" + else: + return f"ParseDataType({args_item})" + if '>' in kernel['data_type']: vars_list = kernel['data_type'].split('>') assert ( @@ -511,7 +522,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d kernel_select_code = ( kernel_select_code + f""" - kernel_data_type = ParseDataType({vars_list[0].strip()}); + kernel_data_type = {process_data_type_args(vars_list[0])}; """ ) diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 63e6d6cb50..0a05ec6eb3 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -335,6 +335,7 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/multiary.h" diff --git a/paddle/phi/api/yaml/generator/backward_api_gen.py b/paddle/phi/api/yaml/generator/backward_api_gen.py index 4d10f8b56b..f01200ec3a 100644 --- a/paddle/phi/api/yaml/generator/backward_api_gen.py +++ b/paddle/phi/api/yaml/generator/backward_api_gen.py @@ -280,6 +280,7 @@ def source_include(header_file_path): #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/infermeta/backward.h" diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index acc7b670ba..5621b2c7db 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -617,12 +617,6 @@ kernel : func : huber_loss_grad -- backward_op : imag_grad - forward : imag (Tensor x) -> Tensor(out) - args : (Tensor out_grad) - output : Tensor(x_grad) - invoke : imag_grad_impl(out_grad, x_grad) - - backward_op : index_add_grad forward : index_add(Tensor x, Tensor index, Tensor add_value, int axis) -> Tensor(out) args : (Tensor index, Tensor add_value, Tensor out_grad, int axis) @@ -1125,12 +1119,6 @@ data_type : x optional : boxes_num -- backward_op : real_grad - forward : real (Tensor x) -> Tensor(out) - args : (Tensor out_grad) - output : Tensor(x_grad) - invoke : real_grad_impl(out_grad, x_grad) - - backward_op : relu6_grad forward : relu6 (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad, float threshold = 6) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 6dfff5d510..b0294c245a 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -900,15 +900,6 @@ func : huber_loss backward : huber_loss_grad -- op : imag - args : (Tensor x) - output : Tensor - infer_meta : - func : RealAndImagInferMeta - kernel : - func : imag - backward : imag_grad - - op : increment args : (Tensor x, float value = 1.0) output : Tensor(out) @@ -1507,15 +1498,6 @@ data_type : dtype backend : place -- op : real - args : (Tensor x) - output : Tensor - infer_meta : - func : RealAndImagInferMeta - kernel : - func : real - backward : real_grad - - op : relu6 args : (Tensor x) output : Tensor diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index cb6f67fbdf..ceee57a771 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -646,6 +646,13 @@ outputs : out : Out +- op : imag + backward : imag_grad + inputs : + x : X + outputs : + out : Out + - op : index_sample inputs : {x : X, index : Index} @@ -997,6 +1004,13 @@ extra : attrs : [float moving_rate = 0.9] +- op : real + backward : real_grad + inputs : + x : X + outputs : + out : Out + - op : reciprocal backward : reciprocal_grad inputs : diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e5378ce077..127d856e37 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -517,6 +517,15 @@ kernel : func : histogram +- op : imag + args : (Tensor x) + output : Tensor (out) + infer_meta : + func : RealAndImagInferMeta + kernel : + func : imag + backward : imag_grad + - op : index_sample args : (Tensor x, Tensor index) output : Tensor @@ -839,6 +848,15 @@ func : qr backward : qr_grad +- op : real + args : (Tensor x) + output : Tensor (out) + infer_meta : + func : RealAndImagInferMeta + kernel : + func : real + backward : real_grad + - op : reciprocal args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/complex_grad_kernel.cc b/paddle/phi/kernels/cpu/complex_grad_kernel.cc index 049022f01e..1053700a13 100644 --- a/paddle/phi/kernels/cpu/complex_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_grad_kernel.cc @@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(real_grad, ALL_LAYOUT, phi::RealGradKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(imag_grad, CPU, ALL_LAYOUT, phi::ImagGradKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL( complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { diff --git a/paddle/phi/kernels/gpu/complex_grad_kernel.cu b/paddle/phi/kernels/gpu/complex_grad_kernel.cu index e9fd5e1fa5..b2a6e4117c 100644 --- a/paddle/phi/kernels/gpu/complex_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_grad_kernel.cu @@ -23,14 +23,18 @@ PD_REGISTER_KERNEL(imag_grad, ALL_LAYOUT, phi::ImagGradKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL(real_grad, GPU, ALL_LAYOUT, phi::RealGradKernel, phi::dtype::complex, - phi::dtype::complex) {} + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} PD_REGISTER_KERNEL( complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { diff --git a/paddle/phi/ops/compat/complex_sig.cc b/paddle/phi/ops/compat/complex_sig.cc deleted file mode 100644 index 88156677d3..0000000000 --- a/paddle/phi/ops/compat/complex_sig.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature RealGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("real_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); -} - -KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping); -- GitLab