diff --git a/paddle/fluid/operators/complex_op.cc b/paddle/fluid/operators/complex_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ad8a121f783aef4f392f4f9919736553d8f9a8f --- /dev/null +++ b/paddle/fluid/operators/complex_op.cc @@ -0,0 +1,144 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/complex_op.h" + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/common_infer_shape_functions.cc" + +namespace paddle { +namespace operators { + +class ComplexOpMaker : public framework::OpProtoAndCheckerMaker { + protected: + void Make() override { + AddInput("X", "(Tensor), real part of complex_op"); + AddInput("Y", "(Tensor), image part of complex_op"); + AddOutput("Out", "(Tensor), output of complex_op"); + AddComment(R"DOC( +Complex Operator. + +Return a complex tensor given the real and image tensors. + +)DOC"); + } +}; + +template +class ComplexGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("complex_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + // op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + +class ComplexOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "complex"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "complex"); + + if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) { + ctx->ShareDim("X", /*->*/ "Out"); + // NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory + // so tensors with lod are not supported here + } else { + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + int max_dim = std::max(x_dims.size(), y_dims.size()); + + // start align axis + int axis = std::abs(x_dims.size() - y_dims.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + details::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), max_dim, axis); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array)); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.GetPlace()); + } +}; + +class ComplexGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex_grad"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_complex_gradgrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "complex_grad"); + + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->ShareDim("X", /*->*/ x_grad_name); + } + + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", /*->*/ y_grad_name); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto out_grad_name = framework::GradVarName("Out"); + auto computation_dtype = framework::ToRealType( + OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name)); + return framework::OpKernelType(computation_dtype, ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(complex, ops::ComplexOp, ops::ComplexOpMaker, + ops::ComplexGradOpMaker, + ops::ComplexGradOpMaker); + +REGISTER_OPERATOR(complex_grad, ops::ComplexGradOp); + +REGISTER_OP_CPU_KERNEL( + complex, ops::ComplexKernel, + ops::ComplexKernel); + +REGISTER_OP_CPU_KERNEL( + complex_grad, + ops::ComplexGradKernel, + ops::ComplexGradKernel); diff --git a/paddle/fluid/operators/complex_op.cu b/paddle/fluid/operators/complex_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..c3f8186c93be48a27bbb05bba62ba525f99756ff --- /dev/null +++ b/paddle/fluid/operators/complex_op.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/complex_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + complex, ops::ComplexKernel, + ops::ComplexKernel); + +REGISTER_OP_CUDA_KERNEL( + complex_grad, + ops::ComplexGradKernel, + ops::ComplexGradKernel); diff --git a/paddle/fluid/operators/complex_op.h b/paddle/fluid/operators/complex_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c6ae46f5a828fdc565d8af7eac857d84a5c0d46c --- /dev/null +++ b/paddle/fluid/operators/complex_op.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/platform/complex.h" + +namespace paddle { +namespace operators { + +// functors to use with ElementwiseComputeEx +template +struct RealAndImagToComplexFunctor { + inline HOSTDEVICE platform::complex operator()(const T& x, const T& y) { + return platform::complex(x, y); + } +}; + +template +struct ImagAndRealToComplexFunctor { + inline HOSTDEVICE platform::complex operator()(const T& y, const T& x) { + return platform::complex(x, y); + } +}; + +template +struct ComplexGradForRealFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, + const platform::complex out, + const platform::complex dout) { + return dout.real; + } +}; + +template +struct ComplexGradForImagFunctor { + inline HOSTDEVICE T operator()(const T x, const T y, + const platform::complex out, + const platform::complex dout) { + return dout.imag; + } +}; + +template +class ComplexKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + const auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + using C = platform::complex; + z->mutable_data(ctx.GetPlace()); + +// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related +// facility functions +#if defined(__NVCC__) || defined(__HIPCC__) + ElementwiseComputeEx, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); +#else + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor(), z); + } else { + ElementwiseComputeEx, DeviceContext, T, C>( + ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor(), z); + } +#endif + } +}; + +template +class ComplexGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + using C = platform::complex; + + // skip out in a hacky way + auto* out = dout; + ElemwiseGradCompute, + ComplexGradForImagFunctor, C>( + ctx, *x, *y, *out, *dout, /*axis*/ -1, dx, dy, + ComplexGradForRealFunctor(), ComplexGradForImagFunctor()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 4cb9fa6467623f09997593b8befb74426922b58e..9700ca3584de8d153f2164abe06482adac36601d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -169,7 +169,7 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, is_xsize_larger); } -template +template void CommonGradBroadcastCPU( const framework::Tensor &x, const framework::Tensor &y, const framework::Tensor &out, const framework::Tensor &dout, @@ -179,8 +179,8 @@ void CommonGradBroadcastCPU( std::vector index_array(max_dim, 0); const T *x_data = x.data(); const T *y_data = y.data(); - const T *out_data = out.data(); - const T *dout_data = dout.data(); + const Tout *out_data = out.data(); + const Tout *dout_data = dout.data(); T *dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); T *dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); if (dx_data != nullptr) { @@ -240,9 +240,9 @@ inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs, } #if defined(__NVCC__) || defined(__HIPCC__) -template +template static __global__ void ElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int h, int w, + const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int j = blockIdx.x; int i = threadIdx.x; @@ -291,9 +291,9 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( // suppose use 2D block is fast because more parallel // and memory coalesced -template +template static __global__ void FastElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int h, int w, + const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; @@ -369,12 +369,12 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( } } -template +template __global__ void CommonGradBroadcastCUDAKernel( const int *x_strides_array, const int *y_strides_array, const int *out_dims_array, const int *y_strides_order, - const int *y_dims_order, const T *x, const T *y, const T *out, - const T *dout, T *dx, int out_size, int max_dim, int thread_num, + const int *y_dims_order, const T *x, const T *y, const Tout *out, + const Tout *dout, T *dx, int out_size, int max_dim, int thread_num, DX_OP dx_op) { T val(0); int i = blockIdx.x; @@ -408,9 +408,9 @@ __global__ void CommonGradBroadcastCUDAKernel( } } -template +template static __global__ void CommonGradBroadcast1CUDAKernelHeight( - const T *x, const T *y, const T *out, const T *dout, int h, int w, + const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { int j = blockIdx.x; int i = threadIdx.x; @@ -454,9 +454,9 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight( } } -template +template static __global__ void FastCommonGradBroadcastCUDAKernelHeight( - const T *x, const T *y, const T *out, const T *dout, int h, int w, + const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; @@ -528,9 +528,9 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight( } } -template +template static __global__ void FastCommonGradBroadcastAllCUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int pre, int n, + const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int tid = threadIdx.x; int bid = blockIdx.x; @@ -581,9 +581,9 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel( } } -template +template static __global__ void FastCommonGradBroadcastOneCUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int pre, int n, + const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, int post, int y_pre, int y_n, int y_post, bool is_xsize, OP op, T *dd) { int tid = threadIdx.x; int bid = blockIdx.x; @@ -669,7 +669,7 @@ static inline bool CheckContiguousDims(const std::vector &broadcast_pos) { return true; } -template +template void CommonGradBroadcastCUDA( const framework::Tensor &x, const framework::Tensor &y, const framework::Tensor &out, const framework::Tensor &dout, @@ -680,8 +680,8 @@ void CommonGradBroadcastCUDA( auto cplace = platform::CPUPlace(); const T *x_data = x.data(); const T *y_data = y.data(); - const T *out_data = out.data(); - const T *dout_data = dout.data(); + const Tout *out_data = out.data(); + const Tout *dout_data = dout.data(); T *dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); T *dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); @@ -1045,7 +1045,7 @@ void CommonGradBroadcastCUDA( memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes, ctx.stream()); CommonGradBroadcastCUDAKernel< - T, DX_OP><<>>( + T, DX_OP, Tout><<>>( x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data, dout_data, dx_data, out_size, max_dim, x_threads, dx_op); @@ -1062,7 +1062,7 @@ void CommonGradBroadcastCUDA( memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes, ctx.stream()); CommonGradBroadcastCUDAKernel< - T, DY_OP><<>>( + T, DY_OP, Tout><<>>( x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data, dout_data, dy_data, out_size, max_dim, y_threads, dy_op); @@ -1138,12 +1138,12 @@ class TransformFunctor { bool is_xsize_larger_; }; -template +template struct ElemwiseGradNoBroadcast { const T *x_; const T *y_; - const T *out_; - const T *dout_; + const Tout *out_; + const Tout *dout_; HOSTDEVICE void operator()(size_t i) { if (dx_ != nullptr) { @@ -1160,9 +1160,9 @@ struct ElemwiseGradNoBroadcast { T *dy_; }; -template -static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out, - const T *dout, int h, int w, +template +static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const Tout *out, + const Tout *dout, int h, int w, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { if (is_xsize_larger) { @@ -1206,11 +1206,12 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out, #if defined(__NVCC__) || defined(__HIPCC__) -template +template static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x, - const T *y, const T *out, const T *dout, - int h, int w, bool is_xsize_larger, - DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + const T *y, const Tout *out, + const Tout *dout, int h, int w, + bool is_xsize_larger, DX_OP dx_op, + DY_OP dy_op, T *dx, T *dy) { // For small case use 1D block constexpr int half_walf = 16; if (w < half_walf || h < half_walf) { @@ -1229,11 +1230,11 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x, #endif -template -static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out, - const T *dout, int pre, int n, int post, - bool is_xsize_larger, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { +template +static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const Tout *out, + const Tout *dout, int pre, int n, + int post, bool is_xsize_larger, + DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { if (is_xsize_larger) { for (int i = 0; i < pre; ++i) { for (int j = 0; j < n; ++j) { @@ -1278,9 +1279,9 @@ static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out, } #if defined(__NVCC__) || defined(__HIPCC__) -template +template static __global__ void ElemwiseGradBroadcast2CUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int pre, int n, + const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int tid = threadIdx.x; int j = blockIdx.x; @@ -1345,12 +1346,12 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( } } -template +template static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x, - const T *y, const T *out, const T *dout, - int pre, int n, int post, - bool is_xsize_larger, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { + const T *y, const Tout *out, + const Tout *dout, int pre, int n, + int post, bool is_xsize_larger, + DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int gird_size = n; ElemwiseGradBroadcast2CUDAKernel<<>>( @@ -1359,7 +1360,8 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x, #endif -template +template void CommonElementwiseBroadcastBackward( const framework::ExecutionContext &ctx, const framework::DDim &x_dims, const framework::DDim &y_dims, const framework::Tensor &x, @@ -1387,14 +1389,14 @@ void CommonElementwiseBroadcastBackward( if (platform::is_gpu_place(ctx.GetPlace())) { #if defined(__NVCC__) || defined(__HIPCC__) - CommonGradBroadcastCUDA( + CommonGradBroadcastCUDA( x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, ctx.template device_context(), dx_op, dy_op); #endif } else { - CommonGradBroadcastCPU( + CommonGradBroadcastCPU( x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, ctx.template device_context(), dx_op, @@ -1402,7 +1404,8 @@ void CommonElementwiseBroadcastBackward( } } -template +template void ElemwiseGradComputeNoBroadcast( const framework::ExecutionContext &ctx, const framework::DDim &x_dim, const framework::DDim &y_dim, const framework::Tensor &x, @@ -1417,13 +1420,14 @@ void ElemwiseGradComputeNoBroadcast( platform::ForRange for_range( ctx.device_context(), N); #endif // !_WIN32 - for_range(ElemwiseGradNoBroadcast{ - x.data(), y.data(), out.data(), dout.data(), dx_op, dy_op, - dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + for_range(ElemwiseGradNoBroadcast{ + x.data(), y.data(), out.data(), dout.data(), dx_op, + dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())}); } -template +template void ElemwiseGradComputeWithBroadcast( const framework::ExecutionContext &ctx, const framework::DDim &x_dims, const framework::DDim &y_dims, const framework::Tensor &x, @@ -1463,7 +1467,7 @@ void ElemwiseGradComputeWithBroadcast( } // special case for common backward implementation. if (is_run_common_broadcast) { - CommonElementwiseBroadcastBackward( + CommonElementwiseBroadcastBackward( ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); return; } @@ -1472,14 +1476,14 @@ void ElemwiseGradComputeWithBroadcast( #if defined(__NVCC__) || defined(__HIPCC__) ElemwiseGradBroadcast1CUDA( ctx.template device_context().stream(), x.data(), - y.data(), out.data(), dout.data(), pre, n, is_xsize_larger, - dx_op, dy_op, + y.data(), out.data(), dout.data(), pre, n, + is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); #endif } else { ElemwiseGradBroadcast1CPU( - x.data(), y.data(), out.data(), dout.data(), pre, n, + x.data(), y.data(), out.data(), dout.data(), pre, n, is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); @@ -1489,15 +1493,15 @@ void ElemwiseGradComputeWithBroadcast( #if defined(__NVCC__) || defined(__HIPCC__) ElemwiseGradBroadcast2CUDA( ctx.template device_context().stream(), x.data(), - y.data(), out.data(), dout.data(), pre, n, post, + y.data(), out.data(), dout.data(), pre, n, post, is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); #endif } else { ElemwiseGradBroadcast2CPU( - x.data(), y.data(), out.data(), dout.data(), pre, n, post, - is_xsize_larger, dx_op, dy_op, + x.data(), y.data(), out.data(), dout.data(), pre, n, + post, is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); } @@ -1521,7 +1525,8 @@ void CommonElementwiseBroadcastForward( axis, is_xsize_larger); } -template +template void ElemwiseGradCompute(const framework::ExecutionContext &ctx, const framework::Tensor &x, const framework::Tensor &y, const framework::Tensor &out, @@ -1531,10 +1536,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, const framework::DDim &x_dim = x.dims(); const framework::DDim &y_dim = y.dims(); if (x.dims() == y.dims()) { - ElemwiseGradComputeNoBroadcast( + ElemwiseGradComputeNoBroadcast( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { - ElemwiseGradComputeWithBroadcast( + ElemwiseGradComputeWithBroadcast( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7e1198bac515fbe27fb6f7519de2fc82a5e1d663..e441e8f18c85899bcd124b43ca225c2d0a359ff5 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -88,6 +88,7 @@ from .tensor.creation import meshgrid # noqa: F401 from .tensor.creation import empty # noqa: F401 from .tensor.creation import empty_like # noqa: F401 from .tensor.creation import assign # noqa: F401 +from .tensor.creation import complex # noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -446,6 +447,7 @@ __all__ = [ # noqa 'shape', 'real', 'imag', + 'complex', 'reciprocal', 'rand', 'less_equal', diff --git a/python/paddle/fluid/tests/unittests/test_complex_op.py b/python/paddle/fluid/tests/unittests/test_complex_op.py new file mode 100644 index 0000000000000000000000000000000000000000..15302a772a53004712e7a15f704513da67b2c33f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_op.py @@ -0,0 +1,156 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + +import paddle +from paddle.fluid import dygraph +from paddle import static +paddle.enable_static() + + +def ref_complex(x, y): + return x + 1j * y + + +def ref_complex_grad(x, y, dout): + out = x + 1j * y + out_rank = out.ndim + delta_rank_x = out_rank - x.ndim + delta_rank_y = out_rank - y.ndim + + dx_reduce_axes = [] + dy_reduce_axes = [] + + for i in range(out_rank): + if i < delta_rank_x or dout.shape[i] > x.shape[i - delta_rank_x]: + dx_reduce_axes.append(i) + if i < delta_rank_y or dout.shape[i] > y.shape[i - delta_rank_y]: + dy_reduce_axes.append(i) + dx = np.sum(dout.real, axis=tuple(dx_reduce_axes)).reshape(x.shape) + dy = np.sum(dout.imag, axis=tuple(dy_reduce_axes)).reshape(y.shape) + return (dx, dy) + + +class TestComplexOp(OpTest): + def init_spec(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.dtype = "float64" + + def setUp(self): + self.op_type = "complex" + self.init_spec() + x = np.random.randn(*self.x_shape).astype(self.dtype) + y = np.random.randn(*self.y_shape).astype(self.dtype) + out_ref = ref_complex(x, y) + self.out_grad = np.random.randn(*self.x_shape).astype(self.dtype) \ + + 1j * np.random.randn(*self.y_shape).astype(self.dtype) + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out_ref} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + dout = self.out_grad + dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'], + self.out_grad) + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[dx, dy], + user_defined_grad_outputs=[dout]) + + def test_check_grad_ignore_x(self): + dout = self.out_grad + dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'], + self.out_grad) + self.assertTupleEqual(dx.shape, tuple(self.x_shape)) + self.assertTupleEqual(dy.shape, tuple(self.y_shape)) + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set('X'), + user_defined_grads=[dy], + user_defined_grad_outputs=[dout]) + + def test_check_grad_ignore_y(self): + dout = self.out_grad + dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'], + self.out_grad) + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + user_defined_grads=[dx], + user_defined_grad_outputs=[dout]) + + +class TestComplexOpBroadcast1(TestComplexOp): + def init_spec(self): + self.x_shape = [10, 3, 1, 4] + self.y_shape = [100, 1] + self.dtype = "float64" + + +class TestComplexOpBroadcast2(TestComplexOp): + def init_spec(self): + self.x_shape = [100, 1] + self.y_shape = [10, 3, 1, 4] + self.dtype = "float32" + + +class TestComplexOpBroadcast3(TestComplexOp): + def init_spec(self): + self.x_shape = [1, 100] + self.y_shape = [100] + self.dtype = "float32" + + +class TestComplexAPI(unittest.TestCase): + def setUp(self): + self.x = np.random.randn(10, 10) + self.y = np.random.randn(10, 10) + self.out = ref_complex(self.x, self.y) + + def test_dygraph(self): + with dygraph.guard(): + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + out_np = paddle.complex(x, y).numpy() + self.assertTrue(np.allclose(self.out, out_np)) + + def test_static(self): + mp, sp = static.Program(), static.Program() + with static.program_guard(mp, sp): + x = static.data("x", shape=[10, 10], dtype="float64") + y = static.data("y", shape=[10, 10], dtype="float64") + out = paddle.complex(x, y) + + exe = static.Executor() + exe.run(sp) + [out_np] = exe.run(mp, + feed={"x": self.x, + "y": self.y}, + fetch_list=[out]) + self.assertTrue(np.allclose(self.out, out_np)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 725ad4e93824f178bd0593e80c8c8f3475a73049..d5f4cef5b8759f807419af267a3deb613aef876a 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -68,6 +68,7 @@ NEED_TO_FIX_OP_LIST = [ 'rank_loss', 'sequence_conv', 'smooth_l1_loss', - 'spectral_norm' + 'spectral_norm', + 'complex', ] # yapf: enable diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 424cbbe4f2d9dbed08a59b6d67a8d66c1ad53d41..95a64fcdd694fe1452e9085f0c0472dcbcf18962 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -33,6 +33,7 @@ from .creation import tril # noqa: F401 from .creation import meshgrid # noqa: F401 from .creation import empty # noqa: F401 from .creation import empty_like # noqa: F401 +from .creation import complex # noqa: F401 from .linalg import matmul # noqa: F401 from .linalg import dot # noqa: F401 from .linalg import norm # noqa: F401 diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 324934aa0aa68f107be70d821841efeeb2643a9b..8a376884063f71821641efc4fe2b3db0918d2e08 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -27,6 +27,7 @@ from ..fluid.layers import core from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard, OpProtoHolder +from paddle.tensor.attribute import _complex_to_real_dtype, _real_to_complex_dtype # TODO: define functions to get create a tensor from ..fluid.layers import linspace # noqa: F401 import paddle @@ -1250,3 +1251,46 @@ def _memcpy(input, place=None, output=None): outputs={'Out': [output]}, attrs=attrs) return output + + +def complex(real, imag, name=None): + """Return a compelx tensor given the real and image component. + + Args: + real (Tensor): The real component. The data type should be 'float32' or 'float64'. + imag (Tensor): The image component. The data type should be the same as ``real``. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``real`` and ``imag``. + + **Note**: + ``paddle.complex`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Examples: + .. code-block:: python + + import paddle + x = paddle.arange(2, dtype=paddle.float32).unsqueeze(-1) + y = paddle.arange(3, dtype=paddle.float32) + z = paddle.complex(x, y) + print(z.numpy()) + + # [[0.+0.j 0.+1.j 0.+2.j] + # [1.+0.j 1.+1.j 1.+2.j]] + """ + if in_dygraph_mode(): + return paddle._C_ops.complex(real, imag) + + check_variable_and_dtype(real, 'real', ['float32', 'float64'], 'complex') + check_variable_and_dtype(imag, 'imag', ['float32', 'float64'], 'complex') + + op_type = "complex" + helper = LayerHelper(op_type, **locals()) + inputs = {"X": real, "Y": imag} + out = helper.create_variable_for_type_inference( + dtype=_real_to_complex_dtype(real.dtype)) + outputs = {"Out": out} + attrs = {} + helper.append_op(type=op_type, inputs=inputs, attrs=attrs, outputs=outputs) + return out