未验证 提交 1a304e6c 编写于 作者: C Chen Weihang 提交者: GitHub

[Complex] Add support for complex grad accumulated (#29889)

* add support for complex grad accumulated

* add unittest for coverage

* update test dtype

* remove useless blank line
上级 c7acad9f
......@@ -25,6 +25,8 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -161,6 +163,10 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
PADDLE_TENSOR_ADD(float);
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64);
PADDLE_TENSOR_ADD(platform::complex128);
#undef PADDLE_TENSOR_ADD
......
......@@ -275,6 +275,15 @@ struct CUBlas<platform::complex64> {
reinterpret_cast<cuFloatComplex *>(C), ldc));
}
static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha,
const complex64 *X, const int incX, complex64 *Y,
const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy(
handle, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(X), incX,
reinterpret_cast<cuFloatComplex *>(Y), incY));
}
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
......@@ -362,6 +371,15 @@ struct CUBlas<platform::complex128> {
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}
static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha,
const complex128 *X, const int incX, complex128 *Y,
const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy(
handle, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(X), incX,
reinterpret_cast<cuDoubleComplex *>(Y), incY));
}
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
......
......@@ -295,6 +295,13 @@ struct CBlas<double> {
template <>
struct CBlas<platform::complex64> {
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *X, const int incX,
paddle::platform::complex64 *Y, const int incY) {
platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_ccopy(args...);
......@@ -415,6 +422,13 @@ struct CBlas<platform::complex64> {
template <>
struct CBlas<platform::complex128> {
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *X, const int incX,
paddle::platform::complex128 *Y, const int incY) {
platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_zcopy(args...);
......@@ -598,11 +612,6 @@ struct CBlas<platform::complex64> {
cblas_ccopy(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
vcAdd(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *X, const int incX,
......@@ -641,11 +650,6 @@ struct CBlas<platform::complex128> {
cblas_zcopy(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
vzAdd(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *X, const int incX,
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
......@@ -548,6 +550,10 @@ template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex64>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex128>;
template struct MergeAverage<platform::CPUDeviceContext, int>;
template struct MergeAverage<platform::CPUDeviceContext, int64_t>;
......
......@@ -448,6 +448,8 @@ template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template struct MergeAdd<platform::CUDADeviceContext, platform::float16>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex64>;
template struct MergeAdd<platform::CUDADeviceContext, platform::complex128>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
......
......@@ -23,4 +23,6 @@ using CUDAReduceSumGradKernel =
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>);
CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex64>,
CUDAReduceSumGradKernel<paddle::platform::complex128>);
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include <cuda.h>
#include <stdio.h>
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -126,9 +128,22 @@ CUDA_ATOMIC_WRAPPER(Add, float16) {
return ret;
}
}
#endif
CUDA_ATOMIC_WRAPPER(Add, complex64) {
float *real = reinterpret_cast<float *>(address);
float *imag = real + 1;
return complex64(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
CUDA_ATOMIC_WRAPPER(Add, complex128) {
double *real = reinterpret_cast<double *>(address);
double *imag = real + 1;
return complex128(CudaAtomicAdd(real, val.real),
CudaAtomicAdd(imag, val.imag));
}
// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
......
......@@ -55,6 +55,8 @@ extern void *cublas_dso_handle;
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSaxpy_v2); \
__macro(cublasDaxpy_v2); \
__macro(cublasCaxpy_v2); \
__macro(cublasZaxpy_v2); \
__macro(cublasSscal_v2); \
__macro(cublasDscal_v2); \
__macro(cublasScopy_v2); \
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
class Optimization_ex1(paddle.nn.Layer):
def __init__(self,
shape,
dtype,
param_attr=paddle.nn.initializer.Uniform(
low=-5., high=5.)):
super(Optimization_ex1, self).__init__()
self.theta0 = self.create_parameter(
shape=shape, attr=param_attr, dtype=dtype, is_bias=False)
self.theta1 = self.create_parameter(
shape=shape, attr=param_attr, dtype=dtype, is_bias=False)
self.A = paddle.to_tensor(
np.random.random((4, 4)).astype(dtype) + np.random.random((4, 4))
.astype(dtype) * 1j)
self.B = paddle.to_tensor(
np.random.random((4, 4)).astype(dtype) + np.random.random(
(4, 4)).astype(dtype) * 1j,
stop_gradient=False)
print(self.A)
def forward(self, mode=1):
jj = paddle.to_tensor(np.array([1j]).astype(np.complex64))
if mode == 1:
# run all calc in one step
loss = paddle.sum(self.A + (self.theta0 + self.theta1 * jj)) * (
paddle.sum(self.A + (self.theta0 + self.theta1 * jj)).conj())
return loss.real()
elif mode == 2:
# run in two step
self.theta = self.theta0 + self.theta1 * jj
loss = paddle.sum(self.A + self.theta) * (
paddle.sum(self.A + self.theta).conj())
return loss.real()
elif mode == 3:
# run without param
loss = paddle.sum(self.A + self.B) * (
paddle.sum(self.A + self.B).conj())
return loss.real()
else:
raise NotImplementedError
class TestComplexGradAccumulated(unittest.TestCase):
def setUp(self):
self.devices = ['cpu']
if core.is_compiled_with_cuda():
self.devices.append('gpu')
self.dtypes = ['float32', 'float64']
self.theta_size = [4, 4]
def run_backward(self, device, dtype, mode):
paddle.set_device(device)
myLayer = Optimization_ex1(self.theta_size, dtype)
loss = myLayer(mode)
loss.backward()
def test_case_one_step(self):
for dev in self.devices:
for dtype in self.dtypes:
self.run_backward(dev, dtype, 1)
def test_case_two_step(self):
for dev in self.devices:
for dtype in self.dtypes:
self.run_backward(dev, dtype, 2)
def test_case_non_param(self):
for dev in self.devices:
for dtype in self.dtypes:
self.run_backward(dev, dtype, 3)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册