未验证 提交 692a9632 编写于 作者: H huangjiyi 提交者: GitHub

rm "paddle/fluid/platform/dynload/cublas.h" in phi (#47778)

上级 ccb47076
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -32,34 +32,34 @@ template <> ...@@ -32,34 +32,34 @@ template <>
struct CUBlas<float> { struct CUBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSaxpy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void SCAL(ARGS... args) { static void SCAL(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSscal(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasScopy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemv(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasSgemmStridedBatched(args...)); phi::dynload::cublasSgemmStridedBatched(args...));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"SgemmStridedBatched is not supported on cuda <= 7.5")); "SgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -93,24 +93,23 @@ struct CUBlas<float> { ...@@ -93,24 +93,23 @@ struct CUBlas<float> {
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle,
paddle::platform::dynload::cublasSgemmEx(handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, alpha,
alpha, A,
A, Atype,
Atype, lda,
lda, B,
B, Btype,
Btype, ldb,
ldb, beta,
beta, C,
C, Ctype,
Ctype, ldc));
ldc));
}); });
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
...@@ -120,37 +119,32 @@ struct CUBlas<float> { ...@@ -120,37 +119,32 @@ struct CUBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void TRSM(ARGS... args) { static void TRSM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasStrsm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) { static void GETRF_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...));
paddle::platform::dynload::cublasSgetrfBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) { static void GETRI_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...));
paddle::platform::dynload::cublasSgetriBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) { static void MATINV_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...));
paddle::platform::dynload::cublasSmatinvBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) { static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...));
paddle::platform::dynload::cublasSgetrsBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) { static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...));
paddle::platform::dynload::cublasStrsmBatched(args...));
} }
}; };
...@@ -158,34 +152,34 @@ template <> ...@@ -158,34 +152,34 @@ template <>
struct CUBlas<double> { struct CUBlas<double> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDaxpy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void SCAL(ARGS... args) { static void SCAL(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDscal(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDcopy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemv(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasDgemmStridedBatched(args...)); phi::dynload::cublasDgemmStridedBatched(args...));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"DgemmStridedBatched is not supported on cuda <= 7.5")); "DgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -200,37 +194,32 @@ struct CUBlas<double> { ...@@ -200,37 +194,32 @@ struct CUBlas<double> {
template <typename... ARGS> template <typename... ARGS>
static void TRSM(ARGS... args) { static void TRSM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDtrsm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) { static void GETRF_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...));
paddle::platform::dynload::cublasDgetrfBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) { static void GETRI_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...));
paddle::platform::dynload::cublasDgetriBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) { static void MATINV_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...));
paddle::platform::dynload::cublasDmatinvBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) { static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...));
paddle::platform::dynload::cublasDgetrsBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) { static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...));
paddle::platform::dynload::cublasDtrsmBatched(args...));
} }
}; };
...@@ -252,21 +241,21 @@ struct CUBlas<phi::dtype::float16> { ...@@ -252,21 +241,21 @@ struct CUBlas<phi::dtype::float16> {
const float16 *beta, const float16 *beta,
float16 *C, float16 *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasHgemm( PADDLE_ENFORCE_GPU_SUCCESS(
handle, phi::dynload::cublasHgemm(handle,
transa, transa,
transb, transb,
m, m,
n, n,
k, k,
reinterpret_cast<const __half *>(alpha), reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(A), reinterpret_cast<const __half *>(A),
lda, lda,
reinterpret_cast<const __half *>(B), reinterpret_cast<const __half *>(B),
ldb, ldb,
reinterpret_cast<const __half *>(beta), reinterpret_cast<const __half *>(beta),
reinterpret_cast<__half *>(C), reinterpret_cast<__half *>(C),
ldc)); ldc));
} }
static void GEMM_STRIDED_BATCH(cublasHandle_t handle, static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
...@@ -288,26 +277,25 @@ struct CUBlas<phi::dtype::float16> { ...@@ -288,26 +277,25 @@ struct CUBlas<phi::dtype::float16> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched(
paddle::platform::dynload::cublasHgemmStridedBatched( handle,
handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, reinterpret_cast<const __half *>(alpha),
reinterpret_cast<const __half *>(alpha), reinterpret_cast<const __half *>(A),
reinterpret_cast<const __half *>(A), lda,
lda, strideA,
strideA, reinterpret_cast<const __half *>(B),
reinterpret_cast<const __half *>(B), ldb,
ldb, strideB,
strideB, reinterpret_cast<const __half *>(beta),
reinterpret_cast<const __half *>(beta), reinterpret_cast<__half *>(C),
reinterpret_cast<__half *>(C), ldc,
ldc, strideC,
strideC, batchCount));
batchCount));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"HgemmStridedBatched is not supported on cuda <= 7.5")); "HgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -347,26 +335,25 @@ struct CUBlas<phi::dtype::float16> { ...@@ -347,26 +335,25 @@ struct CUBlas<phi::dtype::float16> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, alpha,
alpha, A,
A, Atype,
Atype, lda,
lda, B,
B, Btype,
Btype, ldb,
ldb, beta,
beta, C,
C, Ctype,
Ctype, ldc,
ldc, computeType,
computeType, algo));
algo));
}); });
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
...@@ -389,7 +376,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -389,7 +376,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> *beta, const phi::dtype::complex<float> *beta,
phi::dtype::complex<float> *C, phi::dtype::complex<float> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemv( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv(
handle, handle,
transa, transa,
m, m,
...@@ -411,7 +398,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -411,7 +398,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const int incX, const int incX,
phi::dtype::complex<float> *Y, phi::dtype::complex<float> *Y,
const int incY) { const int incY) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCaxpy( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy(
handle, handle,
n, n,
reinterpret_cast<const cuFloatComplex *>(alpha), reinterpret_cast<const cuFloatComplex *>(alpha),
...@@ -440,26 +427,25 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -440,26 +427,25 @@ struct CUBlas<phi::dtype::complex<float>> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched(
paddle::platform::dynload::cublasCgemmStridedBatched( handle,
handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(alpha), reinterpret_cast<const cuFloatComplex *>(A),
reinterpret_cast<const cuFloatComplex *>(A), lda,
lda, strideA,
strideA, reinterpret_cast<const cuFloatComplex *>(B),
reinterpret_cast<const cuFloatComplex *>(B), ldb,
ldb, strideB,
strideB, reinterpret_cast<const cuFloatComplex *>(beta),
reinterpret_cast<const cuFloatComplex *>(beta), reinterpret_cast<cuFloatComplex *>(C),
reinterpret_cast<cuFloatComplex *>(C), ldc,
ldc, strideC,
strideC, batchCount));
batchCount));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"CgemmStridedBatched is not supported on cuda <= 7.5")); "CgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -480,7 +466,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -480,7 +466,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> *beta, const phi::dtype::complex<float> *beta,
phi::dtype::complex<float> *C, phi::dtype::complex<float> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm(
handle, handle,
transa, transa,
transb, transb,
...@@ -509,7 +495,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -509,7 +495,7 @@ struct CUBlas<phi::dtype::complex<float>> {
int lda, int lda,
phi::dtype::complex<float> *B, phi::dtype::complex<float> *B,
int ldb) { int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm(
handle, handle,
side, side,
uplo, uplo,
...@@ -557,26 +543,25 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -557,26 +543,25 @@ struct CUBlas<phi::dtype::complex<float>> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, alpha,
alpha, A,
A, Atype,
Atype, lda,
lda, B,
B, Btype,
Btype, ldb,
ldb, beta,
beta, C,
C, Ctype,
Ctype, ldc,
ldc, computeType,
computeType, algo));
algo));
}); });
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
...@@ -597,7 +582,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -597,7 +582,7 @@ struct CUBlas<phi::dtype::complex<float>> {
phi::dtype::complex<float> **B, phi::dtype::complex<float> **B,
int ldb, int ldb,
int batch_size) { int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsmBatched( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched(
handle, handle,
side, side,
uplo, uplo,
...@@ -628,7 +613,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -628,7 +613,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const phi::dtype::complex<double> *beta, const phi::dtype::complex<double> *beta,
phi::dtype::complex<double> *C, phi::dtype::complex<double> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemv( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv(
handle, handle,
transa, transa,
m, m,
...@@ -650,7 +635,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -650,7 +635,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const int incX, const int incX,
phi::dtype::complex<double> *Y, phi::dtype::complex<double> *Y,
const int incY) { const int incY) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZaxpy( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy(
handle, handle,
n, n,
reinterpret_cast<const cuDoubleComplex *>(alpha), reinterpret_cast<const cuDoubleComplex *>(alpha),
...@@ -680,26 +665,25 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -680,26 +665,25 @@ struct CUBlas<phi::dtype::complex<double>> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched(
paddle::platform::dynload::cublasZgemmStridedBatched( handle,
handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(alpha), reinterpret_cast<const cuDoubleComplex *>(A),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
lda, strideA,
strideA, reinterpret_cast<const cuDoubleComplex *>(B),
reinterpret_cast<const cuDoubleComplex *>(B), ldb,
ldb, strideB,
strideB, reinterpret_cast<const cuDoubleComplex *>(beta),
reinterpret_cast<const cuDoubleComplex *>(beta), reinterpret_cast<cuDoubleComplex *>(C),
reinterpret_cast<cuDoubleComplex *>(C), ldc,
ldc, strideC,
strideC, batchCount));
batchCount));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"CgemmStridedBatched is not supported on cuda <= 7.5")); "CgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -720,7 +704,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -720,7 +704,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const phi::dtype::complex<double> *beta, const phi::dtype::complex<double> *beta,
phi::dtype::complex<double> *C, phi::dtype::complex<double> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm(
handle, handle,
transa, transa,
transb, transb,
...@@ -749,7 +733,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -749,7 +733,7 @@ struct CUBlas<phi::dtype::complex<double>> {
int lda, int lda,
phi::dtype::complex<double> *B, phi::dtype::complex<double> *B,
int ldb) { int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm(
handle, handle,
side, side,
uplo, uplo,
...@@ -777,7 +761,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -777,7 +761,7 @@ struct CUBlas<phi::dtype::complex<double>> {
phi::dtype::complex<double> **B, phi::dtype::complex<double> **B,
int ldb, int ldb,
int batch_size) { int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsmBatched( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched(
handle, handle,
side, side,
uplo, uplo,
...@@ -826,26 +810,25 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -826,26 +810,25 @@ struct CUBlas<phi::dtype::complex<double>> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle, transa,
transa, transb,
transb, m,
m, n,
n, k,
k, alpha,
alpha, A,
A, Atype,
Atype, lda,
lda, B,
B, Btype,
Btype, ldb,
ldb, beta,
beta, C,
C, Ctype,
Ctype, ldc,
ldc, computeType,
computeType, algo));
algo));
}); });
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
...@@ -1039,26 +1022,25 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -1039,26 +1022,25 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle, cuTransB,
cuTransB, cuTransA,
cuTransA, N,
N, M,
M, K,
K, &h_alpha,
&h_alpha, B,
B, CUDA_R_16BF,
CUDA_R_16BF, ldb,
ldb, A,
A, CUDA_R_16BF,
CUDA_R_16BF, lda,
lda, &h_beta,
&h_beta, C,
C, CUDA_R_16BF,
CUDA_R_16BF, N,
N, CUDA_R_32F,
CUDA_R_32F, algo));
algo));
}); });
#else #else
// raise error // raise error
...@@ -1476,29 +1458,29 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1476,29 +1458,29 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, phi::dynload::cublasGemmStridedBatchedEx(handle,
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
M, M,
K, K,
a, a,
B, B,
fp, fp,
ldb, ldb,
strideB, strideB,
A, A,
fp, fp,
lda, lda,
strideA, strideA,
b, b,
C, C,
fp, fp,
ldc, ldc,
strideC, strideC,
batchCount, batchCount,
compute_type, compute_type,
algo)); algo));
}); });
} else { } else {
#endif // CUDA_VERSION >= 9010 #endif // CUDA_VERSION >= 9010
...@@ -1568,30 +1550,29 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1568,30 +1550,29 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx( phi::dynload::cublasGemmStridedBatchedEx(handle,
handle, cuTransB,
cuTransB, cuTransA,
cuTransA, N,
N, M,
M, K,
K, &h_alpha,
&h_alpha, B,
B, CUDA_R_16BF,
CUDA_R_16BF, ldb,
ldb, strideB,
strideB, A,
A, CUDA_R_16BF,
CUDA_R_16BF, lda,
lda, strideA,
strideA, &h_beta,
&h_beta, C,
C, CUDA_R_16BF,
CUDA_R_16BF, ldc,
ldc, strideC,
strideC, batchCount,
batchCount, CUBLAS_COMPUTE_32F,
CUBLAS_COMPUTE_32F, algo));
algo));
}); });
#else #else
// raise error // raise error
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册