diff --git a/paddle/fluid/operators/complex_op.cc b/paddle/fluid/operators/complex_op.cc index d6d93fe958118769b303ae95b374a6d81bbbb941..778f5831c0fbb1bfb4687f241d2373dd69f8e5dd 100644 --- a/paddle/fluid/operators/complex_op.cc +++ b/paddle/fluid/operators/complex_op.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/complex_op.h" - -#include - +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -59,36 +59,6 @@ class ComplexOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "complex"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "complex"); - - if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) { - ctx->ShareDim("X", /*->*/ "Out"); - // NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory - // so tensors with lod are not supported here - } else { - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - int max_dim = std::max(x_dims.size(), y_dims.size()); - - // start align axis - int axis = std::abs(x_dims.size() - y_dims.size()); - std::vector x_dims_array(max_dim); - std::vector y_dims_array(max_dim); - std::vector out_dims_array(max_dim); - GetBroadcastDimsArrays(x_dims, - y_dims, - x_dims_array.data(), - y_dims_array.data(), - out_dims_array.data(), - max_dim, - axis); - ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array)); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -101,25 +71,6 @@ class ComplexGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex_grad"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_complex_gradgrad"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - framework::GradVarName("Out"), - "complex_grad"); - - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->ShareDim("X", /*->*/ x_grad_name); - } - - auto y_grad_name = framework::GradVarName("Y"); - if (ctx->HasOutput(y_grad_name)) { - ctx->ShareDim("Y", /*->*/ y_grad_name); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -135,18 +86,21 @@ class ComplexGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(complex, + ComplexInferShapeFunctor, + PD_INFER_META(phi::ComplexInferMeta)); + REGISTER_OPERATOR(complex, ops::ComplexOp, ops::ComplexOpMaker, ops::ComplexGradOpMaker, - ops::ComplexGradOpMaker); - -REGISTER_OPERATOR(complex_grad, ops::ComplexGradOp); + ops::ComplexGradOpMaker, + ComplexInferShapeFunctor); -REGISTER_OP_CPU_KERNEL(complex, - ops::ComplexKernel, - ops::ComplexKernel); +DECLARE_INFER_SHAPE_FUNCTOR(complex_grad, + ComplexGradInferShapeFunctor, + PD_INFER_META(phi::ComplexGradInferMeta)); -REGISTER_OP_CPU_KERNEL(complex_grad, - ops::ComplexGradKernel, - ops::ComplexGradKernel); +REGISTER_OPERATOR(complex_grad, + ops::ComplexGradOp, + ComplexGradInferShapeFunctor); diff --git a/paddle/fluid/operators/complex_op.cu b/paddle/fluid/operators/complex_op.cu deleted file mode 100644 index c9bc2d459e73bcec50f49b680dce0ac25741ce8e..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/complex_op.cu +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/complex_op.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - complex, - ops::ComplexKernel, - ops::ComplexKernel); - -REGISTER_OP_CUDA_KERNEL( - complex_grad, - ops::ComplexGradKernel, - ops::ComplexGradKernel); diff --git a/paddle/fluid/operators/complex_op.h b/paddle/fluid/operators/complex_op.h deleted file mode 100644 index 5fb19b46ec6a03419dd6b562163424d1a0e6dc89..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/complex_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" - -namespace paddle { -namespace operators { - -// functors to use with ElementwiseComputeEx -template -struct RealAndImagToComplexFunctor { - inline HOSTDEVICE platform::complex operator()(const T x, const T y) { - return platform::complex(x, y); - } -}; - -template -struct ImagAndRealToComplexFunctor { - inline HOSTDEVICE platform::complex operator()(const T y, const T x) { - return platform::complex(x, y); - } -}; - -template -struct ComplexGradForRealFunctor { - inline HOSTDEVICE T operator()(const T x, - const T y, - const platform::complex out, - const platform::complex dout) { - return dout.real; - } -}; - -template -struct ComplexGradForImagFunctor { - inline HOSTDEVICE T operator()(const T x, - const T y, - const platform::complex out, - const platform::complex dout) { - return dout.imag; - } -}; - -template -class ComplexKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* x = ctx.Input("X"); - const auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - using C = platform::complex; - z->mutable_data(ctx.GetPlace()); - -// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related -// facility functions -#if defined(__NVCC__) || defined(__HIPCC__) - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); -#else - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T, C>( - ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), z); - } -#endif - } -}; - -template -class ComplexGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - using C = platform::complex; - - // skip out in a hacky way - auto* out = dout; - ElemwiseGradCompute, - ComplexGradForImagFunctor, - C>(ctx, - *x, - *y, - *out, - *dout, - /*axis*/ -1, - dx, - dy, - ComplexGradForRealFunctor(), - ComplexGradForImagFunctor()); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index aa86c0f34db55adb961565b2eaa64c137fac8ec7..cd01c23641010bab055dc9cbe955697009ec2f0c 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -342,6 +342,15 @@ func : clip backward : clip_grad +- api : complex + args : (Tensor x, Tensor y) + output : Tensor + infer_meta : + func : ComplexInferMeta + kernel : + func : complex + backward : complex_grad + - api : concat args : (Tensor[] x, Scalar(int64_t) axis) output : Tensor diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index f01598e64342017e3831a7b338c3de77b91b51f6..b4972c68a6477bbba4db8bac7884d78dd774daeb 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -306,6 +306,16 @@ backward : clip_double_grad inplace : (out_grad -> x_grad) +- backward_api : complex_grad + forward : complex (Tensor x, Tensor y) -> Tensor(out) + args : (Tensor x, Tensor y, Tensor out_grad) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : ComplexGradInferMeta + kernel : + func : complex_grad + data_type : x + - backward_api : concat_double_grad forward : concat_grad (Tensor[] x, Tensor grad_out, Scalar axis) -> Tensor[](grad_x) args : (Tensor[] grad_x_grad, Scalar axis = 0) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index f59ea5549bd71409743176c471d7f633f62b7ca7..dd2d1eb482c8edcaddc2c4e24cddb2a43a091d2f 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -83,6 +83,23 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, x_grad->set_dtype(out_grad.dtype()); } +void ComplexGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy) { + auto x_dims = x.dims(); + if (dx) { + dx->set_dims(x_dims); + dx->set_dtype(x.dtype()); + } + auto y_dims = y.dims(); + if (dy) { + dy->set_dims(y_dims); + dy->set_dtype(y.dtype()); + } +} + void ConvTransposeGradInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 0e7ed640d8ffb55b622084bbb44cfbbe0879ba51..6a4eba74b47beec1b6d22b4face65728ba36813e 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -42,6 +42,12 @@ void ChannelShuffleGradInferMeta(const MetaTensor& out_grad, const std::string& data_format, MetaTensor* x_grad); +void ComplexGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& dout, + MetaTensor* dx, + MetaTensor* dy); + void ConvTransposeGradInferMeta(const MetaTensor& x, const MetaTensor& filter, const MetaTensor& dout, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 269286d76d9545182b0f94e3b86ad00bf59c7b9f..460b0a9e1bdc41f343fb547c333e8103f99ccf4f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/layout.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/kernels/cpu/conv_util.h" @@ -358,6 +359,37 @@ void CompareAllInferMeta(const MetaTensor& x, out->set_dtype(DataType::BOOL); } +void ComplexInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out) { + if (x.dims() == y.dims()) { + auto sizes = vectorize(x.dims()); + out->set_dims(phi::make_ddim(sizes)); + out->set_dtype(dtype::ToComplex(x.dtype())); + // NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory + // so tensors with lod are not supported here + } else { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + + // start align axis + int axis = std::abs(x_dims.size() - y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + phi::funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + out->set_dims(phi::make_ddim(out_dims_array)); + out->set_dtype(dtype::ToComplex(x.dtype())); + } +} + void ConvInferMeta(const MetaTensor& input, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 9709edf63ccc07cb36246e45e734d09c028434f0..12922ed536add74e5e27d6899347f66dce25f826 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -74,6 +74,10 @@ void CompareInferMeta(const MetaTensor& x, int axis, MetaTensor* out); +void ComplexInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* out); + void ConvInferMeta(const MetaTensor& input, const MetaTensor& filter, const std::vector& strides, diff --git a/paddle/phi/kernels/complex_grad_kernel.h b/paddle/phi/kernels/complex_grad_kernel.h index be13e2826ea81455fd811143dde02f2d11cfdae2..91c47538e958d4450a2ad08a89f13f8990571200 100644 --- a/paddle/phi/kernels/complex_grad_kernel.h +++ b/paddle/phi/kernels/complex_grad_kernel.h @@ -28,4 +28,12 @@ void ImagGradKernel(const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx); +template +void ComplexGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy); + } // namespace phi diff --git a/paddle/phi/kernels/complex_kernel.h b/paddle/phi/kernels/complex_kernel.h index 07f93f9b926f174c374bbc20b7c655a65732423f..ad66b890b3d5ab70aabaf9911b726a4b9d2261ef 100644 --- a/paddle/phi/kernels/complex_kernel.h +++ b/paddle/phi/kernels/complex_kernel.h @@ -30,6 +30,12 @@ void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); template void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +template +void ComplexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + // If T is complex template < typename T, diff --git a/paddle/phi/kernels/cpu/complex_grad_kernel.cc b/paddle/phi/kernels/cpu/complex_grad_kernel.cc index 11b7a05834607ca6d316ae40b485b937acf34b0f..049022f01e7c044bb5726e1440655eae767d87b1 100644 --- a/paddle/phi/kernels/cpu/complex_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_grad_kernel.cc @@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(imag_grad, phi::ImagGradKernel, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL( + complex_grad, CPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { + kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/cpu/complex_kernel.cc b/paddle/phi/kernels/cpu/complex_kernel.cc index bef0b7b747a420f501057d506a7bfe2f1b34971c..9e6c72ae7c16a589e50bcb63381bdc1346c6266b 100644 --- a/paddle/phi/kernels/cpu/complex_kernel.cc +++ b/paddle/phi/kernels/cpu/complex_kernel.cc @@ -49,3 +49,8 @@ PD_REGISTER_KERNEL(imag, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } + +PD_REGISTER_KERNEL( + complex, CPU, ALL_LAYOUT, phi::ComplexKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/complex_grad_kernel.cu b/paddle/phi/kernels/gpu/complex_grad_kernel.cu index 450b32291c4bc8b76f3fe6dcdad5e5e00dbad409..e9fd5e1fa5834e7137438cba8a570223fb2c9889 100644 --- a/paddle/phi/kernels/gpu/complex_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_grad_kernel.cu @@ -31,3 +31,8 @@ PD_REGISTER_KERNEL(real_grad, phi::RealGradKernel, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL( + complex_grad, GPU, ALL_LAYOUT, phi::ComplexGradKernel, float, double) { + kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/gpu/complex_kernel.cu b/paddle/phi/kernels/gpu/complex_kernel.cu index d0ee78202b0604de877edbc709fc565e25737a2b..5c5bf104128d335f325778e742ff74c0c10fbc0d 100644 --- a/paddle/phi/kernels/gpu/complex_kernel.cu +++ b/paddle/phi/kernels/gpu/complex_kernel.cu @@ -50,3 +50,8 @@ PD_REGISTER_KERNEL(imag, phi::dtype::complex) { kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } + +PD_REGISTER_KERNEL( + complex, GPU, ALL_LAYOUT, phi::ComplexKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype())); +} diff --git a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h index 03896a2353dda5b89876d46115eea59fa907a8ac..f7366b32e1105b2448fc00cfc965da38443331dc 100644 --- a/paddle/phi/kernels/impl/complex_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_grad_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_grad_base.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { @@ -47,4 +48,51 @@ void ImagGradKernel(const Context& dev_ctx, for_range(functor); } +template +struct ComplexGradForRealFunctor { + inline HOSTDEVICE T operator()(const T x, + const T y, + const phi::dtype::complex out, + const phi::dtype::complex dout) { + return dout.real; + } +}; + +template +struct ComplexGradForImagFunctor { + inline HOSTDEVICE T operator()(const T x, + const T y, + const phi::dtype::complex out, + const phi::dtype::complex dout) { + return dout.imag; + } +}; + +template +void ComplexGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy) { + using C = phi::dtype::complex; + + // skip out in a hacky way + auto out = dout; + phi::funcs::ElemwiseGradCompute, + ComplexGradForImagFunctor, + C>(dev_ctx, + x, + y, + out, + dout, + /*axis*/ -1, + dx, + dy, + ComplexGradForRealFunctor(), + ComplexGradForImagFunctor()); +} + } // namespace phi diff --git a/paddle/phi/kernels/impl/complex_kernel_impl.h b/paddle/phi/kernels/impl/complex_kernel_impl.h index 72b13288339797ae5dd6b21862ed403dd1719a65..8bd78234119649c9dae50467ea21cdfb4598fb58 100644 --- a/paddle/phi/kernels/impl/complex_kernel_impl.h +++ b/paddle/phi/kernels/impl/complex_kernel_impl.h @@ -15,7 +15,9 @@ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { @@ -61,4 +63,45 @@ void ImagKernel(const Context& dev_ctx, for_range(functor); } +// functors to use with ElementwiseComputeEx +template +struct RealAndImagToComplexFunctor { + inline HOSTDEVICE phi::dtype::complex operator()(const T x, const T y) { + return phi::dtype::complex(x, y); + } +}; + +template +struct ImagAndRealToComplexFunctor { + inline HOSTDEVICE phi::dtype::complex operator()(const T y, const T x) { + return phi::dtype::complex(x, y); + } +}; + +template +void ComplexKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + using C = phi::dtype::complex; + dev_ctx.template Alloc(out); + +// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related +// facility functions +#if defined(__NVCC__) || defined(__HIPCC__) + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); +#else + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims.size() >= y_dims.size()) { + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), out); + } else { + phi::funcs::ElementwiseCompute, T, C>( + dev_ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), out); + } +#endif +} + } // namespace phi diff --git a/paddle/phi/ops/compat/complex_sig.cc b/paddle/phi/ops/compat/complex_sig.cc index 88156677d34df82dd7fd68a910efbf9a8ac459d3..da47e2c7bc7506a74e2b0a3f42665fd35544dd6a 100644 --- a/paddle/phi/ops/compat/complex_sig.cc +++ b/paddle/phi/ops/compat/complex_sig.cc @@ -24,7 +24,14 @@ KernelSignature ImagGradOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature("imag_grad", {"Out@GRAD"}, {}, {"X@GRAD"}); } +KernelSignature ComplexGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "complex_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(real_grad, phi::RealGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(imag_grad, phi::ImagGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(complex_grad, phi::ComplexGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_complex_op.py b/python/paddle/fluid/tests/unittests/test_complex_op.py index 1faef17a2ade35dae4314056a5c04ac4a3aa2e41..49ad644b0ab753d31342f2d6ad12ec03b7c343cc 100644 --- a/python/paddle/fluid/tests/unittests/test_complex_op.py +++ b/python/paddle/fluid/tests/unittests/test_complex_op.py @@ -58,6 +58,7 @@ class TestComplexOp(OpTest): def setUp(self): self.op_type = "complex" + self.python_api = paddle.complex self.init_spec() x = np.random.randn(*self.x_shape).astype(self.dtype) y = np.random.randn(*self.y_shape).astype(self.dtype) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index b73fe74a40ba210bb1b73e87e0c89d5924aecc57..85f8ba4aa4f459460d3b33bdc54a7c132a4a7600 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1701,6 +1701,9 @@ def complex(real, imag, name=None): # [[0.+0.j 0.+1.j 0.+2.j] # [1.+0.j 1.+1.j 1.+2.j]] """ + if in_dygraph_mode(): + return _C_ops.final_state_complex(real, imag) + if paddle.in_dynamic_mode(): return paddle._C_ops.complex(real, imag)