From dbc08d69921ff8e405ecffd0b2cc36cc25af1054 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 25 May 2021 20:04:01 +0800 Subject: [PATCH] modify complex template for elementwise ops (#33071) * modify complex template for elementwise ops * modify mul, div grad struct * add complex template for CudaShuffleDownSync CudaShuffleXorSync funcs and fix the bug when delete cuda<9000 * fix shuffle func args bug * fix shuffle func args bug * fix shuffle func args bug --- .../elementwise/elementwise_add_op.cc | 20 +++--- .../elementwise/elementwise_add_op.cu | 21 +++--- .../elementwise/elementwise_div_op.cc | 15 ++-- .../elementwise/elementwise_div_op.cu | 58 ++++++++------- .../elementwise/elementwise_div_op.h | 48 ++++--------- .../elementwise/elementwise_mul_op.cc | 15 ++-- .../elementwise/elementwise_mul_op.cu | 45 ++++++------ .../elementwise/elementwise_mul_op.h | 48 ++++--------- .../elementwise/elementwise_sub_op.cc | 16 ++--- .../elementwise/elementwise_sub_op.cu | 15 ++-- paddle/fluid/platform/cuda_device_function.h | 72 ++++++++++++------- 11 files changed, 180 insertions(+), 193 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index b551629169..67e2e3a1e9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -20,8 +20,8 @@ limitations under the License. */ namespace paddle { namespace platform { -struct complex128; -struct complex64; +template +struct complex; } // namespace platform } // namespace paddle @@ -135,9 +135,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, + paddle::platform::complex>, ops::ElementwiseAddKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, @@ -145,9 +145,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, + paddle::platform::complex>, ops::ElementwiseAddGradKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_add_grad_grad, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseAddDoubleGradKernel); + paddle::platform::complex>); // A specialization elementwise_add operator, used in gradient accumulation with // inplace addto. @@ -178,9 +178,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, + paddle::platform::complex>, ops::ElementwiseAddKernel); + paddle::platform::complex>); REGISTER_OP_VERSION(elementwise_add) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index a4b97301a2..37e5fa5a20 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -141,8 +140,8 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel>, + ops::ElementwiseAddKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad, ops::ElementwiseAddGradKernel, @@ -150,8 +149,10 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel, - ops::ElementwiseAddGradKernel); + ops::ElementwiseAddGradKernel>, + ops::ElementwiseAddGradKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_add_grad_grad, ops::ElementwiseAddDoubleGradKernel, @@ -160,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, ops::ElementwiseAddDoubleGradKernel, + plat::complex>, ops::ElementwiseAddDoubleGradKernel); + plat::complex>); REGISTER_OP_CUDA_KERNEL( grad_add, ops::ElementwiseAddKernel, @@ -170,5 +171,5 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel); + ops::ElementwiseAddKernel>, + ops::ElementwiseAddKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 0252e6dfff..9a899ec11b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -17,8 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -135,9 +134,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, + paddle::platform::complex>, ops::ElementwiseDivKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, @@ -145,9 +144,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + paddle::platform::complex>, ops::ElementwiseDivGradKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_div_grad_grad, @@ -160,9 +159,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseDivDoubleGradKernel); + paddle::platform::complex>); REGISTER_OP_VERSION(elementwise_div) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 0cf9294c9d..b10ed57af9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -14,8 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -76,18 +75,21 @@ static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, } template <> -__global__ void SimpleElemwiseDivGradCUDAKernel( - const paddle::platform::complex64* x, const paddle::platform::complex64* y, - const paddle::platform::complex64* out, - const paddle::platform::complex64* dout, int64_t size, - paddle::platform::complex64* dx, paddle::platform::complex64* dy) { +__global__ void +SimpleElemwiseDivGradCUDAKernel>( + const paddle::platform::complex* x, + const paddle::platform::complex* y, + const paddle::platform::complex* out, + const paddle::platform::complex* dout, int64_t size, + paddle::platform::complex* dx, + paddle::platform::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { - paddle::platform::complex64 o = dout[col]; - paddle::platform::complex64 y_conj(y[col].real, -y[col].imag); - paddle::platform::complex64 out_div_y_conj((out[col] / y[col]).real, - -(out[col] / y[col]).imag); + paddle::platform::complex o = dout[col]; + paddle::platform::complex y_conj(y[col].real, -y[col].imag); + paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); dx[col] = o / y_conj; dy[col] = -o * out_div_y_conj; col += blockDim.x * gridDim.x; @@ -95,19 +97,21 @@ __global__ void SimpleElemwiseDivGradCUDAKernel( } template <> -__global__ void SimpleElemwiseDivGradCUDAKernel( - const paddle::platform::complex128* x, - const paddle::platform::complex128* y, - const paddle::platform::complex128* out, - const paddle::platform::complex128* dout, int64_t size, - paddle::platform::complex128* dx, paddle::platform::complex128* dy) { +__global__ void +SimpleElemwiseDivGradCUDAKernel>( + const paddle::platform::complex* x, + const paddle::platform::complex* y, + const paddle::platform::complex* out, + const paddle::platform::complex* dout, int64_t size, + paddle::platform::complex* dx, + paddle::platform::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { - paddle::platform::complex128 o = dout[col]; - paddle::platform::complex128 y_conj(y[col].real, -y[col].imag); - paddle::platform::complex128 out_div_y_conj((out[col] / y[col]).real, - -(out[col] / y[col]).imag); + paddle::platform::complex o = dout[col]; + paddle::platform::complex y_conj(y[col].real, -y[col].imag); + paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, + -(out[col] / y[col]).imag); dx[col] = o / y_conj; dy[col] = -o * out_div_y_conj; col += blockDim.x * gridDim.x; @@ -145,9 +149,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, ops::ElementwiseDivKernel, + paddle::platform::complex>, ops::ElementwiseDivKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad, ops::ElementwiseDivGradKernel, @@ -157,9 +161,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, ops::ElementwiseDivGradKernel, + paddle::platform::complex>, ops::ElementwiseDivGradKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( elementwise_div_grad_grad, ops::ElementwiseDivDoubleGradKernel, ops::ElementwiseDivDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseDivDoubleGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 0be8d934b1..a0b9633acb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -74,23 +74,13 @@ struct DivGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout / y; } }; -template <> -struct DivGradDX { - HOSTDEVICE paddle::platform::complex64 operator()( - paddle::platform::complex64 x, paddle::platform::complex64 y, - paddle::platform::complex64 out, paddle::platform::complex64 dout) const { - paddle::platform::complex64 y_conj(y.real, -y.imag); - return dout / y_conj; - } -}; - -template <> -struct DivGradDX { - HOSTDEVICE paddle::platform::complex128 operator()( - paddle::platform::complex128 x, paddle::platform::complex128 y, - paddle::platform::complex128 out, - paddle::platform::complex128 dout) const { - paddle::platform::complex128 y_conj(y.real, -y.imag); +template +struct DivGradDX> { + HOSTDEVICE paddle::platform::complex operator()( + paddle::platform::complex x, paddle::platform::complex y, + paddle::platform::complex out, + paddle::platform::complex dout) const { + paddle::platform::complex y_conj(y.real, -y.imag); return dout / y_conj; } }; @@ -102,23 +92,13 @@ struct DivGradDY { } }; -template <> -struct DivGradDY { - HOSTDEVICE paddle::platform::complex64 operator()( - paddle::platform::complex64 x, paddle::platform::complex64 y, - paddle::platform::complex64 out, paddle::platform::complex64 dout) const { - paddle::platform::complex64 out_div_y_conj((out / y).real, -(out / y).imag); - return -dout * out_div_y_conj; - } -}; - -template <> -struct DivGradDY { - HOSTDEVICE paddle::platform::complex128 operator()( - paddle::platform::complex128 x, paddle::platform::complex128 y, - paddle::platform::complex128 out, - paddle::platform::complex128 dout) const { - paddle::platform::complex128 out_div_y_conj((out / y).real, +template +struct DivGradDY> { + HOSTDEVICE paddle::platform::complex operator()( + paddle::platform::complex x, paddle::platform::complex y, + paddle::platform::complex out, + paddle::platform::complex dout) const { + paddle::platform::complex out_div_y_conj((out / y).real, -(out / y).imag); return -dout * out_div_y_conj; } diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 6bf296f0e0..0045f00ecc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -16,8 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -134,9 +133,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, + paddle::platform::complex>, ops::ElementwiseMulKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, @@ -144,9 +143,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, + paddle::platform::complex>, ops::ElementwiseMulGradKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_mul_grad_grad, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseMulDoubleGradKernel); + paddle::platform::complex>); REGISTER_OP_VERSION(elementwise_mul) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index e01b5eb5fb..8fd4609c3a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -14,8 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -76,31 +75,31 @@ static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, } template <> -__global__ void SimpleElemwiseMulGradCUDAKernel( - const plat::complex64* x, const plat::complex64* y, - const plat::complex64* out, const plat::complex64* dout, int64_t size, - plat::complex64* dx, plat::complex64* dy) { +__global__ void SimpleElemwiseMulGradCUDAKernel>( + const plat::complex* x, const plat::complex* y, + const plat::complex* out, const plat::complex* dout, + int64_t size, plat::complex* dx, plat::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { - plat::complex64 o = dout[col]; - dx[col] = plat::complex64(y[col].real, -y[col].imag) * o; - dy[col] = plat::complex64(x[col].real, -x[col].imag) * o; + plat::complex o = dout[col]; + dx[col] = plat::complex(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex(x[col].real, -x[col].imag) * o; col += blockDim.x * gridDim.x; } } template <> -__global__ void SimpleElemwiseMulGradCUDAKernel( - const plat::complex128* x, const plat::complex128* y, - const plat::complex128* out, const plat::complex128* dout, int64_t size, - plat::complex128* dx, plat::complex128* dy) { +__global__ void SimpleElemwiseMulGradCUDAKernel>( + const plat::complex* x, const plat::complex* y, + const plat::complex* out, const plat::complex* dout, + int64_t size, plat::complex* dx, plat::complex* dy) { int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { - plat::complex128 o = dout[col]; - dx[col] = plat::complex128(y[col].real, -y[col].imag) * o; - dy[col] = plat::complex128(x[col].real, -x[col].imag) * o; + plat::complex o = dout[col]; + dx[col] = plat::complex(y[col].real, -y[col].imag) * o; + dy[col] = plat::complex(x[col].real, -x[col].imag) * o; col += blockDim.x * gridDim.x; } } @@ -133,8 +132,8 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel, - ops::ElementwiseMulKernel); + ops::ElementwiseMulKernel>, + ops::ElementwiseMulKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad, ops::ElementwiseMulGradKernel, @@ -142,8 +141,10 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel); + ops::ElementwiseMulGradKernel>, + ops::ElementwiseMulGradKernel>); REGISTER_OP_CUDA_KERNEL( elementwise_mul_grad_grad, ops::ElementwiseMulDoubleGradKernel, @@ -152,6 +153,6 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, ops::ElementwiseMulDoubleGradKernel, + plat::complex>, ops::ElementwiseMulDoubleGradKernel); + plat::complex>); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 46a00268e4..10e6949164 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -132,23 +132,13 @@ struct MulGradDX { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } }; -template <> -struct MulGradDX { - HOSTDEVICE paddle::platform::complex64 operator()( - paddle::platform::complex64 x, paddle::platform::complex64 y, - paddle::platform::complex64 out, paddle::platform::complex64 dout) const { - paddle::platform::complex64 y_conj(y.real, -y.imag); - return dout * y_conj; - } -}; - -template <> -struct MulGradDX { - HOSTDEVICE paddle::platform::complex128 operator()( - paddle::platform::complex128 x, paddle::platform::complex128 y, - paddle::platform::complex128 out, - paddle::platform::complex128 dout) const { - paddle::platform::complex128 y_conj(y.real, -y.imag); +template +struct MulGradDX> { + HOSTDEVICE paddle::platform::complex operator()( + paddle::platform::complex x, paddle::platform::complex y, + paddle::platform::complex out, + paddle::platform::complex dout) const { + paddle::platform::complex y_conj(y.real, -y.imag); return dout * y_conj; } }; @@ -158,23 +148,13 @@ struct MulGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } }; -template <> -struct MulGradDY { - HOSTDEVICE paddle::platform::complex64 operator()( - paddle::platform::complex64 x, paddle::platform::complex64 y, - paddle::platform::complex64 out, paddle::platform::complex64 dout) const { - paddle::platform::complex64 x_conj(x.real, -x.imag); - return dout * x_conj; - } -}; - -template <> -struct MulGradDY { - HOSTDEVICE paddle::platform::complex128 operator()( - paddle::platform::complex128 x, paddle::platform::complex128 y, - paddle::platform::complex128 out, - paddle::platform::complex128 dout) const { - paddle::platform::complex128 x_conj(x.real, -x.imag); +template +struct MulGradDY> { + HOSTDEVICE paddle::platform::complex operator()( + paddle::platform::complex x, paddle::platform::complex y, + paddle::platform::complex out, + paddle::platform::complex dout) const { + paddle::platform::complex x_conj(x.real, -x.imag); return dout * x_conj; } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index 1951ed7f5d..84aa189b89 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -20,8 +20,8 @@ limitations under the License. */ namespace paddle { namespace platform { -struct complex128; -struct complex64; +template +struct complex; } // namespace platform } // namespace paddle @@ -134,9 +134,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, + paddle::platform::complex>, ops::ElementwiseSubKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, @@ -144,9 +144,9 @@ REGISTER_OP_CPU_KERNEL( ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + paddle::platform::complex>, ops::ElementwiseSubGradKernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( elementwise_sub_grad_grad, ops::ElementwiseSubDoubleGradKernel, ops::ElementwiseSubDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseSubDoubleGradKernel); + paddle::platform::complex>); REGISTER_OP_VERSION(elementwise_sub) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 192999fd2a..19cbbb7bf0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -14,8 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; @@ -103,9 +102,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, ops::ElementwiseSubKernel, + paddle::platform::complex>, ops::ElementwiseSubKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad, ops::ElementwiseSubGradKernel, @@ -115,9 +114,9 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, ops::ElementwiseSubGradKernel, + paddle::platform::complex>, ops::ElementwiseSubGradKernel); + paddle::platform::complex>); REGISTER_OP_CUDA_KERNEL( elementwise_sub_grad_grad, ops::ElementwiseSubDoubleGradKernel, ops::ElementwiseSubDoubleGradKernel, + paddle::platform::complex>, ops::ElementwiseSubDoubleGradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index dde9531e59..4095720f71 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -16,8 +16,7 @@ limitations under the License. */ // NOTE(): support float16 to half in header file. #define PADDLE_CUDA_FP16 -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -82,28 +81,52 @@ __forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val, #endif } -// CUDA 9.0 have native compatible float16 shfl_down #if defined(PADDLE_WITH_HIP) template <> __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, float16 val, int delta, int width) { -#ifdef PADDLE_WITH_HIP return float16(__shfl_down(static_cast(val), static_cast(delta), width)); -#else - return float16( - __shfl_down(static_cast(val), static_cast(delta), width)); -#endif } + +template <> +__forceinline__ __device__ paddle::platform::complex CudaShuffleDownSync( + unsigned mask, paddle::platform::complex val, int delta, int width) { + float real = __shfl_down(val.real, delta, width); + float imag = __shfl_down(val.imag, delta, width); + return paddle::platform::complex(real, imag); +} + +template <> +__forceinline__ __device__ paddle::platform::complex +CudaShuffleDownSync(unsigned mask, paddle::platform::complex val, + int delta, int width) { + double real = __shfl_down(val.real, delta, width); + double imag = __shfl_down(val.imag, delta, width); + return paddle::platform::complex(real, imag); +} + template <> __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, float16 val, int width) { -#ifdef PADDLE_WITH_HIP return float16(__shfl_xor(static_cast(val), width)); -#else - return float16(__shfl_xor(static_cast(val), width)); -#endif +} + +template <> +__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( + unsigned mask, paddle::platform::complex val, int width) { + float real = __shfl_xor(val.real, width); + float imag = __shfl_xor(val.imag, width); + return paddle::platform::complex(real, imag); +} + +template <> +__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( + unsigned mask, paddle::platform::complex val, int width) { + double real = __shfl_xor(val.real, width); + double imag = __shfl_xor(val.imag, width); + return paddle::platform::complex(real, imag); } #else template <> @@ -115,25 +138,26 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, } template <> -__forceinline__ __device__ paddle::platform::complex64 CudaShuffleDownSync( - unsigned mask, paddle::platform::complex64 val, int delta, int width) { +__forceinline__ __device__ paddle::platform::complex CudaShuffleDownSync( + unsigned mask, paddle::platform::complex val, int delta, int width) { float real = static_cast(__shfl_down_sync( mask, static_cast(val.real), static_cast(delta), width)); float imag = static_cast(__shfl_down_sync( mask, static_cast(val.imag), static_cast(delta), width)); - return paddle::platform::complex64(real, imag); + return paddle::platform::complex(real, imag); } template <> -__forceinline__ __device__ paddle::platform::complex128 CudaShuffleDownSync( - unsigned mask, paddle::platform::complex128 val, int delta, int width) { +__forceinline__ __device__ paddle::platform::complex +CudaShuffleDownSync(unsigned mask, paddle::platform::complex val, + int delta, int width) { double real = static_cast( __shfl_down_sync(mask, static_cast(val.real), static_cast(delta), width)); double imag = static_cast( __shfl_down_sync(mask, static_cast(val.imag), static_cast(delta), width)); - return paddle::platform::complex128(real, imag); + return paddle::platform::complex(real, imag); } template <> @@ -143,23 +167,23 @@ __forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, } template <> -__forceinline__ __device__ paddle::platform::complex64 CudaShuffleXorSync( - unsigned mask, paddle::platform::complex64 val, int width) { +__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( + unsigned mask, paddle::platform::complex val, int width) { float real = static_cast( __shfl_xor_sync(mask, static_cast(val.real), width)); float imag = static_cast( __shfl_xor_sync(mask, static_cast(val.imag), width)); - return paddle::platform::complex64(real, imag); + return paddle::platform::complex(real, imag); } template <> -__forceinline__ __device__ paddle::platform::complex128 CudaShuffleXorSync( - unsigned mask, paddle::platform::complex128 val, int width) { +__forceinline__ __device__ paddle::platform::complex CudaShuffleXorSync( + unsigned mask, paddle::platform::complex val, int width) { double real = static_cast( __shfl_xor_sync(mask, static_cast(val.real), width)); double imag = static_cast( __shfl_xor_sync(mask, static_cast(val.imag), width)); - return paddle::platform::complex128(real, imag); + return paddle::platform::complex(real, imag); } #endif -- GitLab