From 5fa44c34c7c0349fc470e4f8161b4536023e8483 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 25 May 2021 10:08:57 +0800 Subject: [PATCH] modify Ops to complex template (#33041) * modify conj, real, imag OP to complex template * replace with complex template to dot Op * replace with complex template to Abs Op * add support for complex64 and complex128 --- paddle/fluid/operators/abs_op.cc | 12 +-- paddle/fluid/operators/abs_op.cu | 12 +-- paddle/fluid/operators/conj_op.cc | 4 +- paddle/fluid/operators/conj_op.cu | 7 +- paddle/fluid/operators/dot_op.cc | 10 +- paddle/fluid/operators/dot_op.cu | 20 ++-- paddle/fluid/operators/imag_op.cc | 8 +- paddle/fluid/operators/imag_op.cu | 8 +- .../fluid/operators/math/complex_functors.h | 101 ++++++++++-------- paddle/fluid/operators/real_op.cc | 8 +- paddle/fluid/operators/real_op.cu | 8 +- 11 files changed, 103 insertions(+), 95 deletions(-) diff --git a/paddle/fluid/operators/abs_op.cc b/paddle/fluid/operators/abs_op.cc index 5c431ce77dc..796425a132b 100644 --- a/paddle/fluid/operators/abs_op.cc +++ b/paddle/fluid/operators/abs_op.cc @@ -164,9 +164,9 @@ REGISTER_OP_CPU_KERNEL( ops::AbsKernel, ops::AbsKernel, ops::AbsKernel, + paddle::platform::complex>, ops::AbsKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( abs_grad, ops::AbsGradKernel, @@ -174,9 +174,9 @@ REGISTER_OP_CPU_KERNEL( ops::AbsGradKernel, ops::AbsGradKernel, ops::AbsGradKernel, + paddle::platform::complex>, ops::AbsGradKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( abs_grad_grad, @@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL( ops::AbsDoubleGradKernel, ops::AbsDoubleGradKernel, + paddle::platform::complex>, ops::AbsDoubleGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/abs_op.cu b/paddle/fluid/operators/abs_op.cu index a29670b415d..d03de7a4562 100644 --- a/paddle/fluid/operators/abs_op.cu +++ b/paddle/fluid/operators/abs_op.cu @@ -70,8 +70,8 @@ REGISTER_OP_CUDA_KERNEL( ops::AbsKernel, ops::AbsKernel, ops::AbsKernel, - ops::AbsKernel, - ops::AbsKernel); + ops::AbsKernel>, + ops::AbsKernel>); REGISTER_OP_CUDA_KERNEL( abs_grad, ops::AbsGradKernel, @@ -79,8 +79,8 @@ REGISTER_OP_CUDA_KERNEL( ops::AbsGradKernel, ops::AbsGradKernel, ops::AbsGradKernel, - ops::AbsGradKernel, - ops::AbsGradKernel); + ops::AbsGradKernel>, + ops::AbsGradKernel>); REGISTER_OP_CUDA_KERNEL( abs_grad_grad, ops::AbsDoubleGradKernel, @@ -88,5 +88,5 @@ REGISTER_OP_CUDA_KERNEL( ops::AbsDoubleGradKernel, ops::AbsDoubleGradKernel, ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel, - ops::AbsDoubleGradKernel); + ops::AbsDoubleGradKernel>, + ops::AbsDoubleGradKernel>); diff --git a/paddle/fluid/operators/conj_op.cc b/paddle/fluid/operators/conj_op.cc index 3afe4f1e3d1..4d801bc003e 100644 --- a/paddle/fluid/operators/conj_op.cc +++ b/paddle/fluid/operators/conj_op.cc @@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker, REGISTER_OP_CPU_KERNEL( conj, ops::ConjKernel, + paddle::platform::complex>, ops::ConjKernel, + paddle::platform::complex>, ops::ConjKernel, ops::ConjKernel, ops::ConjKernel, diff --git a/paddle/fluid/operators/conj_op.cu b/paddle/fluid/operators/conj_op.cu index 601caeb5055..d04024d70a8 100644 --- a/paddle/fluid/operators/conj_op.cu +++ b/paddle/fluid/operators/conj_op.cu @@ -13,15 +13,14 @@ // limitations under the License. #include "paddle/fluid/operators/conj_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( conj, ops::ConjKernel, + paddle::platform::complex>, ops::ConjKernel, + paddle::platform::complex>, ops::ConjKernel, ops::ConjKernel, ops::ConjKernel, diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc index 26f12e8f9e3..31acd971811 100644 --- a/paddle/fluid/operators/dot_op.cc +++ b/paddle/fluid/operators/dot_op.cc @@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel { "Output(Out) of DotOp should not be null.")); auto x_dims = ctx->GetInputDim("X"); - auto x_rank = (size_t)x_dims.size(); + auto x_rank = static_cast(x_dims.size()); PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank, platform::errors::PreconditionNotMet( "ShapeError: The dimensions of input tensor X (%s) " @@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL( ops::DotKernel, ops::DotKernel, ops::DotKernel, + paddle::platform::complex>, ops::DotKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( dot_grad, ops::DotGradKernel, ops::DotGradKernel, ops::DotGradKernel, ops::DotGradKernel, ops::DotGradKernel, + paddle::platform::complex>, ops::DotGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/dot_op.cu b/paddle/fluid/operators/dot_op.cu index 2d259ba1fbc..49f27e1ffb1 100644 --- a/paddle/fluid/operators/dot_op.cu +++ b/paddle/fluid/operators/dot_op.cu @@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL( ops::DotKernel, ops::DotKernel, ops::DotKernel, - ops::DotKernel, - ops::DotKernel); -REGISTER_OP_CUDA_KERNEL( - dot_grad, ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel, - ops::DotGradKernel); + ops::DotKernel>, + ops::DotKernel>); +REGISTER_OP_CUDA_KERNEL(dot_grad, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel>, + ops::DotGradKernel>); diff --git a/paddle/fluid/operators/imag_op.cc b/paddle/fluid/operators/imag_op.cc index 899025ae709..6a195bb9400 100644 --- a/paddle/fluid/operators/imag_op.cc +++ b/paddle/fluid/operators/imag_op.cc @@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker, REGISTER_OPERATOR(imag_grad, ops::ImagGradOp); REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel, + paddle::platform::complex>, ops::ImagKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL(imag_grad, ops::ImagGradKernel, + paddle::platform::complex>, ops::ImagGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/imag_op.cu b/paddle/fluid/operators/imag_op.cu index a7a3b136821..9cfb2ef7f2f 100644 --- a/paddle/fluid/operators/imag_op.cu +++ b/paddle/fluid/operators/imag_op.cu @@ -18,11 +18,11 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(imag, ops::ImagKernel, + paddle::platform::complex>, ops::ImagKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL(imag_grad, ops::ImagGradKernel, + paddle::platform::complex>, ops::ImagGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/math/complex_functors.h b/paddle/fluid/operators/math/complex_functors.h index 0e8aed40f6e..f5302566778 100644 --- a/paddle/fluid/operators/math/complex_functors.h +++ b/paddle/fluid/operators/math/complex_functors.h @@ -16,8 +16,7 @@ limitations under the License. */ #include -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/hostdevice.h" namespace paddle { @@ -66,7 +65,10 @@ using select_t = typename select::type; template using Real = select_t::value, float>, - cond::value, double>, T>; + cond::value, double>, + cond>::value, float>, + cond>::value, double>, + T>; template using Complex = typename std::enable_if::value>::type; @@ -76,14 +78,18 @@ template using NoComplex = typename std::enable_if::value>::type; template -using EnableComplex = - typename std::enable_if::value || - std::is_same::value>::type; +using EnableComplex = typename std::enable_if< + std::is_same::value || + std::is_same::value || + std::is_same>::value || + std::is_same>::value>::type; template using DisableComplex = typename std::enable_if< !std::is_same::value && - !std::is_same::value>::type; + !std::is_same::value && + !std::is_same>::value && + !std::is_same>::value>::type; template struct RealFunctor; @@ -173,44 +179,45 @@ struct AbsGradFunctor { }; template <> -struct AbsGradFunctor { - AbsGradFunctor(const float* dout, const paddle::platform::complex64* x, - paddle::platform::complex64* output, int64_t numel) +struct AbsGradFunctor> { + AbsGradFunctor(const float* dout, const paddle::platform::complex* x, + paddle::platform::complex* output, int64_t numel) : dout_(dout), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - if (x_[idx] == paddle::platform::complex64(0)) { - output_[idx] = paddle::platform::complex64(0); + if (x_[idx] == paddle::platform::complex(0)) { + output_[idx] = paddle::platform::complex(0); } else { - output_[idx] = paddle::platform::complex64(dout_[idx]) * - (x_[idx] / paddle::platform::complex64(abs(x_[idx]))); + output_[idx] = paddle::platform::complex(dout_[idx]) * + (x_[idx] / paddle::platform::complex(abs(x_[idx]))); } } const float* dout_; - const paddle::platform::complex64* x_; - paddle::platform::complex64* output_; + const paddle::platform::complex* x_; + paddle::platform::complex* output_; int64_t numel_; }; template <> -struct AbsGradFunctor { - AbsGradFunctor(const double* dout, const paddle::platform::complex128* x, - paddle::platform::complex128* output, int64_t numel) +struct AbsGradFunctor> { + AbsGradFunctor(const double* dout, const paddle::platform::complex* x, + paddle::platform::complex* output, int64_t numel) : dout_(dout), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - if (x_[idx] == paddle::platform::complex128(0)) { - output_[idx] = paddle::platform::complex128(0); + if (x_[idx] == paddle::platform::complex(0)) { + output_[idx] = paddle::platform::complex(0); } else { - output_[idx] = paddle::platform::complex128(dout_[idx]) * - (x_[idx] / paddle::platform::complex128(abs(x_[idx]))); + output_[idx] = + paddle::platform::complex(dout_[idx]) * + (x_[idx] / paddle::platform::complex(abs(x_[idx]))); } } const double* dout_; - const paddle::platform::complex128* x_; - paddle::platform::complex128* output_; + const paddle::platform::complex* x_; + paddle::platform::complex* output_; int64_t numel_; }; @@ -234,46 +241,46 @@ struct AbsGradGradFunctor { }; template <> -struct AbsGradGradFunctor { - AbsGradGradFunctor(const paddle::platform::complex128* ddx, - const paddle::platform::complex128* x, - paddle::platform::complex128* output, int64_t numel) +struct AbsGradGradFunctor> { + AbsGradGradFunctor(const paddle::platform::complex* ddx, + const paddle::platform::complex* x, + paddle::platform::complex* output, int64_t numel) : ddx_(ddx), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - if (x_[idx] == paddle::platform::complex128(0)) { - output_[idx] = paddle::platform::complex128(0); + if (x_[idx] == paddle::platform::complex(0)) { + output_[idx] = paddle::platform::complex(0); } else { - output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] / - paddle::platform::complex128(abs(x_[idx])); + output_[idx] = paddle::platform::complex(ddx_[idx]) * x_[idx] / + paddle::platform::complex(abs(x_[idx])); } } - const paddle::platform::complex128* ddx_; - const paddle::platform::complex128* x_; - paddle::platform::complex128* output_; + const paddle::platform::complex* ddx_; + const paddle::platform::complex* x_; + paddle::platform::complex* output_; int64_t numel_; }; template <> -struct AbsGradGradFunctor { - AbsGradGradFunctor(const paddle::platform::complex64* ddx, - const paddle::platform::complex64* x, - paddle::platform::complex64* output, int64_t numel) +struct AbsGradGradFunctor> { + AbsGradGradFunctor(const paddle::platform::complex* ddx, + const paddle::platform::complex* x, + paddle::platform::complex* output, int64_t numel) : ddx_(ddx), x_(x), output_(output), numel_(numel) {} HOSTDEVICE void operator()(int64_t idx) const { - if (x_[idx] == paddle::platform::complex64(0)) { - output_[idx] = paddle::platform::complex64(0); + if (x_[idx] == paddle::platform::complex(0)) { + output_[idx] = paddle::platform::complex(0); } else { - output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] / - paddle::platform::complex64(abs(x_[idx])); + output_[idx] = paddle::platform::complex(ddx_[idx]) * x_[idx] / + paddle::platform::complex(abs(x_[idx])); } } - const paddle::platform::complex64* ddx_; - const paddle::platform::complex64* x_; - paddle::platform::complex64* output_; + const paddle::platform::complex* ddx_; + const paddle::platform::complex* x_; + paddle::platform::complex* output_; int64_t numel_; }; template diff --git a/paddle/fluid/operators/real_op.cc b/paddle/fluid/operators/real_op.cc index 5f667999ee6..1174e72a76b 100644 --- a/paddle/fluid/operators/real_op.cc +++ b/paddle/fluid/operators/real_op.cc @@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker, REGISTER_OPERATOR(real_grad, ops::RealGradOp); REGISTER_OP_CPU_KERNEL(real, ops::RealKernel, + paddle::platform::complex>, ops::RealKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL(real_grad, ops::RealGradKernel, + paddle::platform::complex>, ops::RealGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/real_op.cu b/paddle/fluid/operators/real_op.cu index b3d0855111b..9bfb2878a62 100644 --- a/paddle/fluid/operators/real_op.cu +++ b/paddle/fluid/operators/real_op.cu @@ -18,11 +18,11 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(real, ops::RealKernel, + paddle::platform::complex>, ops::RealKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL(real_grad, ops::RealGradKernel, + paddle::platform::complex>, ops::RealGradKernel); + paddle::platform::complex>); -- GitLab