未验证 提交 9f76d050 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize perf of broadcast matmul (#54126)

* Optimize perf of broadcast matmul

* support more dtype
上级 fa7ba041
......@@ -14,6 +14,7 @@
#pragma once
#include <thrust/device_vector.h>
#include "gflags/gflags.h"
#include "glog/logging.h"
......@@ -58,6 +59,16 @@ struct CUBlas<float> {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemv(args...));
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasSgemmBatched(args...));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"SgemmBatched is not supported on cuda <= 7.5"));
#endif
}
template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
......@@ -178,6 +189,16 @@ struct CUBlas<double> {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemv(args...));
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasDgemmBatched(args...));
#else
PADDLE_THROW(phi::errors::Unimplemented(
"DgemmBatched is not supported on cuda <= 7.5"));
#endif
}
template <typename... ARGS>
static void GEMM_STRIDED_BATCH(ARGS... args) {
#if CUDA_VERSION >= 8000
......@@ -261,6 +282,67 @@ struct CUBlas<phi::dtype::float16> {
ldc));
}
static void GEMM_BATCH(phi::GPUContext *dev_ctx,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha,
const float16 **A,
cudaDataType_t Atype,
int lda,
const float16 **B,
cudaDataType_t Btype,
int ldb,
const float *beta,
float16 **C,
cudaDataType_t Ctype,
int ldc,
int batchCount,
cudaDataType_t computeType) {
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
bool use_tensor_op_math = dev_ctx->tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: "
<< (use_tensor_op_math ? "True" : "False");
#endif // CUDA_VERSION >= 9000
thrust::device_vector<const void *> A_ptr(A, A + batchCount);
thrust::device_vector<const void *> B_ptr(B, B + batchCount);
thrust::device_vector<void *> C_ptr(C, C + batchCount);
dev_ctx->TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasGemmBatchedEx(handle,
transa,
transb,
m,
n,
k,
alpha,
A_ptr.data().get(),
Atype,
lda,
B_ptr.data().get(),
Btype,
ldb,
beta,
C_ptr.data().get(),
Ctype,
ldc,
batchCount,
computeType,
algo));
});
#else
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmBatchedEx is not supported on cuda <= 7.5"));
#endif
}
static void GEMM_STRIDED_BATCH(cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
......@@ -1672,6 +1754,96 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
}
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
double alpha,
const double **A,
const double **B,
double beta,
double **C,
int batchCount) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
thrust::device_vector<const double *> A_ptr(A, A + batchCount);
thrust::device_vector<const double *> B_ptr(B, B + batchCount);
thrust::device_vector<double *> C_ptr(C, C + batchCount);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<double>::GEMM_BATCH(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B_ptr.data().get(),
ldb,
A_ptr.data().get(),
lda,
&beta,
C_ptr.data().get(),
ldc,
batchCount);
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
float alpha,
const float **A,
const float **B,
float beta,
float **C,
int batchCount) const {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
thrust::device_vector<const float *> A_ptr(A, A + batchCount);
thrust::device_vector<const float *> B_ptr(B, B + batchCount);
thrust::device_vector<float *> C_ptr(C, C + batchCount);
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<float>::GEMM_BATCH(handle,
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B_ptr.data().get(),
ldb,
A_ptr.data().get(),
lda,
&beta,
C_ptr.data().get(),
ldc,
batchCount);
});
}
template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
......@@ -1685,10 +1857,45 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
phi::dtype::float16 beta,
phi::dtype::float16 **C,
int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<phi::dtype::float16>(
transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
}
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
53,
phi::errors::InvalidArgument(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d",
context_.GetComputeCapability()));
float f_alpha = static_cast<float>(alpha);
float f_beta = static_cast<float>(beta);
auto &cuda_ctx = const_cast<phi::GPUContext &>(context_);
CUBlas<phi::dtype::float16>::GEMM_BATCH(&cuda_ctx,
cuTransB,
cuTransA,
N,
M,
K,
&f_alpha,
B,
CUDA_R_16F,
ldb,
A,
CUDA_R_16F,
lda,
&f_beta,
C,
CUDA_R_16F,
ldc,
batchCount,
CUDA_R_32F);
}
template <>
......@@ -1704,10 +1911,67 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
phi::dtype::bfloat16 beta,
phi::dtype::bfloat16 **C,
int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<phi::dtype::bfloat16>(
transA, transB, M, N, K, alpha, A[k], B[k], beta, C[k]);
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
phi::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));
float f_alpha = static_cast<float>(alpha);
float f_beta = static_cast<float>(beta);
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");
thrust::device_vector<const void *> A_ptr(A, A + batchCount);
thrust::device_vector<const void *> B_ptr(B, B + batchCount);
thrust::device_vector<void *> C_ptr(C, C + batchCount);
context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasGemmBatchedEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&f_alpha,
B_ptr.data().get(),
CUDA_R_16BF,
ldb,
A_ptr.data().get(),
CUDA_R_16BF,
lda,
&f_beta,
C_ptr.data().get(),
CUDA_R_16BF,
ldc,
batchCount,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(phi::errors::Unimplemented(
"cublasGemmBatchedEx with bfloat16 is not supported on cuda <= 11"));
#endif // CUDA_VERSION >= 11000
}
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册