未验证 提交 31e874b1 编写于 作者: F Feiyu Chan 提交者: GitHub

add complex op (#37918)

* add complex op and `paddle.complex`.
上级 a3bd6fc0
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/complex_op.h"
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.cc"
namespace paddle {
namespace operators {
class ComplexOpMaker : public framework::OpProtoAndCheckerMaker {
protected:
void Make() override {
AddInput("X", "(Tensor), real part of complex_op");
AddInput("Y", "(Tensor), image part of complex_op");
AddOutput("Out", "(Tensor), output of complex_op");
AddComment(R"DOC(
Complex Operator.
Return a complex tensor given the real and image tensors.
)DOC");
}
};
template <typename T>
class ComplexGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("complex_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
// op->SetInput("Out", this->Output("Out"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
class ComplexOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "complex");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "complex");
if (ctx->GetInputDim("X") == ctx->GetInputDim("Y")) {
ctx->ShareDim("X", /*->*/ "Out");
// NOTE(chenfeiyu): lod & broadcasting is intrinsically contradictory
// so tensors with lod are not supported here
} else {
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int max_dim = std::max(x_dims.size(), y_dims.size());
// start align axis
int axis = std::abs(x_dims.size() - y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
details::GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(), max_dim, axis);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims_array));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
class ComplexGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "complex_grad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_complex_gradgrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "complex_grad");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->ShareDim("X", /*->*/ x_grad_name);
}
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(y_grad_name)) {
ctx->ShareDim("Y", /*->*/ y_grad_name);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
auto computation_dtype = framework::ToRealType(
OperatorWithKernel::IndicateVarDataType(ctx, out_grad_name));
return framework::OpKernelType(computation_dtype, ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(complex, ops::ComplexOp, ops::ComplexOpMaker,
ops::ComplexGradOpMaker<paddle::framework::OpDesc>,
ops::ComplexGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(complex_grad, ops::ComplexGradOp);
REGISTER_OP_CPU_KERNEL(
complex, ops::ComplexKernel<paddle::platform::CPUDeviceContext, float>,
ops::ComplexKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
complex_grad,
ops::ComplexGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ComplexGradKernel<paddle::platform::CPUDeviceContext, double>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/complex_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
complex, ops::ComplexKernel<paddle::platform::CUDADeviceContext, float>,
ops::ComplexKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
complex_grad,
ops::ComplexGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ComplexGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
// functors to use with ElementwiseComputeEx
template <typename T>
struct RealAndImagToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T& x, const T& y) {
return platform::complex<T>(x, y);
}
};
template <typename T>
struct ImagAndRealToComplexFunctor {
inline HOSTDEVICE platform::complex<T> operator()(const T& y, const T& x) {
return platform::complex<T>(x, y);
}
};
template <typename T>
struct ComplexGradForRealFunctor {
inline HOSTDEVICE T operator()(const T x, const T y,
const platform::complex<T> out,
const platform::complex<T> dout) {
return dout.real;
}
};
template <typename T>
struct ComplexGradForImagFunctor {
inline HOSTDEVICE T operator()(const T x, const T y,
const platform::complex<T> out,
const platform::complex<T> dout) {
return dout.imag;
}
};
template <typename DeviceContext, typename T>
class ComplexKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
auto* z = ctx.Output<framework::Tensor>("Out");
using C = platform::complex<T>;
z->mutable_data<C>(ctx.GetPlace());
// NOTE(chenfeiyu): be careful of the caveats of calling elementwise-related
// facility functions
#if defined(__NVCC__) || defined(__HIPCC__)
ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z);
#else
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<RealAndImagToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, RealAndImagToComplexFunctor<T>(), z);
} else {
ElementwiseComputeEx<ImagAndRealToComplexFunctor<T>, DeviceContext, T, C>(
ctx, x, y, /*axis*/ -1, ImagAndRealToComplexFunctor<T>(), z);
}
#endif
}
};
template <typename DeviceContext, typename T>
class ComplexGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
using C = platform::complex<T>;
// skip out in a hacky way
auto* out = dout;
ElemwiseGradCompute<DeviceContext, T, ComplexGradForRealFunctor<T>,
ComplexGradForImagFunctor<T>, C>(
ctx, *x, *y, *out, *dout, /*axis*/ -1, dx, dy,
ComplexGradForRealFunctor<T>(), ComplexGradForImagFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -169,7 +169,7 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
is_xsize_larger);
}
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out, const framework::Tensor &dout,
......@@ -179,8 +179,8 @@ void CommonGradBroadcastCPU(
std::vector<int> index_array(max_dim, 0);
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const T *out_data = out.data<T>();
const T *dout_data = dout.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
if (dx_data != nullptr) {
......@@ -240,9 +240,9 @@ inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
......@@ -291,9 +291,9 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
......@@ -369,12 +369,12 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
}
}
template <typename T, typename DX_OP>
template <typename T, typename DX_OP, typename Tout = T>
__global__ void CommonGradBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const int *y_strides_order,
const int *y_dims_order, const T *x, const T *y, const T *out,
const T *dout, T *dx, int out_size, int max_dim, int thread_num,
const int *y_dims_order, const T *x, const T *y, const Tout *out,
const Tout *dout, T *dx, int out_size, int max_dim, int thread_num,
DX_OP dx_op) {
T val(0);
int i = blockIdx.x;
......@@ -408,9 +408,9 @@ __global__ void CommonGradBroadcastCUDAKernel(
}
}
template <typename T, typename DY_OP>
template <typename T, typename DY_OP, typename Tout = T>
static __global__ void CommonGradBroadcast1CUDAKernelHeight(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w,
DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) {
int j = blockIdx.x;
int i = threadIdx.x;
......@@ -454,9 +454,9 @@ static __global__ void CommonGradBroadcast1CUDAKernelHeight(
}
}
template <typename T, typename DY_OP>
template <typename T, typename DY_OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastCUDAKernelHeight(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w,
DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
......@@ -528,9 +528,9 @@ static __global__ void FastCommonGradBroadcastCUDAKernelHeight(
}
}
template <typename T, typename DY_OP, typename DX_OP>
template <typename T, typename DY_OP, typename DX_OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastAllCUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n,
int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x;
int bid = blockIdx.x;
......@@ -581,9 +581,9 @@ static __global__ void FastCommonGradBroadcastAllCUDAKernel(
}
}
template <typename T, typename OP>
template <typename T, typename OP, typename Tout = T>
static __global__ void FastCommonGradBroadcastOneCUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n,
int post, int y_pre, int y_n, int y_post, bool is_xsize, OP op, T *dd) {
int tid = threadIdx.x;
int bid = blockIdx.x;
......@@ -669,7 +669,7 @@ static inline bool CheckContiguousDims(const std::vector<int> &broadcast_pos) {
return true;
}
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCUDA(
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out, const framework::Tensor &dout,
......@@ -680,8 +680,8 @@ void CommonGradBroadcastCUDA(
auto cplace = platform::CPUPlace();
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
const T *out_data = out.data<T>();
const T *dout_data = dout.data<T>();
const Tout *out_data = out.data<Tout>();
const Tout *dout_data = dout.data<Tout>();
T *dx_data = dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace());
T *dy_data = dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace());
......@@ -1045,7 +1045,7 @@ void CommonGradBroadcastCUDA(
memory::Copy(gplace, x_dims_order_gpu, cplace, x_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DX_OP><<<x_blocks, x_block_size, 0, ctx.stream()>>>(
T, DX_OP, Tout><<<x_blocks, x_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data,
dout_data, dx_data, out_size, max_dim, x_threads, dx_op);
......@@ -1062,7 +1062,7 @@ void CommonGradBroadcastCUDA(
memory::Copy(gplace, y_dims_order_gpu, cplace, y_dims_order.data(), bytes,
ctx.stream());
CommonGradBroadcastCUDAKernel<
T, DY_OP><<<y_blocks, y_block_size, 0, ctx.stream()>>>(
T, DY_OP, Tout><<<y_blocks, y_block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu,
y_strides_order_gpu, y_dims_order_gpu, x_data, y_data, out_data,
dout_data, dy_data, out_size, max_dim, y_threads, dy_op);
......@@ -1138,12 +1138,12 @@ class TransformFunctor {
bool is_xsize_larger_;
};
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
struct ElemwiseGradNoBroadcast {
const T *x_;
const T *y_;
const T *out_;
const T *dout_;
const Tout *out_;
const Tout *dout_;
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
......@@ -1160,9 +1160,9 @@ struct ElemwiseGradNoBroadcast {
T *dy_;
};
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
const T *dout, int h, int w,
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const Tout *out,
const Tout *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
if (is_xsize_larger) {
......@@ -1206,11 +1206,12 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int h, int w, bool is_xsize_larger,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
const T *y, const Tout *out,
const Tout *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
// For small case use 1D block
constexpr int half_walf = 16;
if (w < half_walf || h < half_walf) {
......@@ -1229,11 +1230,11 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x,
#endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n, int post,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const Tout *out,
const Tout *dout, int pre, int n,
int post, bool is_xsize_larger,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
if (is_xsize_larger) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
......@@ -1278,9 +1279,9 @@ static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
}
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n,
int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
......@@ -1345,12 +1346,12 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
}
template <typename T, typename DX_OP, typename DY_OP>
template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int pre, int n, int post,
bool is_xsize_larger, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
const T *y, const Tout *out,
const Tout *dout, int pre, int n,
int post, bool is_xsize_larger,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
......@@ -1359,7 +1360,8 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x,
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void CommonElementwiseBroadcastBackward(
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
......@@ -1387,14 +1389,14 @@ void CommonElementwiseBroadcastBackward(
if (platform::is_gpu_place(ctx.GetPlace())) {
#if defined(__NVCC__) || defined(__HIPCC__)
CommonGradBroadcastCUDA<T, DX_OP, DY_OP>(
CommonGradBroadcastCUDA<T, DX_OP, DY_OP, Tout>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), dx_op,
dy_op);
#endif
} else {
CommonGradBroadcastCPU<T, DX_OP, DY_OP>(
CommonGradBroadcastCPU<T, DX_OP, DY_OP, Tout>(
x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CPUDeviceContext>(), dx_op,
......@@ -1402,7 +1404,8 @@ void CommonElementwiseBroadcastBackward(
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor &x,
......@@ -1417,13 +1420,14 @@ void ElemwiseGradComputeNoBroadcast(
platform::ForRange<DeviceContext> for_range(
ctx.device_context<DeviceContext>(), N);
#endif // !_WIN32
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP, Tout>{
x.data<T>(), y.data<T>(), out.data<Tout>(), dout.data<Tout>(), dx_op,
dy_op, dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dims,
const framework::DDim &y_dims, const framework::Tensor &x,
......@@ -1463,7 +1467,7 @@ void ElemwiseGradComputeWithBroadcast(
}
// special case for common backward implementation.
if (is_run_common_broadcast) {
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP>(
CommonElementwiseBroadcastBackward<DeviceContext, T, DX_OP, DY_OP, Tout>(
ctx, x_dims, y_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
return;
}
......@@ -1472,14 +1476,14 @@ void ElemwiseGradComputeWithBroadcast(
#if defined(__NVCC__) || defined(__HIPCC__)
ElemwiseGradBroadcast1CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, is_xsize_larger,
dx_op, dy_op,
y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast1CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n,
x.data<T>(), y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
......@@ -1489,15 +1493,15 @@ void ElemwiseGradComputeWithBroadcast(
#if defined(__NVCC__) || defined(__HIPCC__)
ElemwiseGradBroadcast2CUDA(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n, post,
is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
ElemwiseGradBroadcast2CPU(
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), pre, n, post,
is_xsize_larger, dx_op, dy_op,
x.data<T>(), y.data<T>(), out.data<Tout>(), dout.data<Tout>(), pre, n,
post, is_xsize_larger, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
......@@ -1521,7 +1525,8 @@ void CommonElementwiseBroadcastForward(
axis, is_xsize_larger);
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
typename Tout = T>
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out,
......@@ -1531,10 +1536,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP, Tout>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP, Tout>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
}
......
......@@ -88,6 +88,7 @@ from .tensor.creation import meshgrid # noqa: F401
from .tensor.creation import empty # noqa: F401
from .tensor.creation import empty_like # noqa: F401
from .tensor.creation import assign # noqa: F401
from .tensor.creation import complex # noqa: F401
from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
......@@ -446,6 +447,7 @@ __all__ = [ # noqa
'shape',
'real',
'imag',
'complex',
'reciprocal',
'rand',
'less_equal',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
from paddle.fluid import dygraph
from paddle import static
paddle.enable_static()
def ref_complex(x, y):
return x + 1j * y
def ref_complex_grad(x, y, dout):
out = x + 1j * y
out_rank = out.ndim
delta_rank_x = out_rank - x.ndim
delta_rank_y = out_rank - y.ndim
dx_reduce_axes = []
dy_reduce_axes = []
for i in range(out_rank):
if i < delta_rank_x or dout.shape[i] > x.shape[i - delta_rank_x]:
dx_reduce_axes.append(i)
if i < delta_rank_y or dout.shape[i] > y.shape[i - delta_rank_y]:
dy_reduce_axes.append(i)
dx = np.sum(dout.real, axis=tuple(dx_reduce_axes)).reshape(x.shape)
dy = np.sum(dout.imag, axis=tuple(dy_reduce_axes)).reshape(y.shape)
return (dx, dy)
class TestComplexOp(OpTest):
def init_spec(self):
self.x_shape = [10, 10]
self.y_shape = [10, 10]
self.dtype = "float64"
def setUp(self):
self.op_type = "complex"
self.init_spec()
x = np.random.randn(*self.x_shape).astype(self.dtype)
y = np.random.randn(*self.y_shape).astype(self.dtype)
out_ref = ref_complex(x, y)
self.out_grad = np.random.randn(*self.x_shape).astype(self.dtype) \
+ 1j * np.random.randn(*self.y_shape).astype(self.dtype)
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': out_ref}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
dout = self.out_grad
dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'],
self.out_grad)
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[dx, dy],
user_defined_grad_outputs=[dout])
def test_check_grad_ignore_x(self):
dout = self.out_grad
dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'],
self.out_grad)
self.assertTupleEqual(dx.shape, tuple(self.x_shape))
self.assertTupleEqual(dy.shape, tuple(self.y_shape))
self.check_grad(
['Y'],
'Out',
no_grad_set=set('X'),
user_defined_grads=[dy],
user_defined_grad_outputs=[dout])
def test_check_grad_ignore_y(self):
dout = self.out_grad
dx, dy = ref_complex_grad(self.inputs['X'], self.inputs['Y'],
self.out_grad)
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[dx],
user_defined_grad_outputs=[dout])
class TestComplexOpBroadcast1(TestComplexOp):
def init_spec(self):
self.x_shape = [10, 3, 1, 4]
self.y_shape = [100, 1]
self.dtype = "float64"
class TestComplexOpBroadcast2(TestComplexOp):
def init_spec(self):
self.x_shape = [100, 1]
self.y_shape = [10, 3, 1, 4]
self.dtype = "float32"
class TestComplexOpBroadcast3(TestComplexOp):
def init_spec(self):
self.x_shape = [1, 100]
self.y_shape = [100]
self.dtype = "float32"
class TestComplexAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.randn(10, 10)
self.y = np.random.randn(10, 10)
self.out = ref_complex(self.x, self.y)
def test_dygraph(self):
with dygraph.guard():
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
out_np = paddle.complex(x, y).numpy()
self.assertTrue(np.allclose(self.out, out_np))
def test_static(self):
mp, sp = static.Program(), static.Program()
with static.program_guard(mp, sp):
x = static.data("x", shape=[10, 10], dtype="float64")
y = static.data("y", shape=[10, 10], dtype="float64")
out = paddle.complex(x, y)
exe = static.Executor()
exe.run(sp)
[out_np] = exe.run(mp,
feed={"x": self.x,
"y": self.y},
fetch_list=[out])
self.assertTrue(np.allclose(self.out, out_np))
if __name__ == "__main__":
unittest.main()
......@@ -68,6 +68,7 @@ NEED_TO_FIX_OP_LIST = [
'rank_loss',
'sequence_conv',
'smooth_l1_loss',
'spectral_norm'
'spectral_norm',
'complex',
]
# yapf: enable
......@@ -33,6 +33,7 @@ from .creation import tril # noqa: F401
from .creation import meshgrid # noqa: F401
from .creation import empty # noqa: F401
from .creation import empty_like # noqa: F401
from .creation import complex # noqa: F401
from .linalg import matmul # noqa: F401
from .linalg import dot # noqa: F401
from .linalg import norm # noqa: F401
......
......@@ -27,6 +27,7 @@ from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype
from ..fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard, OpProtoHolder
from paddle.tensor.attribute import _complex_to_real_dtype, _real_to_complex_dtype
# TODO: define functions to get create a tensor
from ..fluid.layers import linspace # noqa: F401
import paddle
......@@ -1250,3 +1251,46 @@ def _memcpy(input, place=None, output=None):
outputs={'Out': [output]},
attrs=attrs)
return output
def complex(real, imag, name=None):
"""Return a compelx tensor given the real and image component.
Args:
real (Tensor): The real component. The data type should be 'float32' or 'float64'.
imag (Tensor): The image component. The data type should be the same as ``real``.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``real`` and ``imag``.
**Note**:
``paddle.complex`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Examples:
.. code-block:: python
import paddle
x = paddle.arange(2, dtype=paddle.float32).unsqueeze(-1)
y = paddle.arange(3, dtype=paddle.float32)
z = paddle.complex(x, y)
print(z.numpy())
# [[0.+0.j 0.+1.j 0.+2.j]
# [1.+0.j 1.+1.j 1.+2.j]]
"""
if in_dygraph_mode():
return paddle._C_ops.complex(real, imag)
check_variable_and_dtype(real, 'real', ['float32', 'float64'], 'complex')
check_variable_and_dtype(imag, 'imag', ['float32', 'float64'], 'complex')
op_type = "complex"
helper = LayerHelper(op_type, **locals())
inputs = {"X": real, "Y": imag}
out = helper.create_variable_for_type_inference(
dtype=_real_to_complex_dtype(real.dtype))
outputs = {"Out": out}
attrs = {}
helper.append_op(type=op_type, inputs=inputs, attrs=attrs, outputs=outputs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册