未验证 提交 6c07cd7e 编写于 作者: C chentianyu03 提交者: GitHub

modify matmul Op to complex template types (#33130)

* modify matmul Op to complex template types

* remove complex64/128 head file
上级 8259d9bf
...@@ -24,8 +24,7 @@ ...@@ -24,8 +24,7 @@
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -200,8 +199,8 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) { ...@@ -200,8 +199,8 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated, // NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future // support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64); PADDLE_TENSOR_ADD(platform::complex<float>);
PADDLE_TENSOR_ADD(platform::complex128); PADDLE_TENSOR_ADD(platform::complex<double>);
#endif #endif
#undef PADDLE_TENSOR_ADD #undef PADDLE_TENSOR_ADD
......
...@@ -260,13 +260,13 @@ struct CUBlas<platform::float16> { ...@@ -260,13 +260,13 @@ struct CUBlas<platform::float16> {
}; };
template <> template <>
struct CUBlas<platform::complex64> { struct CUBlas<platform::complex<float>> {
using complex64 = platform::complex64;
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda, int n, const platform::complex<float> *alpha,
const complex64 *B, int ldb, const complex64 *beta, const platform::complex<float> *A, int lda,
complex64 *C, int ldc) { const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv(
handle, transa, m, n, reinterpret_cast<const cuFloatComplex *>(alpha), handle, transa, m, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda, reinterpret_cast<const cuFloatComplex *>(A), lda,
...@@ -275,9 +275,10 @@ struct CUBlas<platform::complex64> { ...@@ -275,9 +275,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast<cuFloatComplex *>(C), ldc)); reinterpret_cast<cuFloatComplex *>(C), ldc));
} }
static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha, static void AXPY(cublasHandle_t handle, int n,
const complex64 *X, const int incX, complex64 *Y, const platform::complex<float> *alpha,
const int incY) { const platform::complex<float> *X, const int incX,
platform::complex<float> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy(
handle, n, reinterpret_cast<const cuFloatComplex *>(alpha), handle, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(X), incX, reinterpret_cast<const cuFloatComplex *>(X), incX,
...@@ -287,11 +288,13 @@ struct CUBlas<platform::complex64> { ...@@ -287,11 +288,13 @@ struct CUBlas<platform::complex64> {
static void GEMM_STRIDED_BATCH(cublasHandle_t handle, static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, const platform::complex<float> *alpha,
int lda, long long int strideA, // NOLINT const platform::complex<float> *A, int lda,
const complex64 *B, // NOLINT long long int strideA, // NOLINT
int ldb, long long int strideB, // NOLINT const platform::complex<float> *B, // NOLINT
const complex64 *beta, complex64 *C, int ldc, int ldb, long long int strideB, // NOLINT
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc,
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
...@@ -310,9 +313,11 @@ struct CUBlas<platform::complex64> { ...@@ -310,9 +313,11 @@ struct CUBlas<platform::complex64> {
static void GEMM(cublasHandle_t handle, cublasOperation_t transa, static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda, const platform::complex<float> *alpha,
const complex64 *B, int ldb, const complex64 *beta, const platform::complex<float> *A, int lda,
complex64 *C, int ldc) { const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
reinterpret_cast<const cuFloatComplex *>(alpha), reinterpret_cast<const cuFloatComplex *>(alpha),
...@@ -356,13 +361,13 @@ struct CUBlas<platform::complex64> { ...@@ -356,13 +361,13 @@ struct CUBlas<platform::complex64> {
}; };
template <> template <>
struct CUBlas<platform::complex128> { struct CUBlas<platform::complex<double>> {
using complex128 = platform::complex128;
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m, static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda, int n, const platform::complex<double> *alpha,
const complex128 *B, int ldb, const complex128 *beta, const platform::complex<double> *A, int lda,
complex128 *C, int ldc) { const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv(
handle, transa, m, n, reinterpret_cast<const cuDoubleComplex *>(alpha), handle, transa, m, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda, reinterpret_cast<const cuDoubleComplex *>(A), lda,
...@@ -371,9 +376,10 @@ struct CUBlas<platform::complex128> { ...@@ -371,9 +376,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast<cuDoubleComplex *>(C), ldc)); reinterpret_cast<cuDoubleComplex *>(C), ldc));
} }
static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha, static void AXPY(cublasHandle_t handle, int n,
const complex128 *X, const int incX, complex128 *Y, const platform::complex<double> *alpha,
const int incY) { const platform::complex<double> *X, const int incX,
platform::complex<double> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy(
handle, n, reinterpret_cast<const cuDoubleComplex *>(alpha), handle, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(X), incX, reinterpret_cast<const cuDoubleComplex *>(X), incX,
...@@ -383,11 +389,13 @@ struct CUBlas<platform::complex128> { ...@@ -383,11 +389,13 @@ struct CUBlas<platform::complex128> {
static void GEMM_STRIDED_BATCH(cublasHandle_t handle, static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, const platform::complex<double> *alpha,
int lda, long long int strideA, // NOLINT const platform::complex<double> *A, int lda,
const complex128 *B, // NOLINT long long int strideA, // NOLINT
int ldb, long long int strideB, // NOLINT const platform::complex<double> *B, // NOLINT
const complex128 *beta, complex128 *C, int ldc, int ldb, long long int strideB, // NOLINT
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc,
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
...@@ -406,9 +414,11 @@ struct CUBlas<platform::complex128> { ...@@ -406,9 +414,11 @@ struct CUBlas<platform::complex128> {
static void GEMM(cublasHandle_t handle, cublasOperation_t transa, static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda, const platform::complex<double> *alpha,
const complex128 *B, int ldb, const complex128 *beta, const platform::complex<double> *A, int lda,
complex128 *C, int ldc) { const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
reinterpret_cast<const cuDoubleComplex *>(alpha), reinterpret_cast<const cuDoubleComplex *>(alpha),
...@@ -535,9 +545,9 @@ template <> ...@@ -535,9 +545,9 @@ template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A, platform::complex<float> alpha, const platform::complex<float> *A,
const platform::complex64 *B, platform::complex64 beta, const platform::complex<float> *B, platform::complex<float> beta,
platform::complex64 *C) const { platform::complex<float> *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -565,16 +575,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -565,16 +575,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated // input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs. // using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_); auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX( CUBlas<platform::complex<float>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A, &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A,
CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F); CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) { context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex64>::GEMM(handle, cuTransB, cuTransA, N, M, K, CUBlas<platform::complex<float>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta, &c_alpha, h_B, ldb, h_A, lda,
h_C, N); &c_beta, h_C, N);
}); });
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
...@@ -583,9 +593,9 @@ template <> ...@@ -583,9 +593,9 @@ template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A, platform::complex<double> alpha, const platform::complex<double> *A,
const platform::complex128 *B, platform::complex128 beta, const platform::complex<double> *B, platform::complex<double> beta,
platform::complex128 *C) const { platform::complex<double> *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -614,16 +624,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -614,16 +624,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated // input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs. // using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_); auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX( CUBlas<platform::complex<double>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A, &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A,
CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F); CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F);
#else #else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) { context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex128>::GEMM(handle, cuTransB, cuTransA, N, M, K, CUBlas<platform::complex<double>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta, &c_alpha, h_B, ldb, h_A, lda,
h_C, N); &c_beta, h_C, N);
}); });
#endif // CUDA_VERSION >= 8000 #endif // CUDA_VERSION >= 8000
} }
......
...@@ -23,8 +23,7 @@ ...@@ -23,8 +23,7 @@
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -324,11 +323,11 @@ struct CBlas<double> { ...@@ -324,11 +323,11 @@ struct CBlas<double> {
}; };
template <> template <>
struct CBlas<platform::complex64> { struct CBlas<platform::complex<float>> {
template <typename... ARGS> template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha, static void AXPY(int n, const paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *X, const int incX, const paddle::platform::complex<float> *X, const int incX,
paddle::platform::complex64 *Y, const int incY) { paddle::platform::complex<float> *Y, const int incY) {
platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY); platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY);
} }
...@@ -363,35 +362,35 @@ struct CBlas<platform::complex64> { ...@@ -363,35 +362,35 @@ struct CBlas<platform::complex64> {
*/ */
template <typename... ARGS> template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex64 *a, static void VADD(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex64 *b, const paddle::platform::complex<float> *b,
paddle::platform::complex64 *y) { paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i]; y[i] = a[i] + b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex64 *a, static void VSUB(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex64 *b, const paddle::platform::complex<float> *b,
paddle::platform::complex64 *y) { paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i]; y[i] = a[i] - b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex64 *a, static void VMUL(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex64 *b, const paddle::platform::complex<float> *b,
paddle::platform::complex64 *y) { paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i]; y[i] = a[i] * b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex64 *a, static void VDIV(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex64 *b, const paddle::platform::complex<float> *b,
paddle::platform::complex64 *y) { paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i]; y[i] = a[i] / b[i];
} }
...@@ -399,11 +398,11 @@ struct CBlas<platform::complex64> { ...@@ -399,11 +398,11 @@ struct CBlas<platform::complex64> {
template <typename... ARGS> template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex64 alpha, paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *A, int lda, const paddle::platform::complex<float> *A, int lda,
const paddle::platform::complex64 *X, int incx, const paddle::platform::complex<float> *X, int incx,
paddle::platform::complex64 beta, paddle::platform::complex<float> beta,
paddle::platform::complex64 *Y, int incy) { paddle::platform::complex<float> *Y, int incy) {
const void *a_ = (const void *)(A); const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X); const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y); void *y_ = static_cast<void *>(Y);
...@@ -414,11 +413,11 @@ struct CBlas<platform::complex64> { ...@@ -414,11 +413,11 @@ struct CBlas<platform::complex64> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K, CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex64 alpha, paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *A, int lda, const paddle::platform::complex<float> *A, int lda,
const paddle::platform::complex64 *B, int ldb, const paddle::platform::complex<float> *B, int ldb,
paddle::platform::complex64 beta, paddle::platform::complex<float> beta,
paddle::platform::complex64 *C, int ldc) { paddle::platform::complex<float> *C, int ldc) {
const void *a_ = (const void *)(A); const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B); const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C); void *c_ = static_cast<void *>(C);
...@@ -429,11 +428,12 @@ struct CBlas<platform::complex64> { ...@@ -429,11 +428,12 @@ struct CBlas<platform::complex64> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex64 *alpha, paddle::platform::complex<float> *alpha,
const paddle::platform::complex64 **A, const int *lda, const paddle::platform::complex<float> **A,
const paddle::platform::complex64 **B, const int *ldb, const int *lda,
paddle::platform::complex64 *beta, const paddle::platform::complex<float> **B,
paddle::platform::complex64 **C, const int *ldc, const int *ldb, paddle::platform::complex<float> *beta,
paddle::platform::complex<float> **C, const int *ldc,
int group_count, int *group_size) { int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A)); const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B)); const void **B_void = (const void **)(&(*B));
...@@ -451,11 +451,11 @@ struct CBlas<platform::complex64> { ...@@ -451,11 +451,11 @@ struct CBlas<platform::complex64> {
}; };
template <> template <>
struct CBlas<platform::complex128> { struct CBlas<platform::complex<double>> {
template <typename... ARGS> template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha, static void AXPY(int n, const paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *X, const int incX, const paddle::platform::complex<double> *X, const int incX,
paddle::platform::complex128 *Y, const int incY) { paddle::platform::complex<double> *Y, const int incY) {
platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY); platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY);
} }
...@@ -490,35 +490,35 @@ struct CBlas<platform::complex128> { ...@@ -490,35 +490,35 @@ struct CBlas<platform::complex128> {
*/ */
template <typename... ARGS> template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex128 *a, static void VADD(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex128 *b, const paddle::platform::complex<double> *b,
paddle::platform::complex128 *y) { paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i]; y[i] = a[i] + b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex128 *a, static void VSUB(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex128 *b, const paddle::platform::complex<double> *b,
paddle::platform::complex128 *y) { paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i]; y[i] = a[i] - b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex128 *a, static void VMUL(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex128 *b, const paddle::platform::complex<double> *b,
paddle::platform::complex128 *y) { paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i]; y[i] = a[i] * b[i];
} }
} }
template <typename... ARGS> template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex128 *a, static void VDIV(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex128 *b, const paddle::platform::complex<double> *b,
paddle::platform::complex128 *y) { paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i]; y[i] = a[i] / b[i];
} }
...@@ -526,11 +526,11 @@ struct CBlas<platform::complex128> { ...@@ -526,11 +526,11 @@ struct CBlas<platform::complex128> {
template <typename... ARGS> template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N, static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex128 alpha, paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *A, int lda, const paddle::platform::complex<double> *A, int lda,
const paddle::platform::complex128 *X, int incx, const paddle::platform::complex<double> *X, int incx,
paddle::platform::complex128 beta, paddle::platform::complex<double> beta,
paddle::platform::complex128 *Y, int incy) { paddle::platform::complex<double> *Y, int incy) {
const void *a_ = (const void *)(A); const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X); const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y); void *y_ = static_cast<void *>(Y);
...@@ -541,11 +541,11 @@ struct CBlas<platform::complex128> { ...@@ -541,11 +541,11 @@ struct CBlas<platform::complex128> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a, static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K, CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex128 alpha, paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *A, int lda, const paddle::platform::complex<double> *A, int lda,
const paddle::platform::complex128 *B, int ldb, const paddle::platform::complex<double> *B, int ldb,
paddle::platform::complex128 beta, paddle::platform::complex<double> beta,
paddle::platform::complex128 *C, int ldc) { paddle::platform::complex<double> *C, int ldc) {
const void *a_ = (const void *)(A); const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B); const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C); void *c_ = static_cast<void *>(C);
...@@ -556,11 +556,13 @@ struct CBlas<platform::complex128> { ...@@ -556,11 +556,13 @@ struct CBlas<platform::complex128> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex128 *alpha, paddle::platform::complex<double> *alpha,
const paddle::platform::complex128 **A, const int *lda, const paddle::platform::complex<double> **A,
const paddle::platform::complex128 **B, const int *ldb, const int *lda,
paddle::platform::complex128 *beta, const paddle::platform::complex<double> **B,
paddle::platform::complex128 **C, const int *ldc, const int *ldb,
paddle::platform::complex<double> *beta,
paddle::platform::complex<double> **C, const int *ldc,
int group_count, int *group_size) { int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A)); const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B)); const void **B_void = (const void **)(&(*B));
...@@ -636,76 +638,76 @@ struct CBlas<double> { ...@@ -636,76 +638,76 @@ struct CBlas<double> {
}; };
template <> template <>
struct CBlas<platform::complex64> { struct CBlas<platform::complex<float>> {
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
cblas_ccopy(args...); cblas_ccopy(args...);
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha, static void AXPY(int n, const paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *X, const int incX, const paddle::platform::complex<float> *X, const int incX,
paddle::platform::complex64 *Y, const int incY) { paddle::platform::complex<float> *Y, const int incY) {
cblas_caxpy(n, &alpha, X, incX, Y, incY); cblas_caxpy(n, &alpha, X, incX, Y, incY);
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N, const int M, const int N,
const paddle::platform::complex64 alpha, const paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *A, const int lda, const paddle::platform::complex<float> *A, const int lda,
const paddle::platform::complex64 *X, const int incX, const paddle::platform::complex<float> *X, const int incX,
const paddle::platform::complex64 beta, const paddle::platform::complex<float> beta,
paddle::platform::complex64 *Y, const int incY) { paddle::platform::complex<float> *Y, const int incY) {
cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex64 alpha, const int K, const paddle::platform::complex<float> alpha,
const paddle::platform::complex64 *A, const int lda, const paddle::platform::complex<float> *A, const int lda,
const paddle::platform::complex64 *B, const int ldb, const paddle::platform::complex<float> *B, const int ldb,
const paddle::platform::complex64 beta, const paddle::platform::complex<float> beta,
paddle::platform::complex64 *C, const int ldc) { paddle::platform::complex<float> *C, const int ldc) {
cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc); C, ldc);
} }
}; };
template <> template <>
struct CBlas<platform::complex128> { struct CBlas<platform::complex<double>> {
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
cblas_zcopy(args...); cblas_zcopy(args...);
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha, static void AXPY(int n, const paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *X, const int incX, const paddle::platform::complex<double> *X, const int incX,
paddle::platform::complex128 *Y, const int incY) { paddle::platform::complex<double> *Y, const int incY) {
cblas_zaxpy(n, &alpha, X, incX, Y, incY); cblas_zaxpy(n, &alpha, X, incX, Y, incY);
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N, const int M, const int N,
const paddle::platform::complex128 alpha, const paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *A, const int lda, const paddle::platform::complex<double> *A, const int lda,
const paddle::platform::complex128 *X, const int incX, const paddle::platform::complex<double> *X, const int incX,
const paddle::platform::complex128 beta, const paddle::platform::complex<double> beta,
paddle::platform::complex128 *Y, const int incY) { paddle::platform::complex<double> *Y, const int incY) {
cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY); cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA, static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N, const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex128 alpha, const int K, const paddle::platform::complex<double> alpha,
const paddle::platform::complex128 *A, const int lda, const paddle::platform::complex<double> *A, const int lda,
const paddle::platform::complex128 *B, const int ldb, const paddle::platform::complex<double> *B, const int ldb,
const paddle::platform::complex128 beta, const paddle::platform::complex<double> beta,
paddle::platform::complex128 *C, const int ldc) { paddle::platform::complex<double> *C, const int ldc) {
cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc); C, ldc);
} }
......
...@@ -213,13 +213,13 @@ struct CUBlas<platform::float16> { ...@@ -213,13 +213,13 @@ struct CUBlas<platform::float16> {
}; };
template <> template <>
struct CUBlas<platform::complex64> { struct CUBlas<platform::complex<float>> {
using complex64 = platform::complex64;
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda, int n, const platform::complex<float> *alpha,
const complex64 *B, int ldb, const complex64 *beta, const platform::complex<float> *A, int lda,
complex64 *C, int ldc) { const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv(
handle, transa, m, n, handle, transa, m, n,
reinterpret_cast<const rocblas_float_complex *>(alpha), reinterpret_cast<const rocblas_float_complex *>(alpha),
...@@ -229,9 +229,10 @@ struct CUBlas<platform::complex64> { ...@@ -229,9 +229,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast<rocblas_float_complex *>(C), ldc)); reinterpret_cast<rocblas_float_complex *>(C), ldc));
} }
static void AXPY(rocblas_handle handle, int n, const complex64 *alpha, static void AXPY(rocblas_handle handle, int n,
const complex64 *X, const int incX, complex64 *Y, const platform::complex<float> *alpha,
const int incY) { const platform::complex<float> *X, const int incX,
platform::complex<float> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy(
handle, n, reinterpret_cast<const rocblas_float_complex *>(alpha), handle, n, reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(X), incX, reinterpret_cast<const rocblas_float_complex *>(X), incX,
...@@ -241,11 +242,13 @@ struct CUBlas<platform::complex64> { ...@@ -241,11 +242,13 @@ struct CUBlas<platform::complex64> {
static void GEMM_STRIDED_BATCH(rocblas_handle handle, static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k, rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, const platform::complex<float> *alpha,
int lda, long long int strideA, // NOLINT const platform::complex<float> *A, int lda,
const complex64 *B, // NOLINT long long int strideA, // NOLINT
int ldb, long long int strideB, // NOLINT const platform::complex<float> *B, // NOLINT
const complex64 *beta, complex64 *C, int ldc, int ldb, long long int strideB, // NOLINT
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc,
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -261,9 +264,11 @@ struct CUBlas<platform::complex64> { ...@@ -261,9 +264,11 @@ struct CUBlas<platform::complex64> {
static void GEMM(rocblas_handle handle, rocblas_operation transa, static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k, rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda, const platform::complex<float> *alpha,
const complex64 *B, int ldb, const complex64 *beta, const platform::complex<float> *A, int lda,
complex64 *C, int ldc) { const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_float_complex *>(alpha), reinterpret_cast<const rocblas_float_complex *>(alpha),
...@@ -293,13 +298,13 @@ struct CUBlas<platform::complex64> { ...@@ -293,13 +298,13 @@ struct CUBlas<platform::complex64> {
}; };
template <> template <>
struct CUBlas<platform::complex128> { struct CUBlas<platform::complex<double>> {
using complex128 = platform::complex128;
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m, static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda, int n, const platform::complex<double> *alpha,
const complex128 *B, int ldb, const complex128 *beta, const platform::complex<double> *A, int lda,
complex128 *C, int ldc) { const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv(
handle, transa, m, n, handle, transa, m, n,
reinterpret_cast<const rocblas_double_complex *>(alpha), reinterpret_cast<const rocblas_double_complex *>(alpha),
...@@ -309,9 +314,10 @@ struct CUBlas<platform::complex128> { ...@@ -309,9 +314,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast<rocblas_double_complex *>(C), ldc)); reinterpret_cast<rocblas_double_complex *>(C), ldc));
} }
static void AXPY(rocblas_handle handle, int n, const complex128 *alpha, static void AXPY(rocblas_handle handle, int n,
const complex128 *X, const int incX, complex128 *Y, const platform::complex<double> *alpha,
const int incY) { const platform::complex<double> *X, const int incX,
platform::complex<double> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy(
handle, n, reinterpret_cast<const rocblas_double_complex *>(alpha), handle, n, reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(X), incX, reinterpret_cast<const rocblas_double_complex *>(X), incX,
...@@ -321,11 +327,13 @@ struct CUBlas<platform::complex128> { ...@@ -321,11 +327,13 @@ struct CUBlas<platform::complex128> {
static void GEMM_STRIDED_BATCH(rocblas_handle handle, static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k, rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, const platform::complex<double> *alpha,
int lda, long long int strideA, // NOLINT const platform::complex<double> *A, int lda,
const complex128 *B, // NOLINT long long int strideA, // NOLINT
int ldb, long long int strideB, // NOLINT const platform::complex<double> *B, // NOLINT
const complex128 *beta, complex128 *C, int ldc, int ldb, long long int strideB, // NOLINT
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc,
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -341,9 +349,11 @@ struct CUBlas<platform::complex128> { ...@@ -341,9 +349,11 @@ struct CUBlas<platform::complex128> {
static void GEMM(rocblas_handle handle, rocblas_operation transa, static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k, rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda, const platform::complex<double> *alpha,
const complex128 *B, int ldb, const complex128 *beta, const platform::complex<double> *A, int lda,
complex128 *C, int ldc) { const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm(
handle, transa, transb, m, n, k, handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_double_complex *>(alpha), reinterpret_cast<const rocblas_double_complex *>(alpha),
...@@ -434,9 +444,9 @@ template <> ...@@ -434,9 +444,9 @@ template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A, platform::complex<float> alpha, const platform::complex<float> *A,
const platform::complex64 *B, platform::complex64 beta, const platform::complex<float> *B, platform::complex<float> beta,
platform::complex64 *C) const { platform::complex<float> *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -461,7 +471,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -461,7 +471,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag); thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_); auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX( CUBlas<platform::complex<float>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C, rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C,
rocblas_datatype_f32_c, N, rocblas_datatype_f32_c); rocblas_datatype_f32_c, N, rocblas_datatype_f32_c);
...@@ -471,9 +481,9 @@ template <> ...@@ -471,9 +481,9 @@ template <>
template <> template <>
inline void Blas<platform::CUDADeviceContext>::GEMM( inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A, platform::complex<double> alpha, const platform::complex<double> *A,
const platform::complex128 *B, platform::complex128 beta, const platform::complex<double> *B, platform::complex<double> beta,
platform::complex128 *C) const { platform::complex<double> *C) const {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M; int lda = (transA == CblasNoTrans) ? K : M;
...@@ -499,7 +509,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM( ...@@ -499,7 +509,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust::complex<double>(beta.real, beta.imag); thrust::complex<double>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_); auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX( CUBlas<platform::complex<double>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, &cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C, rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C,
rocblas_datatype_f64_c, N, rocblas_datatype_f64_c); rocblas_datatype_f64_c, N, rocblas_datatype_f64_c);
......
...@@ -297,7 +297,9 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, ...@@ -297,7 +297,9 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
namespace scatter { namespace scatter {
template <typename T> template <typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len, elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) { const T* in, T* out) {
blas->AXPY(data_len, T(1.f), in, out); blas->AXPY(data_len, T(1.f), in, out);
...@@ -542,9 +544,9 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>; ...@@ -542,9 +544,9 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>; template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>; template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex64>; paddle::platform::complex<float>>;
template struct MergeAdd<platform::CPUDeviceContext, template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex128>; paddle::platform::complex<double>>;
template struct MergeAdd<platform::CPUDeviceContext, template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>; paddle::platform::bfloat16>;
......
...@@ -204,15 +204,15 @@ REGISTER_OP_CPU_KERNEL( ...@@ -204,15 +204,15 @@ REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>, matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
matmul_v2_grad, matmul_v2_grad,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>, ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>, ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>, paddle::platform::complex<float>>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>); paddle::platform::complex<double>>);
...@@ -21,12 +21,12 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -21,12 +21,12 @@ REGISTER_OP_CUDA_KERNEL(
matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>, matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>, ops::MatMulV2Kernel<plf::CUDADeviceContext, double>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>, ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex64>, ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex<float>>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex128>); ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>, matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>, ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>, ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex64>, ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<float>>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex128>); ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<double>>);
...@@ -483,19 +483,19 @@ struct ConjHelper { ...@@ -483,19 +483,19 @@ struct ConjHelper {
}; };
template <typename DeviceContext> template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> { struct ConjHelper<DeviceContext, paddle::platform::complex<float>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims()); dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>(); auto* src_data = src.data<paddle::platform::complex<float>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>( auto* dst_data = dst.mutable_data<paddle::platform::complex<float>>(
ctx_.GetPlace(), ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64))); size_t(src.numel() * sizeof(paddle::platform::complex<float>)));
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel()); ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor( math::ConjFunctor<paddle::platform::complex<float>> functor(
src_data, src.numel(), dst_data); src_data, src.numel(), dst_data);
for_range(functor); for_range(functor);
return; return;
...@@ -504,19 +504,19 @@ struct ConjHelper<DeviceContext, paddle::platform::complex64> { ...@@ -504,19 +504,19 @@ struct ConjHelper<DeviceContext, paddle::platform::complex64> {
}; };
template <typename DeviceContext> template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> { struct ConjHelper<DeviceContext, paddle::platform::complex<double>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {} explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) { HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims()); dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>(); auto* src_data = src.data<paddle::platform::complex<double>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>( auto* dst_data = dst.mutable_data<paddle::platform::complex<double>>(
ctx_.GetPlace(), ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128))); size_t(src.numel() * sizeof(paddle::platform::complex<double>)));
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel()); ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor( math::ConjFunctor<paddle::platform::complex<double>> functor(
src_data, src.numel(), dst_data); src_data, src.numel(), dst_data);
for_range(functor); for_range(functor);
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册