From 6c07cd7ec23db3f6b002a3b07ef8c1f4469b02f2 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Wed, 26 May 2021 19:10:43 +0800 Subject: [PATCH] modify matmul Op to complex template types (#33130) * modify matmul Op to complex template types * remove complex64/128 head file --- .../fluid/imperative/gradient_accumulator.cc | 7 +- paddle/fluid/operators/math/blas_impl.cu.h | 106 +++++----- paddle/fluid/operators/math/blas_impl.h | 186 +++++++++--------- paddle/fluid/operators/math/blas_impl.hip.h | 94 +++++---- .../operators/math/selected_rows_functor.cc | 8 +- paddle/fluid/operators/matmul_v2_op.cc | 8 +- paddle/fluid/operators/matmul_v2_op.cu | 8 +- paddle/fluid/operators/matmul_v2_op.h | 20 +- 8 files changed, 230 insertions(+), 207 deletions(-) diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 6b9b4117133..57657941ef8 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -24,8 +24,7 @@ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/profiler.h" @@ -200,8 +199,8 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { PADDLE_TENSOR_ADD(double); // NOTE(chenweihang): only support complex grad tensor accumulated, // support selected rows if needed in the future - PADDLE_TENSOR_ADD(platform::complex64); - PADDLE_TENSOR_ADD(platform::complex128); + PADDLE_TENSOR_ADD(platform::complex); + PADDLE_TENSOR_ADD(platform::complex); #endif #undef PADDLE_TENSOR_ADD diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index c44c15adb13..477f3e0f6a2 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -260,13 +260,13 @@ struct CUBlas { }; template <> -struct CUBlas { - using complex64 = platform::complex64; - +struct CUBlas> { static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, - int n, const complex64 *alpha, const complex64 *A, int lda, - const complex64 *B, int ldb, const complex64 *beta, - complex64 *C, int ldc) { + int n, const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv( handle, transa, m, n, reinterpret_cast(alpha), reinterpret_cast(A), lda, @@ -275,9 +275,10 @@ struct CUBlas { reinterpret_cast(C), ldc)); } - static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha, - const complex64 *X, const int incX, complex64 *Y, - const int incY) { + static void AXPY(cublasHandle_t handle, int n, + const platform::complex *alpha, + const platform::complex *X, const int incX, + platform::complex *Y, const int incY) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy( handle, n, reinterpret_cast(alpha), reinterpret_cast(X), incX, @@ -287,11 +288,13 @@ struct CUBlas { static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const complex64 *alpha, const complex64 *A, - int lda, long long int strideA, // NOLINT - const complex64 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const complex64 *beta, complex64 *C, int ldc, + const platform::complex *alpha, + const platform::complex *A, int lda, + long long int strideA, // NOLINT + const platform::complex *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const platform::complex *beta, + platform::complex *C, int ldc, long long int strideC, // NOLINT int batchCount) { #if CUDA_VERSION >= 8000 @@ -310,9 +313,11 @@ struct CUBlas { static void GEMM(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const complex64 *alpha, const complex64 *A, int lda, - const complex64 *B, int ldb, const complex64 *beta, - complex64 *C, int ldc) { + const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm( handle, transa, transb, m, n, k, reinterpret_cast(alpha), @@ -356,13 +361,13 @@ struct CUBlas { }; template <> -struct CUBlas { - using complex128 = platform::complex128; - +struct CUBlas> { static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, - int n, const complex128 *alpha, const complex128 *A, int lda, - const complex128 *B, int ldb, const complex128 *beta, - complex128 *C, int ldc) { + int n, const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv( handle, transa, m, n, reinterpret_cast(alpha), reinterpret_cast(A), lda, @@ -371,9 +376,10 @@ struct CUBlas { reinterpret_cast(C), ldc)); } - static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha, - const complex128 *X, const int incX, complex128 *Y, - const int incY) { + static void AXPY(cublasHandle_t handle, int n, + const platform::complex *alpha, + const platform::complex *X, const int incX, + platform::complex *Y, const int incY) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy( handle, n, reinterpret_cast(alpha), reinterpret_cast(X), incX, @@ -383,11 +389,13 @@ struct CUBlas { static void GEMM_STRIDED_BATCH(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const complex128 *alpha, const complex128 *A, - int lda, long long int strideA, // NOLINT - const complex128 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const complex128 *beta, complex128 *C, int ldc, + const platform::complex *alpha, + const platform::complex *A, int lda, + long long int strideA, // NOLINT + const platform::complex *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const platform::complex *beta, + platform::complex *C, int ldc, long long int strideC, // NOLINT int batchCount) { #if CUDA_VERSION >= 8000 @@ -406,9 +414,11 @@ struct CUBlas { static void GEMM(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, - const complex128 *alpha, const complex128 *A, int lda, - const complex128 *B, int ldb, const complex128 *beta, - complex128 *C, int ldc) { + const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm( handle, transa, transb, m, n, k, reinterpret_cast(alpha), @@ -535,9 +545,9 @@ template <> template <> inline void Blas::GEMM( CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex64 alpha, const platform::complex64 *A, - const platform::complex64 *B, platform::complex64 beta, - platform::complex64 *C) const { + platform::complex alpha, const platform::complex *A, + const platform::complex *B, platform::complex beta, + platform::complex *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -565,16 +575,16 @@ inline void Blas::GEMM( // input/output in fp16, computation in fp32, which can also be accelerated // using tensor cores in volta GPUs. auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( + CUBlas>::GEMM_EX( &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A, CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, &c_beta, - h_C, N); + CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, + &c_alpha, h_B, ldb, h_A, lda, + &c_beta, h_C, N); }); #endif // CUDA_VERSION >= 8000 } @@ -583,9 +593,9 @@ template <> template <> inline void Blas::GEMM( CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex128 alpha, const platform::complex128 *A, - const platform::complex128 *B, platform::complex128 beta, - platform::complex128 *C) const { + platform::complex alpha, const platform::complex *A, + const platform::complex *B, platform::complex beta, + platform::complex *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -614,16 +624,16 @@ inline void Blas::GEMM( // input/output in fp16, computation in fp32, which can also be accelerated // using tensor cores in volta GPUs. auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( + CUBlas>::GEMM_EX( &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A, CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F); #else // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm context_.CublasCall([&](cublasHandle_t handle) { - CUBlas::GEMM(handle, cuTransB, cuTransA, N, M, K, - &c_alpha, h_B, ldb, h_A, lda, &c_beta, - h_C, N); + CUBlas>::GEMM(handle, cuTransB, cuTransA, N, M, K, + &c_alpha, h_B, ldb, h_A, lda, + &c_beta, h_C, N); }); #endif // CUDA_VERSION >= 8000 } diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 05d42f02c10..eab513e24bc 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -23,8 +23,7 @@ #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/complex128.h" -#include "paddle/fluid/platform/complex64.h" +#include "paddle/fluid/platform/complex.h" namespace paddle { namespace operators { @@ -324,11 +323,11 @@ struct CBlas { }; template <> -struct CBlas { +struct CBlas> { template - static void AXPY(int n, const paddle::platform::complex64 alpha, - const paddle::platform::complex64 *X, const int incX, - paddle::platform::complex64 *Y, const int incY) { + static void AXPY(int n, const paddle::platform::complex alpha, + const paddle::platform::complex *X, const int incX, + paddle::platform::complex *Y, const int incY) { platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); } @@ -363,35 +362,35 @@ struct CBlas { */ template - static void VADD(int n, const paddle::platform::complex64 *a, - const paddle::platform::complex64 *b, - paddle::platform::complex64 *y) { + static void VADD(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] + b[i]; } } template - static void VSUB(int n, const paddle::platform::complex64 *a, - const paddle::platform::complex64 *b, - paddle::platform::complex64 *y) { + static void VSUB(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] - b[i]; } } template - static void VMUL(int n, const paddle::platform::complex64 *a, - const paddle::platform::complex64 *b, - paddle::platform::complex64 *y) { + static void VMUL(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] * b[i]; } } template - static void VDIV(int n, const paddle::platform::complex64 *a, - const paddle::platform::complex64 *b, - paddle::platform::complex64 *y) { + static void VDIV(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] / b[i]; } @@ -399,11 +398,11 @@ struct CBlas { template static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex64 alpha, - const paddle::platform::complex64 *A, int lda, - const paddle::platform::complex64 *X, int incx, - paddle::platform::complex64 beta, - paddle::platform::complex64 *Y, int incy) { + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + const paddle::platform::complex *X, int incx, + paddle::platform::complex beta, + paddle::platform::complex *Y, int incy) { const void *a_ = (const void *)(A); const void *x_ = (const void *)(X); void *y_ = static_cast(Y); @@ -414,11 +413,11 @@ struct CBlas { template static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex64 alpha, - const paddle::platform::complex64 *A, int lda, - const paddle::platform::complex64 *B, int ldb, - paddle::platform::complex64 beta, - paddle::platform::complex64 *C, int ldc) { + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + const paddle::platform::complex *B, int ldb, + paddle::platform::complex beta, + paddle::platform::complex *C, int ldc) { const void *a_ = (const void *)(A); const void *b_ = (const void *)(B); void *c_ = static_cast(C); @@ -429,11 +428,12 @@ struct CBlas { template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex64 *alpha, - const paddle::platform::complex64 **A, const int *lda, - const paddle::platform::complex64 **B, const int *ldb, - paddle::platform::complex64 *beta, - paddle::platform::complex64 **C, const int *ldc, + paddle::platform::complex *alpha, + const paddle::platform::complex **A, + const int *lda, + const paddle::platform::complex **B, + const int *ldb, paddle::platform::complex *beta, + paddle::platform::complex **C, const int *ldc, int group_count, int *group_size) { const void **A_void = (const void **)(&(*A)); const void **B_void = (const void **)(&(*B)); @@ -451,11 +451,11 @@ struct CBlas { }; template <> -struct CBlas { +struct CBlas> { template - static void AXPY(int n, const paddle::platform::complex128 alpha, - const paddle::platform::complex128 *X, const int incX, - paddle::platform::complex128 *Y, const int incY) { + static void AXPY(int n, const paddle::platform::complex alpha, + const paddle::platform::complex *X, const int incX, + paddle::platform::complex *Y, const int incY) { platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); } @@ -490,35 +490,35 @@ struct CBlas { */ template - static void VADD(int n, const paddle::platform::complex128 *a, - const paddle::platform::complex128 *b, - paddle::platform::complex128 *y) { + static void VADD(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] + b[i]; } } template - static void VSUB(int n, const paddle::platform::complex128 *a, - const paddle::platform::complex128 *b, - paddle::platform::complex128 *y) { + static void VSUB(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] - b[i]; } } template - static void VMUL(int n, const paddle::platform::complex128 *a, - const paddle::platform::complex128 *b, - paddle::platform::complex128 *y) { + static void VMUL(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] * b[i]; } } template - static void VDIV(int n, const paddle::platform::complex128 *a, - const paddle::platform::complex128 *b, - paddle::platform::complex128 *y) { + static void VDIV(int n, const paddle::platform::complex *a, + const paddle::platform::complex *b, + paddle::platform::complex *y) { for (int i = 0; i < n; ++i) { y[i] = a[i] / b[i]; } @@ -526,11 +526,11 @@ struct CBlas { template static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, - paddle::platform::complex128 alpha, - const paddle::platform::complex128 *A, int lda, - const paddle::platform::complex128 *X, int incx, - paddle::platform::complex128 beta, - paddle::platform::complex128 *Y, int incy) { + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + const paddle::platform::complex *X, int incx, + paddle::platform::complex beta, + paddle::platform::complex *Y, int incy) { const void *a_ = (const void *)(A); const void *x_ = (const void *)(X); void *y_ = static_cast(Y); @@ -541,11 +541,11 @@ struct CBlas { template static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, CBLAS_TRANSPOSE trans_b, int M, int N, int K, - paddle::platform::complex128 alpha, - const paddle::platform::complex128 *A, int lda, - const paddle::platform::complex128 *B, int ldb, - paddle::platform::complex128 beta, - paddle::platform::complex128 *C, int ldc) { + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + const paddle::platform::complex *B, int ldb, + paddle::platform::complex beta, + paddle::platform::complex *C, int ldc) { const void *a_ = (const void *)(A); const void *b_ = (const void *)(B); void *c_ = static_cast(C); @@ -556,11 +556,13 @@ struct CBlas { template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, - paddle::platform::complex128 *alpha, - const paddle::platform::complex128 **A, const int *lda, - const paddle::platform::complex128 **B, const int *ldb, - paddle::platform::complex128 *beta, - paddle::platform::complex128 **C, const int *ldc, + paddle::platform::complex *alpha, + const paddle::platform::complex **A, + const int *lda, + const paddle::platform::complex **B, + const int *ldb, + paddle::platform::complex *beta, + paddle::platform::complex **C, const int *ldc, int group_count, int *group_size) { const void **A_void = (const void **)(&(*A)); const void **B_void = (const void **)(&(*B)); @@ -636,76 +638,76 @@ struct CBlas { }; template <> -struct CBlas { +struct CBlas> { template static void VCOPY(ARGS... args) { cblas_ccopy(args...); } template - static void AXPY(int n, const paddle::platform::complex64 alpha, - const paddle::platform::complex64 *X, const int incX, - paddle::platform::complex64 *Y, const int incY) { + static void AXPY(int n, const paddle::platform::complex alpha, + const paddle::platform::complex *X, const int incX, + paddle::platform::complex *Y, const int incY) { cblas_caxpy(n, &alpha, X, incX, Y, incY); } template static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, const int M, const int N, - const paddle::platform::complex64 alpha, - const paddle::platform::complex64 *A, const int lda, - const paddle::platform::complex64 *X, const int incX, - const paddle::platform::complex64 beta, - paddle::platform::complex64 *Y, const int incY) { + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + const paddle::platform::complex *X, const int incX, + const paddle::platform::complex beta, + paddle::platform::complex *Y, const int incY) { cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); } template static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex64 alpha, - const paddle::platform::complex64 *A, const int lda, - const paddle::platform::complex64 *B, const int ldb, - const paddle::platform::complex64 beta, - paddle::platform::complex64 *C, const int ldc) { + const int K, const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + const paddle::platform::complex *B, const int ldb, + const paddle::platform::complex beta, + paddle::platform::complex *C, const int ldc) { cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } }; template <> -struct CBlas { +struct CBlas> { template static void VCOPY(ARGS... args) { cblas_zcopy(args...); } template - static void AXPY(int n, const paddle::platform::complex128 alpha, - const paddle::platform::complex128 *X, const int incX, - paddle::platform::complex128 *Y, const int incY) { + static void AXPY(int n, const paddle::platform::complex alpha, + const paddle::platform::complex *X, const int incX, + paddle::platform::complex *Y, const int incY) { cblas_zaxpy(n, &alpha, X, incX, Y, incY); } template static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, const int M, const int N, - const paddle::platform::complex128 alpha, - const paddle::platform::complex128 *A, const int lda, - const paddle::platform::complex128 *X, const int incX, - const paddle::platform::complex128 beta, - paddle::platform::complex128 *Y, const int incY) { + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + const paddle::platform::complex *X, const int incX, + const paddle::platform::complex beta, + paddle::platform::complex *Y, const int incY) { cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); } template static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const paddle::platform::complex128 alpha, - const paddle::platform::complex128 *A, const int lda, - const paddle::platform::complex128 *B, const int ldb, - const paddle::platform::complex128 beta, - paddle::platform::complex128 *C, const int ldc) { + const int K, const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + const paddle::platform::complex *B, const int ldb, + const paddle::platform::complex beta, + paddle::platform::complex *C, const int ldc) { cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } diff --git a/paddle/fluid/operators/math/blas_impl.hip.h b/paddle/fluid/operators/math/blas_impl.hip.h index 81110b591a1..788ebc6ad98 100644 --- a/paddle/fluid/operators/math/blas_impl.hip.h +++ b/paddle/fluid/operators/math/blas_impl.hip.h @@ -213,13 +213,13 @@ struct CUBlas { }; template <> -struct CUBlas { - using complex64 = platform::complex64; - +struct CUBlas> { static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, - int n, const complex64 *alpha, const complex64 *A, int lda, - const complex64 *B, int ldb, const complex64 *beta, - complex64 *C, int ldc) { + int n, const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv( handle, transa, m, n, reinterpret_cast(alpha), @@ -229,9 +229,10 @@ struct CUBlas { reinterpret_cast(C), ldc)); } - static void AXPY(rocblas_handle handle, int n, const complex64 *alpha, - const complex64 *X, const int incX, complex64 *Y, - const int incY) { + static void AXPY(rocblas_handle handle, int n, + const platform::complex *alpha, + const platform::complex *X, const int incX, + platform::complex *Y, const int incY) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy( handle, n, reinterpret_cast(alpha), reinterpret_cast(X), incX, @@ -241,11 +242,13 @@ struct CUBlas { static void GEMM_STRIDED_BATCH(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, - const complex64 *alpha, const complex64 *A, - int lda, long long int strideA, // NOLINT - const complex64 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const complex64 *beta, complex64 *C, int ldc, + const platform::complex *alpha, + const platform::complex *A, int lda, + long long int strideA, // NOLINT + const platform::complex *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const platform::complex *beta, + platform::complex *C, int ldc, long long int strideC, // NOLINT int batchCount) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -261,9 +264,11 @@ struct CUBlas { static void GEMM(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, - const complex64 *alpha, const complex64 *A, int lda, - const complex64 *B, int ldb, const complex64 *beta, - complex64 *C, int ldc) { + const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm( handle, transa, transb, m, n, k, reinterpret_cast(alpha), @@ -293,13 +298,13 @@ struct CUBlas { }; template <> -struct CUBlas { - using complex128 = platform::complex128; - +struct CUBlas> { static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, - int n, const complex128 *alpha, const complex128 *A, int lda, - const complex128 *B, int ldb, const complex128 *beta, - complex128 *C, int ldc) { + int n, const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv( handle, transa, m, n, reinterpret_cast(alpha), @@ -309,9 +314,10 @@ struct CUBlas { reinterpret_cast(C), ldc)); } - static void AXPY(rocblas_handle handle, int n, const complex128 *alpha, - const complex128 *X, const int incX, complex128 *Y, - const int incY) { + static void AXPY(rocblas_handle handle, int n, + const platform::complex *alpha, + const platform::complex *X, const int incX, + platform::complex *Y, const int incY) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy( handle, n, reinterpret_cast(alpha), reinterpret_cast(X), incX, @@ -321,11 +327,13 @@ struct CUBlas { static void GEMM_STRIDED_BATCH(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, - const complex128 *alpha, const complex128 *A, - int lda, long long int strideA, // NOLINT - const complex128 *B, // NOLINT - int ldb, long long int strideB, // NOLINT - const complex128 *beta, complex128 *C, int ldc, + const platform::complex *alpha, + const platform::complex *A, int lda, + long long int strideA, // NOLINT + const platform::complex *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const platform::complex *beta, + platform::complex *C, int ldc, long long int strideC, // NOLINT int batchCount) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -341,9 +349,11 @@ struct CUBlas { static void GEMM(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, int k, - const complex128 *alpha, const complex128 *A, int lda, - const complex128 *B, int ldb, const complex128 *beta, - complex128 *C, int ldc) { + const platform::complex *alpha, + const platform::complex *A, int lda, + const platform::complex *B, int ldb, + const platform::complex *beta, + platform::complex *C, int ldc) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm( handle, transa, transb, m, n, k, reinterpret_cast(alpha), @@ -434,9 +444,9 @@ template <> template <> inline void Blas::GEMM( CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex64 alpha, const platform::complex64 *A, - const platform::complex64 *B, platform::complex64 beta, - platform::complex64 *C) const { + platform::complex alpha, const platform::complex *A, + const platform::complex *B, platform::complex beta, + platform::complex *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -461,7 +471,7 @@ inline void Blas::GEMM( thrust::complex c_beta = thrust::complex(beta.real, beta.imag); auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( + CUBlas>::GEMM_EX( &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C, rocblas_datatype_f32_c, N, rocblas_datatype_f32_c); @@ -471,9 +481,9 @@ template <> template <> inline void Blas::GEMM( CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, - platform::complex128 alpha, const platform::complex128 *A, - const platform::complex128 *B, platform::complex128 beta, - platform::complex128 *C) const { + platform::complex alpha, const platform::complex *A, + const platform::complex *B, platform::complex beta, + platform::complex *C) const { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -499,7 +509,7 @@ inline void Blas::GEMM( thrust::complex(beta.real, beta.imag); auto &cuda_ctx = const_cast(context_); - CUBlas::GEMM_EX( + CUBlas>::GEMM_EX( &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C, rocblas_datatype_f64_c, N, rocblas_datatype_f64_c); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index b9a1854a661..ee405be5ae9 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -297,7 +297,9 @@ template struct SelectedRowsAddToTensor -typename std::enable_if::value>::type +typename std::enable_if::value || + std::is_same>::value || + std::is_same>::value>::type elementwise_add_to(BlasT* blas, size_t data_len, const T* in, T* out) { blas->AXPY(data_len, T(1.f), in, out); @@ -542,9 +544,9 @@ template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; template struct MergeAdd; + paddle::platform::complex>; template struct MergeAdd; + paddle::platform::complex>; template struct MergeAdd; diff --git a/paddle/fluid/operators/matmul_v2_op.cc b/paddle/fluid/operators/matmul_v2_op.cc index 6fccd3657af..82706fd4875 100644 --- a/paddle/fluid/operators/matmul_v2_op.cc +++ b/paddle/fluid/operators/matmul_v2_op.cc @@ -204,15 +204,15 @@ REGISTER_OP_CPU_KERNEL( matmul_v2, ops::MatMulV2Kernel, ops::MatMulV2Kernel, ops::MatMulV2Kernel, + paddle::platform::complex>, ops::MatMulV2Kernel); + paddle::platform::complex>); REGISTER_OP_CPU_KERNEL( matmul_v2_grad, ops::MatMulV2GradKernel, ops::MatMulV2GradKernel, ops::MatMulV2GradKernel, + paddle::platform::complex>, ops::MatMulV2GradKernel); + paddle::platform::complex>); diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index e819398ec9b..2176ab79dd9 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -21,12 +21,12 @@ REGISTER_OP_CUDA_KERNEL( matmul_v2, ops::MatMulV2Kernel, ops::MatMulV2Kernel, ops::MatMulV2Kernel, - ops::MatMulV2Kernel, - ops::MatMulV2Kernel); + ops::MatMulV2Kernel>, + ops::MatMulV2Kernel>); REGISTER_OP_CUDA_KERNEL( matmul_v2_grad, ops::MatMulV2GradKernel, ops::MatMulV2GradKernel, ops::MatMulV2GradKernel, - ops::MatMulV2GradKernel, - ops::MatMulV2GradKernel); + ops::MatMulV2GradKernel>, + ops::MatMulV2GradKernel>); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index ca20efaad07..6061679b288 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -483,19 +483,19 @@ struct ConjHelper { }; template -struct ConjHelper { +struct ConjHelper> { explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { dst.Resize(src.dims()); - auto* src_data = src.data(); - auto* dst_data = dst.mutable_data( + auto* src_data = src.data>(); + auto* dst_data = dst.mutable_data>( ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex64))); + size_t(src.numel() * sizeof(paddle::platform::complex))); platform::ForRange for_range( ctx_.template device_context(), src.numel()); - math::ConjFunctor functor( + math::ConjFunctor> functor( src_data, src.numel(), dst_data); for_range(functor); return; @@ -504,19 +504,19 @@ struct ConjHelper { }; template -struct ConjHelper { +struct ConjHelper> { explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { dst.Resize(src.dims()); - auto* src_data = src.data(); - auto* dst_data = dst.mutable_data( + auto* src_data = src.data>(); + auto* dst_data = dst.mutable_data>( ctx_.GetPlace(), - size_t(src.numel() * sizeof(paddle::platform::complex128))); + size_t(src.numel() * sizeof(paddle::platform::complex))); platform::ForRange for_range( ctx_.template device_context(), src.numel()); - math::ConjFunctor functor( + math::ConjFunctor> functor( src_data, src.numel(), dst_data); for_range(functor); return; -- GitLab