From b1aa693e888083226933105a42907f5109753fd7 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 13 Jul 2022 14:59:41 +0800 Subject: [PATCH] [Phi] Migrate complex_op into Phi & Add complex api yaml (#44233) * mv to phi * refine infermeta code position * refine grad code * add api yaml and add final_state_api * refine code --- paddle/fluid/operators/complex_op.cc | 78 +++-------- paddle/fluid/operators/complex_op.cu | 28 ---- paddle/fluid/operators/complex_op.h | 123 ------------------ paddle/phi/api/yaml/legacy_api.yaml | 9 ++ paddle/phi/api/yaml/legacy_backward.yaml | 10 ++ paddle/phi/infermeta/backward.cc | 17 +++ paddle/phi/infermeta/backward.h | 6 + paddle/phi/infermeta/binary.cc | 32 +++++ paddle/phi/infermeta/binary.h | 4 + paddle/phi/kernels/complex_grad_kernel.h | 8 ++ paddle/phi/kernels/complex_kernel.h | 6 + paddle/phi/kernels/cpu/complex_grad_kernel.cc | 5 + paddle/phi/kernels/cpu/complex_kernel.cc | 5 + paddle/phi/kernels/gpu/complex_grad_kernel.cu | 5 + paddle/phi/kernels/gpu/complex_kernel.cu | 5 + .../kernels/impl/complex_grad_kernel_impl.h | 48 +++++++ paddle/phi/kernels/impl/complex_kernel_impl.h | 43 ++++++ paddle/phi/ops/compat/complex_sig.cc | 7 + .../fluid/tests/unittests/test_complex_op.py | 1 + python/paddle/tensor/creation.py | 3 + 20 files changed, 230 insertions(+), 213 deletions(-) delete mode 100644 paddle/fluid/operators/complex_op.cu delete mode 100644 paddle/fluid/operators/complex_op.h diff --git a/paddle/fluid/operators/complex_op.cc b/paddle/fluid/operators/complex_op.cc index d6d93fe9581..778f5831c0f 100644 --- a/paddle/fluid/operators/complex_op.cc +++ b/paddle/fluid/operators/complex_op.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/complex_op.h" - -#include - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -59,36 +59,6 @@ class ComplexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "complex"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "complex"); - - if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) { - ctx->ShareDim("X", /*->*/ "Out"); - // NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory - // so tensors with lod are not supported here - } else { - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - int max_dim = std::max(x_dims.size(), y_dims.size()); - - // start align axis - int axis = std::abs(x_dims.size() - y_dims.size()); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, - y_dims, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array)); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -101,25 +71,6 @@ class ComplexGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex_grad"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_complex_gradgrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "complex_grad"); - - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->ShareDim("X", /*->*/ x_grad_name); - } - - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(y_grad_name)) { - ctx->ShareDim("Y", /*->*/ y_grad_name); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -135,18 +86,21 @@ class ComplexGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(complex, + ComplexInferShapeFunctor, + PD_INFER_META(phi::ComplexInferMeta)); + REGISTER_OPERATOR(complex, ops::ComplexOp, ops::ComplexOpMaker, ops::ComplexGradOpMaker, - ops::ComplexGradOpMaker); - -REGISTER_OPERATOR(complex_grad, ops::ComplexGradOp); + ops::ComplexGradOpMaker, + ComplexInferShapeFunctor); -REGISTER_OP_CPU_KERNEL(complex, - ops::ComplexKernel, - ops::ComplexKernel); +DECLARE_INFER_SHAPE_FUNCTOR(complex_grad, + ComplexGradInferShapeFunctor, + PD_INFER_META(phi::ComplexGradInferMeta)); -REGISTER_OP_CPU_KERNEL(complex_grad, - ops::ComplexGradKernel, - ops::ComplexGradKernel); +REGISTER_OPERATOR(complex_grad, + ops::ComplexGradOp, + ComplexGradInferShapeFunctor); diff --git a/paddle/fluid/operators/complex_op.cu b/paddle/fluid/operators/complex_op.cu deleted file mode 100644 index c9bc2d459e7..00000000000 --- a/paddle/fluid/operators/complex_op.cu +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/complex_op.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - complex, - ops::ComplexKernel, - ops::ComplexKernel); - -REGISTER_OP_CUDA_KERNEL( - complex_grad, - ops::ComplexGradKernel, - ops::ComplexGradKernel); diff --git a/paddle/fluid/operators/complex_op.h b/paddle/fluid/operators/complex_op.h deleted file mode 100644 index 5fb19b46ec6..00000000000 --- a/paddle/fluid/operators/complex_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -// functors to use with ElementwiseComputeEx -template -struct RealAndImagToComplexFunctor { - inline HOSTDEVICE platform::complex operator()(const T x, const T y) { - return platform::complex(x, y); - } -}; - -template -struct ImagAndRealToComplexFunctor { - inline HOSTDEVICE platform::complex operator()(const T y, const T x) { - return platform::complex(x, y); - } -}; - -template -struct ComplexGradForRealFunctor { - inline HOSTDEVICE T operator()(const T x, - const T y, - const platform::complex out, - const platform::complex dout) { - return dout.real; - } -}; - -template -struct ComplexGradForImagFunctor { - inline HOSTDEVICE T operator()(const T x, - const T y, - const platform::complex out, - const platform::complex dout) { - return dout.imag; - } -}; - -template -class ComplexKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* x = ctx.Input("X"); - const auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - using C = platform::complex; - z->mutable_data(ctx.GetPlace()); - -// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related -// facility functions -#if defined(__NVCC__) || defined(__HIPCC__) - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); -#else - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), z); - } -#endif - } -}; - -template -class ComplexGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - using C = platform::complex; - - // skip out in a hacky way - auto* out = dout; - ElemwiseGradCompute, - ComplexGradForImagFunctor, - C>(ctx, - *x, - *y, - *out, - *dout, - /*axis*/ -1, - dx, - dy, - ComplexGradForRealFunctor(), - ComplexGradForImagFunctor()); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index aa86c0f34db..cd01c236410 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -342,6 +342,15 @@ func : clip backward : clip_grad +- api : complex + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : ComplexInferMeta + kernel : + func : complex + backward : complex_grad + - api : concat args : (Tensor[] x, Scalar(int64_t) axis) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f01598e6434..b4972c68a64 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -306,6 +306,16 @@ backward : clip_double_grad inplace : (out_grad -> x_grad) +- backward_api : complex_grad + forward : complex (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : ComplexGradInferMeta + kernel : + func : complex_grad + data_type : x + - backward_api : concat_double_grad forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) args : (Tensor[] grad_x_grad, Scalar axis = 0) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index f59ea5549bd..dd2d1eb482c 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -83,6 +83,23 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, x_grad->set_dtype(out_grad.dtype()); } +void ComplexGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy) { + auto x_dims = x.dims(); + if (dx) { + dx->set_dims(x_dims); + dx->set_dtype(x.dtype()); + } + auto y_dims = y.dims(); + if (dy) { + dy->set_dims(y_dims); + dy->set_dtype(y.dtype()); + } +} + void ConvTransposeGradInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 0e7ed640d8f..6a4eba74b47 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -42,6 +42,12 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, const std::string& data_format, MetaTensor* x_grad); +void ComplexGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy); + void ConvTransposeGradInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 269286d76d9..460b0a9e1bd 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/cpu/conv_util.h" @@ -358,6 +359,37 @@ void CompareAllInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void ComplexInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + if (x.dims() == y.dims()) { + auto sizes = vectorize(x.dims()); + out->set_dims(phi::make_ddim(sizes)); + out->set_dtype(dtype::ToComplex(x.dtype())); + // NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory + // so tensors with lod are not supported here + } else { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + + // start align axis + int axis = std::abs(x_dims.size() - y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + out->set_dtype(dtype::ToComplex(x.dtype())); + } +} + void ConvInferMeta(const MetaTensor& input, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9709edf63cc..12922ed536a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -74,6 +74,10 @@ void CompareInferMeta(const MetaTensor& x, int axis, MetaTensor* out); +void ComplexInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void ConvInferMeta(const MetaTensor& input, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/kernels/complex_grad_kernel.h b/paddle/phi/kernels/complex_grad_kernel.h index be13e2826ea..91c47538e95 100644 --- a/paddle/phi/kernels/complex_grad_kernel.h +++ b/paddle/phi/kernels/complex_grad_kernel.h @@ -28,4 +28,12 @@ void ImagGradKernel(const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx); +template +void ComplexGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + } // namespace phi diff --git a/paddle/phi/kernels/complex_kernel.h b/paddle/phi/kernels/complex_kernel.h index 07f93f9b926..ad66b890b3d 100644 --- a/paddle/phi/kernels/complex_kernel.h +++ b/paddle/phi/kernels/complex_kernel.h @@ -30,6 +30,12 @@ void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); template void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +template +void ComplexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + // If T is complex template < typename T, diff --git a/paddle/phi/kernels/cpu/complex_grad_kernel.cc b/paddle/phi/kernels/cpu/complex_grad_kernel.cc index 11b7a058346..049022f01e7 100644 --- a/paddle/phi/kernels/cpu/complex_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_grad_kernel.cc @@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(imag_grad, phi::ImagGradKernel, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL( + complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { + kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/cpu/complex_kernel.cc b/paddle/phi/kernels/cpu/complex_kernel.cc index bef0b7b747a..9e6c72ae7c1 100644 --- a/paddle/phi/kernels/cpu/complex_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_kernel.cc @@ -49,3 +49,8 @@ PD_REGISTER_KERNEL(imag, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } + +PD_REGISTER_KERNEL( + complex, CPU, ALL_LAYOUT, phi::ComplexKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/complex_grad_kernel.cu b/paddle/phi/kernels/gpu/complex_grad_kernel.cu index 450b32291c4..e9fd5e1fa58 100644 --- a/paddle/phi/kernels/gpu/complex_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_grad_kernel.cu @@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(real_grad, phi::RealGradKernel, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL( + complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { + kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/complex_kernel.cu b/paddle/phi/kernels/gpu/complex_kernel.cu index d0ee78202b0..5c5bf104128 100644 --- a/paddle/phi/kernels/gpu/complex_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_kernel.cu @@ -50,3 +50,8 @@ PD_REGISTER_KERNEL(imag, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } + +PD_REGISTER_KERNEL( + complex, GPU, ALL_LAYOUT, phi::ComplexKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h index 03896a2353d..f7366b32e11 100644 --- a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { @@ -47,4 +48,51 @@ void ImagGradKernel(const Context& dev_ctx, for_range(functor); } +template +struct ComplexGradForRealFunctor { + inline HOSTDEVICE T operator()(const T x, + const T y, + const phi::dtype::complex out, + const phi::dtype::complex dout) { + return dout.real; + } +}; + +template +struct ComplexGradForImagFunctor { + inline HOSTDEVICE T operator()(const T x, + const T y, + const phi::dtype::complex out, + const phi::dtype::complex dout) { + return dout.imag; + } +}; + +template +void ComplexGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + using C = phi::dtype::complex; + + // skip out in a hacky way + auto out = dout; + phi::funcs::ElemwiseGradCompute, + ComplexGradForImagFunctor, + C>(dev_ctx, + x, + y, + out, + dout, + /*axis*/ -1, + dx, + dy, + ComplexGradForRealFunctor(), + ComplexGradForImagFunctor()); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/complex_kernel_impl.h b/paddle/phi/kernels/impl/complex_kernel_impl.h index 72b13288339..8bd78234119 100644 --- a/paddle/phi/kernels/impl/complex_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_kernel_impl.h @@ -15,7 +15,9 @@ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { @@ -61,4 +63,45 @@ void ImagKernel(const Context& dev_ctx, for_range(functor); } +// functors to use with ElementwiseComputeEx +template +struct RealAndImagToComplexFunctor { + inline HOSTDEVICE phi::dtype::complex operator()(const T x, const T y) { + return phi::dtype::complex(x, y); + } +}; + +template +struct ImagAndRealToComplexFunctor { + inline HOSTDEVICE phi::dtype::complex operator()(const T y, const T x) { + return phi::dtype::complex(x, y); + } +}; + +template +void ComplexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using C = phi::dtype::complex; + dev_ctx.template Alloc(out); + +// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related +// facility functions +#if defined(__NVCC__) || defined(__HIPCC__) + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); +#else + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims.size() >= y_dims.size()) { + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); + } else { + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), out); + } +#endif +} + } // namespace phi diff --git a/paddle/phi/ops/compat/complex_sig.cc b/paddle/phi/ops/compat/complex_sig.cc index 88156677d34..da47e2c7bc7 100644 --- a/paddle/phi/ops/compat/complex_sig.cc +++ b/paddle/phi/ops/compat/complex_sig.cc @@ -24,7 +24,14 @@ KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } +KernelSignature ComplexGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "complex_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(complex_grad, phi::ComplexGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_complex_op.py b/python/paddle/fluid/tests/unittests/test_complex_op.py index 1faef17a2ad..49ad644b0ab 100644 --- a/python/paddle/fluid/tests/unittests/test_complex_op.py +++ b/python/paddle/fluid/tests/unittests/test_complex_op.py @@ -58,6 +58,7 @@ class TestComplexOp(OpTest): def setUp(self): self.op_type = "complex" + self.python_api = paddle.complex self.init_spec() x = np.random.randn(*self.x_shape).astype(self.dtype) y = np.random.randn(*self.y_shape).astype(self.dtype) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b73fe74a40b..85f8ba4aa4f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1701,6 +1701,9 @@ def complex(real, imag, name=None): # [[0.+0.j 0.+1.j 0.+2.j] # [1.+0.j 1.+1.j 1.+2.j]] """ + if in_dygraph_mode(): + return _C_ops.final_state_complex(real, imag) + if paddle.in_dynamic_mode(): return paddle._C_ops.complex(real, imag) -- GitLab