未验证 提交 6c07cd7e 编写于 作者: C chentianyu03 提交者: GitHub

modify matmul Op to complex template types (#33130)

* modify matmul Op to complex template types

* remove complex64/128 head file
上级 8259d9bf
......@@ -24,8 +24,7 @@
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -200,8 +199,8 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64);
PADDLE_TENSOR_ADD(platform::complex128);
PADDLE_TENSOR_ADD(platform::complex<float>);
PADDLE_TENSOR_ADD(platform::complex<double>);
#endif
#undef PADDLE_TENSOR_ADD
......
......@@ -260,13 +260,13 @@ struct CUBlas<platform::float16> {
};
template <>
struct CUBlas<platform::complex64> {
using complex64 = platform::complex64;
struct CUBlas<platform::complex<float>> {
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
int n, const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemv(
handle, transa, m, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda,
......@@ -275,9 +275,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast<cuFloatComplex *>(C), ldc));
}
static void AXPY(cublasHandle_t handle, int n, const complex64 *alpha,
const complex64 *X, const int incX, complex64 *Y,
const int incY) {
static void AXPY(cublasHandle_t handle, int n,
const platform::complex<float> *alpha,
const platform::complex<float> *X, const int incX,
platform::complex<float> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCaxpy(
handle, n, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(X), incX,
......@@ -287,11 +288,13 @@ struct CUBlas<platform::complex64> {
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A,
int lda, long long int strideA, // NOLINT
const complex64 *B, // NOLINT
const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
long long int strideA, // NOLINT
const platform::complex<float> *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex64 *beta, complex64 *C, int ldc,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000
......@@ -310,9 +313,11 @@ struct CUBlas<platform::complex64> {
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuFloatComplex *>(alpha),
......@@ -356,13 +361,13 @@ struct CUBlas<platform::complex64> {
};
template <>
struct CUBlas<platform::complex128> {
using complex128 = platform::complex128;
struct CUBlas<platform::complex<double>> {
static void GEMV(cublasHandle_t handle, cublasOperation_t transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
int n, const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemv(
handle, transa, m, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
......@@ -371,9 +376,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}
static void AXPY(cublasHandle_t handle, int n, const complex128 *alpha,
const complex128 *X, const int incX, complex128 *Y,
const int incY) {
static void AXPY(cublasHandle_t handle, int n,
const platform::complex<double> *alpha,
const platform::complex<double> *X, const int incX,
platform::complex<double> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZaxpy(
handle, n, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(X), incX,
......@@ -383,11 +389,13 @@ struct CUBlas<platform::complex128> {
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A,
int lda, long long int strideA, // NOLINT
const complex128 *B, // NOLINT
const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
long long int strideA, // NOLINT
const platform::complex<double> *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex128 *beta, complex128 *C, int ldc,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
#if CUDA_VERSION >= 8000
......@@ -406,9 +414,11 @@ struct CUBlas<platform::complex128> {
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const cuDoubleComplex *>(alpha),
......@@ -535,9 +545,9 @@ template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A,
const platform::complex64 *B, platform::complex64 beta,
platform::complex64 *C) const {
platform::complex<float> alpha, const platform::complex<float> *A,
const platform::complex<float> *B, platform::complex<float> beta,
platform::complex<float> *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -565,16 +575,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX(
CUBlas<platform::complex<float>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_32F, ldb, A,
CUDA_C_32F, lda, &c_beta, C, CUDA_C_32F, N, CUDA_C_32F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex64>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta,
h_C, N);
CUBlas<platform::complex<float>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda,
&c_beta, h_C, N);
});
#endif // CUDA_VERSION >= 8000
}
......@@ -583,9 +593,9 @@ template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A,
const platform::complex128 *B, platform::complex128 beta,
platform::complex128 *C) const {
platform::complex<double> alpha, const platform::complex<double> *A,
const platform::complex<double> *B, platform::complex<double> beta,
platform::complex<double> *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -614,16 +624,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
// input/output in fp16, computation in fp32, which can also be accelerated
// using tensor cores in volta GPUs.
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX(
CUBlas<platform::complex<double>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B, CUDA_C_64F, ldb, A,
CUDA_C_64F, lda, &c_beta, C, CUDA_C_64F, N, CUDA_C_64F);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<platform::complex128>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda, &c_beta,
h_C, N);
CUBlas<platform::complex<double>>::GEMM(handle, cuTransB, cuTransA, N, M, K,
&c_alpha, h_B, ldb, h_A, lda,
&c_beta, h_C, N);
});
#endif // CUDA_VERSION >= 8000
}
......
......@@ -23,8 +23,7 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
namespace paddle {
namespace operators {
......@@ -324,11 +323,11 @@ struct CBlas<double> {
};
template <>
struct CBlas<platform::complex64> {
struct CBlas<platform::complex<float>> {
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *X, const int incX,
paddle::platform::complex64 *Y, const int incY) {
static void AXPY(int n, const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *X, const int incX,
paddle::platform::complex<float> *Y, const int incY) {
platform::dynload::cblas_caxpy(n, &alpha, X, incX, Y, incY);
}
......@@ -363,35 +362,35 @@ struct CBlas<platform::complex64> {
*/
template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
static void VADD(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex<float> *b,
paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i];
}
}
template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
static void VSUB(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex<float> *b,
paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i];
}
}
template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
static void VMUL(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex<float> *b,
paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i];
}
}
template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex64 *a,
const paddle::platform::complex64 *b,
paddle::platform::complex64 *y) {
static void VDIV(int n, const paddle::platform::complex<float> *a,
const paddle::platform::complex<float> *b,
paddle::platform::complex<float> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i];
}
......@@ -399,11 +398,11 @@ struct CBlas<platform::complex64> {
template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, int lda,
const paddle::platform::complex64 *X, int incx,
paddle::platform::complex64 beta,
paddle::platform::complex64 *Y, int incy) {
paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, int lda,
const paddle::platform::complex<float> *X, int incx,
paddle::platform::complex<float> beta,
paddle::platform::complex<float> *Y, int incy) {
const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y);
......@@ -414,11 +413,11 @@ struct CBlas<platform::complex64> {
template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, int lda,
const paddle::platform::complex64 *B, int ldb,
paddle::platform::complex64 beta,
paddle::platform::complex64 *C, int ldc) {
paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, int lda,
const paddle::platform::complex<float> *B, int ldb,
paddle::platform::complex<float> beta,
paddle::platform::complex<float> *C, int ldc) {
const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C);
......@@ -429,11 +428,12 @@ struct CBlas<platform::complex64> {
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex64 *alpha,
const paddle::platform::complex64 **A, const int *lda,
const paddle::platform::complex64 **B, const int *ldb,
paddle::platform::complex64 *beta,
paddle::platform::complex64 **C, const int *ldc,
paddle::platform::complex<float> *alpha,
const paddle::platform::complex<float> **A,
const int *lda,
const paddle::platform::complex<float> **B,
const int *ldb, paddle::platform::complex<float> *beta,
paddle::platform::complex<float> **C, const int *ldc,
int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B));
......@@ -451,11 +451,11 @@ struct CBlas<platform::complex64> {
};
template <>
struct CBlas<platform::complex128> {
struct CBlas<platform::complex<double>> {
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *X, const int incX,
paddle::platform::complex128 *Y, const int incY) {
static void AXPY(int n, const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *X, const int incX,
paddle::platform::complex<double> *Y, const int incY) {
platform::dynload::cblas_zaxpy(n, &alpha, X, incX, Y, incY);
}
......@@ -490,35 +490,35 @@ struct CBlas<platform::complex128> {
*/
template <typename... ARGS>
static void VADD(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
static void VADD(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex<double> *b,
paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] + b[i];
}
}
template <typename... ARGS>
static void VSUB(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
static void VSUB(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex<double> *b,
paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] - b[i];
}
}
template <typename... ARGS>
static void VMUL(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
static void VMUL(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex<double> *b,
paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] * b[i];
}
}
template <typename... ARGS>
static void VDIV(int n, const paddle::platform::complex128 *a,
const paddle::platform::complex128 *b,
paddle::platform::complex128 *y) {
static void VDIV(int n, const paddle::platform::complex<double> *a,
const paddle::platform::complex<double> *b,
paddle::platform::complex<double> *y) {
for (int i = 0; i < n; ++i) {
y[i] = a[i] / b[i];
}
......@@ -526,11 +526,11 @@ struct CBlas<platform::complex128> {
template <typename... ARGS>
static void GEMV(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans, int M, int N,
paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, int lda,
const paddle::platform::complex128 *X, int incx,
paddle::platform::complex128 beta,
paddle::platform::complex128 *Y, int incy) {
paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, int lda,
const paddle::platform::complex<double> *X, int incx,
paddle::platform::complex<double> beta,
paddle::platform::complex<double> *Y, int incy) {
const void *a_ = (const void *)(A);
const void *x_ = (const void *)(X);
void *y_ = static_cast<void *>(Y);
......@@ -541,11 +541,11 @@ struct CBlas<platform::complex128> {
template <typename... ARGS>
static void GEMM(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE trans_a,
CBLAS_TRANSPOSE trans_b, int M, int N, int K,
paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, int lda,
const paddle::platform::complex128 *B, int ldb,
paddle::platform::complex128 beta,
paddle::platform::complex128 *C, int ldc) {
paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, int lda,
const paddle::platform::complex<double> *B, int ldb,
paddle::platform::complex<double> beta,
paddle::platform::complex<double> *C, int ldc) {
const void *a_ = (const void *)(A);
const void *b_ = (const void *)(B);
void *c_ = static_cast<void *>(C);
......@@ -556,11 +556,13 @@ struct CBlas<platform::complex128> {
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
paddle::platform::complex128 *alpha,
const paddle::platform::complex128 **A, const int *lda,
const paddle::platform::complex128 **B, const int *ldb,
paddle::platform::complex128 *beta,
paddle::platform::complex128 **C, const int *ldc,
paddle::platform::complex<double> *alpha,
const paddle::platform::complex<double> **A,
const int *lda,
const paddle::platform::complex<double> **B,
const int *ldb,
paddle::platform::complex<double> *beta,
paddle::platform::complex<double> **C, const int *ldc,
int group_count, int *group_size) {
const void **A_void = (const void **)(&(*A));
const void **B_void = (const void **)(&(*B));
......@@ -636,76 +638,76 @@ struct CBlas<double> {
};
template <>
struct CBlas<platform::complex64> {
struct CBlas<platform::complex<float>> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_ccopy(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *X, const int incX,
paddle::platform::complex64 *Y, const int incY) {
static void AXPY(int n, const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *X, const int incX,
paddle::platform::complex<float> *Y, const int incY) {
cblas_caxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N,
const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, const int lda,
const paddle::platform::complex64 *X, const int incX,
const paddle::platform::complex64 beta,
paddle::platform::complex64 *Y, const int incY) {
const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, const int lda,
const paddle::platform::complex<float> *X, const int incX,
const paddle::platform::complex<float> beta,
paddle::platform::complex<float> *Y, const int incY) {
cblas_cgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex64 alpha,
const paddle::platform::complex64 *A, const int lda,
const paddle::platform::complex64 *B, const int ldb,
const paddle::platform::complex64 beta,
paddle::platform::complex64 *C, const int ldc) {
const int K, const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, const int lda,
const paddle::platform::complex<float> *B, const int ldb,
const paddle::platform::complex<float> beta,
paddle::platform::complex<float> *C, const int ldc) {
cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
};
template <>
struct CBlas<platform::complex128> {
struct CBlas<platform::complex<double>> {
template <typename... ARGS>
static void VCOPY(ARGS... args) {
cblas_zcopy(args...);
}
template <typename... ARGS>
static void AXPY(int n, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *X, const int incX,
paddle::platform::complex128 *Y, const int incY) {
static void AXPY(int n, const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *X, const int incX,
paddle::platform::complex<double> *Y, const int incY) {
cblas_zaxpy(n, &alpha, X, incX, Y, incY);
}
template <typename... ARGS>
static void GEMV(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const int M, const int N,
const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, const int lda,
const paddle::platform::complex128 *X, const int incX,
const paddle::platform::complex128 beta,
paddle::platform::complex128 *Y, const int incY) {
const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, const int lda,
const paddle::platform::complex<double> *X, const int incX,
const paddle::platform::complex<double> beta,
paddle::platform::complex<double> *Y, const int incY) {
cblas_zgemv(layout, TransA, M, N, &alpha, A, lda, X, incX, &beta, Y, incY);
}
template <typename... ARGS>
static void GEMM(const CBLAS_LAYOUT layout, const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const paddle::platform::complex128 alpha,
const paddle::platform::complex128 *A, const int lda,
const paddle::platform::complex128 *B, const int ldb,
const paddle::platform::complex128 beta,
paddle::platform::complex128 *C, const int ldc) {
const int K, const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, const int lda,
const paddle::platform::complex<double> *B, const int ldb,
const paddle::platform::complex<double> beta,
paddle::platform::complex<double> *C, const int ldc) {
cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
......
......@@ -213,13 +213,13 @@ struct CUBlas<platform::float16> {
};
template <>
struct CUBlas<platform::complex64> {
using complex64 = platform::complex64;
struct CUBlas<platform::complex<float>> {
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
int n, const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemv(
handle, transa, m, n,
reinterpret_cast<const rocblas_float_complex *>(alpha),
......@@ -229,9 +229,10 @@ struct CUBlas<platform::complex64> {
reinterpret_cast<rocblas_float_complex *>(C), ldc));
}
static void AXPY(rocblas_handle handle, int n, const complex64 *alpha,
const complex64 *X, const int incX, complex64 *Y,
const int incY) {
static void AXPY(rocblas_handle handle, int n,
const platform::complex<float> *alpha,
const platform::complex<float> *X, const int incX,
platform::complex<float> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_caxpy(
handle, n, reinterpret_cast<const rocblas_float_complex *>(alpha),
reinterpret_cast<const rocblas_float_complex *>(X), incX,
......@@ -241,11 +242,13 @@ struct CUBlas<platform::complex64> {
static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A,
int lda, long long int strideA, // NOLINT
const complex64 *B, // NOLINT
const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
long long int strideA, // NOLINT
const platform::complex<float> *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex64 *beta, complex64 *C, int ldc,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -261,9 +264,11 @@ struct CUBlas<platform::complex64> {
static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex64 *alpha, const complex64 *A, int lda,
const complex64 *B, int ldb, const complex64 *beta,
complex64 *C, int ldc) {
const platform::complex<float> *alpha,
const platform::complex<float> *A, int lda,
const platform::complex<float> *B, int ldb,
const platform::complex<float> *beta,
platform::complex<float> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_cgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_float_complex *>(alpha),
......@@ -293,13 +298,13 @@ struct CUBlas<platform::complex64> {
};
template <>
struct CUBlas<platform::complex128> {
using complex128 = platform::complex128;
struct CUBlas<platform::complex<double>> {
static void GEMV(rocblas_handle handle, rocblas_operation transa, int m,
int n, const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
int n, const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemv(
handle, transa, m, n,
reinterpret_cast<const rocblas_double_complex *>(alpha),
......@@ -309,9 +314,10 @@ struct CUBlas<platform::complex128> {
reinterpret_cast<rocblas_double_complex *>(C), ldc));
}
static void AXPY(rocblas_handle handle, int n, const complex128 *alpha,
const complex128 *X, const int incX, complex128 *Y,
const int incY) {
static void AXPY(rocblas_handle handle, int n,
const platform::complex<double> *alpha,
const platform::complex<double> *X, const int incX,
platform::complex<double> *Y, const int incY) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zaxpy(
handle, n, reinterpret_cast<const rocblas_double_complex *>(alpha),
reinterpret_cast<const rocblas_double_complex *>(X), incX,
......@@ -321,11 +327,13 @@ struct CUBlas<platform::complex128> {
static void GEMM_STRIDED_BATCH(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A,
int lda, long long int strideA, // NOLINT
const complex128 *B, // NOLINT
const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
long long int strideA, // NOLINT
const platform::complex<double> *B, // NOLINT
int ldb, long long int strideB, // NOLINT
const complex128 *beta, complex128 *C, int ldc,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc,
long long int strideC, // NOLINT
int batchCount) {
PADDLE_ENFORCE_CUDA_SUCCESS(
......@@ -341,9 +349,11 @@ struct CUBlas<platform::complex128> {
static void GEMM(rocblas_handle handle, rocblas_operation transa,
rocblas_operation transb, int m, int n, int k,
const complex128 *alpha, const complex128 *A, int lda,
const complex128 *B, int ldb, const complex128 *beta,
complex128 *C, int ldc) {
const platform::complex<double> *alpha,
const platform::complex<double> *A, int lda,
const platform::complex<double> *B, int ldb,
const platform::complex<double> *beta,
platform::complex<double> *C, int ldc) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::rocblas_zgemm(
handle, transa, transb, m, n, k,
reinterpret_cast<const rocblas_double_complex *>(alpha),
......@@ -434,9 +444,9 @@ template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex64 alpha, const platform::complex64 *A,
const platform::complex64 *B, platform::complex64 beta,
platform::complex64 *C) const {
platform::complex<float> alpha, const platform::complex<float> *A,
const platform::complex<float> *B, platform::complex<float> beta,
platform::complex<float> *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -461,7 +471,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust::complex<float> c_beta = thrust::complex<float>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex64>::GEMM_EX(
CUBlas<platform::complex<float>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f32_c, ldb, A, rocblas_datatype_f32_c, lda, &c_beta, C,
rocblas_datatype_f32_c, N, rocblas_datatype_f32_c);
......@@ -471,9 +481,9 @@ template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::complex128 alpha, const platform::complex128 *A,
const platform::complex128 *B, platform::complex128 beta,
platform::complex128 *C) const {
platform::complex<double> alpha, const platform::complex<double> *A,
const platform::complex<double> *B, platform::complex<double> beta,
platform::complex<double> *C) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
......@@ -499,7 +509,7 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
thrust::complex<double>(beta.real, beta.imag);
auto &cuda_ctx = const_cast<platform::CUDADeviceContext &>(context_);
CUBlas<platform::complex128>::GEMM_EX(
CUBlas<platform::complex<double>>::GEMM_EX(
&cuda_ctx, cuTransB, cuTransA, N, M, K, &c_alpha, B,
rocblas_datatype_f64_c, ldb, A, rocblas_datatype_f64_c, lda, &c_beta, C,
rocblas_datatype_f64_c, N, rocblas_datatype_f64_c);
......
......@@ -297,7 +297,9 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext,
namespace scatter {
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value>::type
typename std::enable_if<std::is_floating_point<T>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
elementwise_add_to(BlasT<platform::CPUDeviceContext, T>* blas, size_t data_len,
const T* in, T* out) {
blas->AXPY(data_len, T(1.f), in, out);
......@@ -542,9 +544,9 @@ template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex64>;
paddle::platform::complex<float>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::complex128>;
paddle::platform::complex<double>>;
template struct MergeAdd<platform::CPUDeviceContext,
paddle::platform::bfloat16>;
......
......@@ -204,15 +204,15 @@ REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_grad,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
......@@ -21,12 +21,12 @@ REGISTER_OP_CUDA_KERNEL(
matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex128>);
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex<float>>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex128>);
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<float>>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex<double>>);
......@@ -483,19 +483,19 @@ struct ConjHelper {
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex64> {
struct ConjHelper<DeviceContext, paddle::platform::complex<float>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex64>();
auto* dst_data = dst.mutable_data<paddle::platform::complex64>(
auto* src_data = src.data<paddle::platform::complex<float>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex<float>>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex64)));
size_t(src.numel() * sizeof(paddle::platform::complex<float>)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex64> functor(
math::ConjFunctor<paddle::platform::complex<float>> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
......@@ -504,19 +504,19 @@ struct ConjHelper<DeviceContext, paddle::platform::complex64> {
};
template <typename DeviceContext>
struct ConjHelper<DeviceContext, paddle::platform::complex128> {
struct ConjHelper<DeviceContext, paddle::platform::complex<double>> {
explicit ConjHelper(const framework::ExecutionContext& ctx) : ctx_(ctx) {}
HOSTDEVICE void operator()(framework::Tensor& src, framework::Tensor& dst) {
dst.Resize(src.dims());
auto* src_data = src.data<paddle::platform::complex128>();
auto* dst_data = dst.mutable_data<paddle::platform::complex128>(
auto* src_data = src.data<paddle::platform::complex<double>>();
auto* dst_data = dst.mutable_data<paddle::platform::complex<double>>(
ctx_.GetPlace(),
size_t(src.numel() * sizeof(paddle::platform::complex128)));
size_t(src.numel() * sizeof(paddle::platform::complex<double>)));
platform::ForRange<DeviceContext> for_range(
ctx_.template device_context<DeviceContext>(), src.numel());
math::ConjFunctor<paddle::platform::complex128> functor(
math::ConjFunctor<paddle::platform::complex<double>> functor(
src_data, src.numel(), dst_data);
for_range(functor);
return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册