diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 66c4d1c5f55ab99e56763876f4730ea948388baf..bc38e3b59b64495419937bf185132b1dbacb3c51 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -25,6 +25,8 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" @@ -161,6 +163,10 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(double); + // NOTE(chenweihang): only support complex grad tensor accumulated, + // support selected rows if needed in the future + PADDLE_TENSOR_ADD(platform::complex64); + PADDLE_TENSOR_ADD(platform::complex128); #undef PADDLE_TENSOR_ADD diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 53e07d2ba4e9259caf5b1ec8260ca9f7ff2273a8..c44c15adb13caf9be401c3174e68e229d1eea745 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -275,6 +275,15 @@ struct CUBlas { reinterpret_cast(C), ldc)); } + static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha, + const complex64 *X, const int incX, complex64 *Y, + const int incY) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy( + handle, n, reinterpret_cast(alpha), + reinterpret_cast(X), incX, + reinterpret_cast(Y), incY)); + } + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, @@ -362,6 +371,15 @@ struct CUBlas { reinterpret_cast(C), ldc)); } + static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha, + const complex128 *X, const int incX, complex128 *Y, + const int incY) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy( + handle, n, reinterpret_cast(alpha), + reinterpret_cast(X), incX, + reinterpret_cast(Y), incY)); + } + static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 32aced7619c41e942801bf0712dddd8d48ccb24c..5ccdeabf96bf318bee126608f3c9b37862439443 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -295,6 +295,13 @@ struct CBlas { template <> struct CBlas { + template + static void AXPY(int n, const paddle::platform::complex64 alpha, + const paddle::platform::complex64 *X, const int incX, + paddle::platform::complex64 *Y, const int incY) { + platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); + } + template static void VCOPY(ARGS... args) { platform::dynload::cblas_ccopy(args...); @@ -415,6 +422,13 @@ struct CBlas { template <> struct CBlas { + template + static void AXPY(int n, const paddle::platform::complex128 alpha, + const paddle::platform::complex128 *X, const int incX, + paddle::platform::complex128 *Y, const int incY) { + platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); + } + template static void VCOPY(ARGS... args) { platform::dynload::cblas_zcopy(args...); @@ -598,11 +612,6 @@ struct CBlas { cblas_ccopy(args...); } - template - static void VADD(ARGS... args) { - vcAdd(args...); - } - template static void AXPY(int n, const paddle::platform::complex64 alpha, const paddle::platform::complex64 *X, const int incX, @@ -641,11 +650,6 @@ struct CBlas { cblas_zcopy(args...); } - template - static void VADD(ARGS... args) { - vzAdd(args...); - } - template static void AXPY(int n, const paddle::platform::complex128 alpha, const paddle::platform::complex128 *X, const int incX, diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index c2595beb0cb4dc37104a91ac8a2647c7d787c5c5..21b60119dcacfe93e60c9a74d9bdc6b1c0723bf7 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -18,6 +18,8 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" namespace paddle { namespace operators { @@ -548,6 +550,10 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template struct MergeAverage; template struct MergeAverage; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 35bd02ad35b71eb7deb3299490fa545ef8b23dc6..26e9a0de606babfc325de58ba73404191751411c 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -448,6 +448,8 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template __global__ void UpdateToTensorKernel(const T* selected_rows, diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 0d689d710a19103cf667a76e592dfba9571cae5c..f2bee6dddc39ec965966e4964c954e5fb1441bf5 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,4 +23,6 @@ using CUDAReduceSumGradKernel = REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, - CUDAReduceSumGradKernel); + CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel); diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 4d9673e9646dedbd001eabcd70d4d34aecaa10b5..72430a3f75323c2d81d0e83b95361e601080c423 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include #include +#include "paddle/fluid/platform/complex128.h" +#include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -126,9 +128,22 @@ CUDA_ATOMIC_WRAPPER(Add, float16) { return ret; } } - #endif +CUDA_ATOMIC_WRAPPER(Add, complex64) { + float *real = reinterpret_cast(address); + float *imag = real + 1; + return complex64(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} + +CUDA_ATOMIC_WRAPPER(Add, complex128) { + double *real = reinterpret_cast(address); + double *imag = real + 1; + return complex128(CudaAtomicAdd(real, val.real), + CudaAtomicAdd(imag, val.imag)); +} + // For atomicMax USE_CUDA_ATOMIC(Max, int); USE_CUDA_ATOMIC(Max, unsigned int); diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 66032075f29836760fea0756f6d5a42d9153c461..96e16894c78c659a6173e6c0a5f57bdfa4e80827 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -55,6 +55,8 @@ extern void *cublas_dso_handle; #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ __macro(cublasSaxpy_v2); \ __macro(cublasDaxpy_v2); \ + __macro(cublasCaxpy_v2); \ + __macro(cublasZaxpy_v2); \ __macro(cublasSscal_v2); \ __macro(cublasDscal_v2); \ __macro(cublasScopy_v2); \ diff --git a/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py b/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py new file mode 100644 index 0000000000000000000000000000000000000000..106b9fe15a331a7c53793a06684e46c659cde10a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_complex_grad_accumulated.py @@ -0,0 +1,101 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle + +import paddle.fluid.core as core + + +class Optimization_ex1(paddle.nn.Layer): + def __init__(self, + shape, + dtype, + param_attr=paddle.nn.initializer.Uniform( + low=-5., high=5.)): + super(Optimization_ex1, self).__init__() + + self.theta0 = self.create_parameter( + shape=shape, attr=param_attr, dtype=dtype, is_bias=False) + self.theta1 = self.create_parameter( + shape=shape, attr=param_attr, dtype=dtype, is_bias=False) + self.A = paddle.to_tensor( + np.random.random((4, 4)).astype(dtype) + np.random.random((4, 4)) + .astype(dtype) * 1j) + self.B = paddle.to_tensor( + np.random.random((4, 4)).astype(dtype) + np.random.random( + (4, 4)).astype(dtype) * 1j, + stop_gradient=False) + print(self.A) + + def forward(self, mode=1): + jj = paddle.to_tensor(np.array([1j]).astype(np.complex64)) + if mode == 1: + # run all calc in one step + loss = paddle.sum(self.A + (self.theta0 + self.theta1 * jj)) * ( + paddle.sum(self.A + (self.theta0 + self.theta1 * jj)).conj()) + return loss.real() + elif mode == 2: + # run in two step + self.theta = self.theta0 + self.theta1 * jj + loss = paddle.sum(self.A + self.theta) * ( + paddle.sum(self.A + self.theta).conj()) + return loss.real() + elif mode == 3: + # run without param + loss = paddle.sum(self.A + self.B) * ( + paddle.sum(self.A + self.B).conj()) + return loss.real() + else: + raise NotImplementedError + + +class TestComplexGradAccumulated(unittest.TestCase): + def setUp(self): + self.devices = ['cpu'] + if core.is_compiled_with_cuda(): + self.devices.append('gpu') + self.dtypes = ['float32', 'float64'] + self.theta_size = [4, 4] + + def run_backward(self, device, dtype, mode): + paddle.set_device(device) + + myLayer = Optimization_ex1(self.theta_size, dtype) + + loss = myLayer(mode) + loss.backward() + + def test_case_one_step(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 1) + + def test_case_two_step(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 2) + + def test_case_non_param(self): + for dev in self.devices: + for dtype in self.dtypes: + self.run_backward(dev, dtype, 3) + + +if __name__ == '__main__': + unittest.main()