未验证 提交 692a9632 编写于 作者: H huangjiyi 提交者: GitHub

rm "paddle/fluid/platform/dynload/cublas.h" in phi (#47778)

上级 ccb47076
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#pragma once #pragma once
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -32,34 +32,34 @@ template <> ...@@ -32,34 +32,34 @@ template <>
struct CUBlas<float> { struct CUBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSaxpy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void SCAL(ARGS... args) { static void SCAL(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSscal(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSscal(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasScopy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasScopy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasSgemv(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasSgemmStridedBatched(args...)); phi::dynload::cublasSgemmStridedBatched(args...));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"SgemmStridedBatched is not supported on cuda <= 7.5")); "SgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -93,8 +93,7 @@ struct CUBlas<float> { ...@@ -93,8 +93,7 @@ struct CUBlas<float> {
VLOG(5) << "use_tensor_op_math: " VLOG(5) << "use_tensor_op_math: "
<< (dev_ctx->tensor_core_available() ? "True" : "False"); << (dev_ctx->tensor_core_available() ? "True" : "False");
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmEx(handle,
paddle::platform::dynload::cublasSgemmEx(handle,
transa, transa,
transb, transb,
m, m,
...@@ -120,37 +119,32 @@ struct CUBlas<float> { ...@@ -120,37 +119,32 @@ struct CUBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void TRSM(ARGS... args) { static void TRSM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasStrsm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) { static void GETRF_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrfBatched(args...));
paddle::platform::dynload::cublasSgetrfBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) { static void GETRI_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetriBatched(args...));
paddle::platform::dynload::cublasSgetriBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) { static void MATINV_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSmatinvBatched(args...));
paddle::platform::dynload::cublasSmatinvBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) { static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgetrsBatched(args...));
paddle::platform::dynload::cublasSgetrsBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) { static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasStrsmBatched(args...));
paddle::platform::dynload::cublasStrsmBatched(args...));
} }
}; };
...@@ -158,34 +152,34 @@ template <> ...@@ -158,34 +152,34 @@ template <>
struct CUBlas<double> { struct CUBlas<double> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDaxpy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDaxpy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void SCAL(ARGS... args) { static void SCAL(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDscal(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDscal(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDcopy(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDcopy(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDgemv(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) { static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasDgemmStridedBatched(args...)); phi::dynload::cublasDgemmStridedBatched(args...));
#else #else
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"DgemmStridedBatched is not supported on cuda <= 7.5")); "DgemmStridedBatched is not supported on cuda <= 7.5"));
...@@ -200,37 +194,32 @@ struct CUBlas<double> { ...@@ -200,37 +194,32 @@ struct CUBlas<double> {
template <typename... ARGS> template <typename... ARGS>
static void TRSM(ARGS... args) { static void TRSM(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasDtrsm(args...)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsm(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRF_BATCH(ARGS... args) { static void GETRF_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrfBatched(args...));
paddle::platform::dynload::cublasDgetrfBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRI_BATCH(ARGS... args) { static void GETRI_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetriBatched(args...));
paddle::platform::dynload::cublasDgetriBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void MATINV_BATCH(ARGS... args) { static void MATINV_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDmatinvBatched(args...));
paddle::platform::dynload::cublasDmatinvBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) { static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgetrsBatched(args...));
paddle::platform::dynload::cublasDgetrsBatched(args...));
} }
template <typename... ARGS> template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) { static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDtrsmBatched(args...));
paddle::platform::dynload::cublasDtrsmBatched(args...));
} }
}; };
...@@ -252,8 +241,8 @@ struct CUBlas<phi::dtype::float16> { ...@@ -252,8 +241,8 @@ struct CUBlas<phi::dtype::float16> {
const float16 *beta, const float16 *beta,
float16 *C, float16 *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasHgemm( PADDLE_ENFORCE_GPU_SUCCESS(
handle, phi::dynload::cublasHgemm(handle,
transa, transa,
transb, transb,
m, m,
...@@ -288,8 +277,7 @@ struct CUBlas<phi::dtype::float16> { ...@@ -288,8 +277,7 @@ struct CUBlas<phi::dtype::float16> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasHgemmStridedBatched(
paddle::platform::dynload::cublasHgemmStridedBatched(
handle, handle,
transa, transa,
transb, transb,
...@@ -347,8 +335,7 @@ struct CUBlas<phi::dtype::float16> { ...@@ -347,8 +335,7 @@ struct CUBlas<phi::dtype::float16> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle,
transa, transa,
transb, transb,
m, m,
...@@ -389,7 +376,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -389,7 +376,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> *beta, const phi::dtype::complex<float> *beta,
phi::dtype::complex<float> *C, phi::dtype::complex<float> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemv( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemv(
handle, handle,
transa, transa,
m, m,
...@@ -411,7 +398,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -411,7 +398,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const int incX, const int incX,
phi::dtype::complex<float> *Y, phi::dtype::complex<float> *Y,
const int incY) { const int incY) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCaxpy( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCaxpy(
handle, handle,
n, n,
reinterpret_cast<const cuFloatComplex *>(alpha), reinterpret_cast<const cuFloatComplex *>(alpha),
...@@ -440,8 +427,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -440,8 +427,7 @@ struct CUBlas<phi::dtype::complex<float>> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemmStridedBatched(
paddle::platform::dynload::cublasCgemmStridedBatched(
handle, handle,
transa, transa,
transb, transb,
...@@ -480,7 +466,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -480,7 +466,7 @@ struct CUBlas<phi::dtype::complex<float>> {
const phi::dtype::complex<float> *beta, const phi::dtype::complex<float> *beta,
phi::dtype::complex<float> *C, phi::dtype::complex<float> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCgemm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgemm(
handle, handle,
transa, transa,
transb, transb,
...@@ -509,7 +495,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -509,7 +495,7 @@ struct CUBlas<phi::dtype::complex<float>> {
int lda, int lda,
phi::dtype::complex<float> *B, phi::dtype::complex<float> *B,
int ldb) { int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsm(
handle, handle,
side, side,
uplo, uplo,
...@@ -557,8 +543,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -557,8 +543,7 @@ struct CUBlas<phi::dtype::complex<float>> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle,
transa, transa,
transb, transb,
m, m,
...@@ -597,7 +582,7 @@ struct CUBlas<phi::dtype::complex<float>> { ...@@ -597,7 +582,7 @@ struct CUBlas<phi::dtype::complex<float>> {
phi::dtype::complex<float> **B, phi::dtype::complex<float> **B,
int ldb, int ldb,
int batch_size) { int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasCtrsmBatched( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCtrsmBatched(
handle, handle,
side, side,
uplo, uplo,
...@@ -628,7 +613,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -628,7 +613,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const phi::dtype::complex<double> *beta, const phi::dtype::complex<double> *beta,
phi::dtype::complex<double> *C, phi::dtype::complex<double> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemv( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemv(
handle, handle,
transa, transa,
m, m,
...@@ -650,7 +635,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -650,7 +635,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const int incX, const int incX,
phi::dtype::complex<double> *Y, phi::dtype::complex<double> *Y,
const int incY) { const int incY) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZaxpy( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZaxpy(
handle, handle,
n, n,
reinterpret_cast<const cuDoubleComplex *>(alpha), reinterpret_cast<const cuDoubleComplex *>(alpha),
...@@ -680,8 +665,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -680,8 +665,7 @@ struct CUBlas<phi::dtype::complex<double>> {
long long int strideC, // NOLINT long long int strideC, // NOLINT
int batchCount) { int batchCount) {
#if CUDA_VERSION >= 8000 #if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemmStridedBatched(
paddle::platform::dynload::cublasZgemmStridedBatched(
handle, handle,
transa, transa,
transb, transb,
...@@ -720,7 +704,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -720,7 +704,7 @@ struct CUBlas<phi::dtype::complex<double>> {
const phi::dtype::complex<double> *beta, const phi::dtype::complex<double> *beta,
phi::dtype::complex<double> *C, phi::dtype::complex<double> *C,
int ldc) { int ldc) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZgemm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgemm(
handle, handle,
transa, transa,
transb, transb,
...@@ -749,7 +733,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -749,7 +733,7 @@ struct CUBlas<phi::dtype::complex<double>> {
int lda, int lda,
phi::dtype::complex<double> *B, phi::dtype::complex<double> *B,
int ldb) { int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsm( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsm(
handle, handle,
side, side,
uplo, uplo,
...@@ -777,7 +761,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -777,7 +761,7 @@ struct CUBlas<phi::dtype::complex<double>> {
phi::dtype::complex<double> **B, phi::dtype::complex<double> **B,
int ldb, int ldb,
int batch_size) { int batch_size) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cublasZtrsmBatched( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZtrsmBatched(
handle, handle,
side, side,
uplo, uplo,
...@@ -826,8 +810,7 @@ struct CUBlas<phi::dtype::complex<double>> { ...@@ -826,8 +810,7 @@ struct CUBlas<phi::dtype::complex<double>> {
#endif // CUDA_VERSION >= 9000 #endif // CUDA_VERSION >= 9000
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle,
transa, transa,
transb, transb,
m, m,
...@@ -1039,8 +1022,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA, ...@@ -1039,8 +1022,7 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
paddle::platform::dynload::cublasGemmEx(handle,
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
...@@ -1476,7 +1458,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1476,7 +1458,7 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx(handle, phi::dynload::cublasGemmStridedBatchedEx(handle,
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
...@@ -1568,8 +1550,7 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA, ...@@ -1568,8 +1550,7 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cublasGemmStridedBatchedEx( phi::dynload::cublasGemmStridedBatchedEx(handle,
handle,
cuTransB, cuTransB,
cuTransA, cuTransA,
N, N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册