Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9f76d050
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f76d050
编写于
6月 02, 2023
作者:
Z
Zhang Zheng
提交者:
GitHub
6月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize perf of broadcast matmul (#54126)
* Optimize perf of broadcast matmul * support more dtype
上级
fa7ba041
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
271 addition
and
7 deletion
+271
-7
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
+271
-7
未找到文件。
paddle/phi/kernels/funcs/blas/blas_impl.cu.h
浏览文件 @
9f76d050
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <thrust/device_vector.h>
#include "gflags/gflags.h"
#include "glog/logging.h"
...
...
@@ -58,6 +59,16 @@ struct CUBlas<float> {
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasSgemv
(
args
...));
}
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
ARGS
...
args
)
{
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasSgemmBatched
(
args
...));
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"SgemmBatched is not supported on cuda <= 7.5"
));
#endif
}
template
<
typename
...
ARGS
>
static
void
GEMM_STRIDED_BATCH
(
ARGS
...
args
)
{
#if CUDA_VERSION >= 8000
...
...
@@ -178,6 +189,16 @@ struct CUBlas<double> {
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasDgemv
(
args
...));
}
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
ARGS
...
args
)
{
#if CUDA_VERSION >= 8000
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasDgemmBatched
(
args
...));
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"DgemmBatched is not supported on cuda <= 7.5"
));
#endif
}
template
<
typename
...
ARGS
>
static
void
GEMM_STRIDED_BATCH
(
ARGS
...
args
)
{
#if CUDA_VERSION >= 8000
...
...
@@ -261,6 +282,67 @@ struct CUBlas<phi::dtype::float16> {
ldc
));
}
static
void
GEMM_BATCH
(
phi
::
GPUContext
*
dev_ctx
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
const
float16
**
A
,
cudaDataType_t
Atype
,
int
lda
,
const
float16
**
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
float16
**
C
,
cudaDataType_t
Ctype
,
int
ldc
,
int
batchCount
,
cudaDataType_t
computeType
)
{
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_available
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
thrust
::
device_vector
<
const
void
*>
A_ptr
(
A
,
A
+
batchCount
);
thrust
::
device_vector
<
const
void
*>
B_ptr
(
B
,
B
+
batchCount
);
thrust
::
device_vector
<
void
*>
C_ptr
(
C
,
C
+
batchCount
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasGemmBatchedEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A_ptr
.
data
().
get
(),
Atype
,
lda
,
B_ptr
.
data
().
get
(),
Btype
,
ldb
,
beta
,
C_ptr
.
data
().
get
(),
Ctype
,
ldc
,
batchCount
,
computeType
,
algo
));
});
#else
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"cublasGemmBatchedEx is not supported on cuda <= 7.5"
));
#endif
}
static
void
GEMM_STRIDED_BATCH
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
...
...
@@ -1672,6 +1754,96 @@ void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
}
}
template
<
>
template
<
>
inline
void
Blas
<
phi
::
GPUContext
>::
BatchedGEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
double
alpha
,
const
double
**
A
,
const
double
**
B
,
double
beta
,
double
**
C
,
int
batchCount
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
thrust
::
device_vector
<
const
double
*>
A_ptr
(
A
,
A
+
batchCount
);
thrust
::
device_vector
<
const
double
*>
B_ptr
(
B
,
B
+
batchCount
);
thrust
::
device_vector
<
double
*>
C_ptr
(
C
,
C
+
batchCount
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
double
>::
GEMM_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B_ptr
.
data
().
get
(),
ldb
,
A_ptr
.
data
().
get
(),
lda
,
&
beta
,
C_ptr
.
data
().
get
(),
ldc
,
batchCount
);
});
}
template
<
>
template
<
>
inline
void
Blas
<
phi
::
GPUContext
>::
BatchedGEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
float
alpha
,
const
float
**
A
,
const
float
**
B
,
float
beta
,
float
**
C
,
int
batchCount
)
const
{
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
thrust
::
device_vector
<
const
float
*>
A_ptr
(
A
,
A
+
batchCount
);
thrust
::
device_vector
<
const
float
*>
B_ptr
(
B
,
B
+
batchCount
);
thrust
::
device_vector
<
float
*>
C_ptr
(
C
,
C
+
batchCount
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
float
>::
GEMM_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B_ptr
.
data
().
get
(),
ldb
,
A_ptr
.
data
().
get
(),
lda
,
&
beta
,
C_ptr
.
data
().
get
(),
ldc
,
batchCount
);
});
}
template
<
>
template
<
>
inline
void
Blas
<
phi
::
GPUContext
>::
BatchedGEMM
(
CBLAS_TRANSPOSE
transA
,
...
...
@@ -1685,10 +1857,45 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
phi
::
dtype
::
float16
beta
,
phi
::
dtype
::
float16
**
C
,
int
batchCount
)
const
{
for
(
int
k
=
0
;
k
<
batchCount
;
++
k
)
{
this
->
template
GEMM
<
phi
::
dtype
::
float16
>(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
[
k
],
B
[
k
],
beta
,
C
[
k
]);
}
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
53
,
phi
::
errors
::
InvalidArgument
(
"cublas fp16 gemm requires GPU compute capability >= 53,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
f_alpha
=
static_cast
<
float
>
(
alpha
);
float
f_beta
=
static_cast
<
float
>
(
beta
);
auto
&
cuda_ctx
=
const_cast
<
phi
::
GPUContext
&>
(
context_
);
CUBlas
<
phi
::
dtype
::
float16
>::
GEMM_BATCH
(
&
cuda_ctx
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
f_alpha
,
B
,
CUDA_R_16F
,
ldb
,
A
,
CUDA_R_16F
,
lda
,
&
f_beta
,
C
,
CUDA_R_16F
,
ldc
,
batchCount
,
CUDA_R_32F
);
}
template
<
>
...
...
@@ -1704,10 +1911,67 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
phi
::
dtype
::
bfloat16
beta
,
phi
::
dtype
::
bfloat16
**
C
,
int
batchCount
)
const
{
for
(
int
k
=
0
;
k
<
batchCount
;
++
k
)
{
this
->
template
GEMM
<
phi
::
dtype
::
bfloat16
>(
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
[
k
],
B
[
k
],
beta
,
C
[
k
]);
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasOperation_t
cuTransB
=
(
transB
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
PADDLE_ENFORCE_GE
(
context_
.
GetComputeCapability
(),
80
,
phi
::
errors
::
InvalidArgument
(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d"
,
context_
.
GetComputeCapability
()));
float
f_alpha
=
static_cast
<
float
>
(
alpha
);
float
f_beta
=
static_cast
<
float
>
(
beta
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DEFAULT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
thrust
::
device_vector
<
const
void
*>
A_ptr
(
A
,
A
+
batchCount
);
thrust
::
device_vector
<
const
void
*>
B_ptr
(
B
,
B
+
batchCount
);
thrust
::
device_vector
<
void
*>
C_ptr
(
C
,
C
+
batchCount
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cublasGemmBatchedEx
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
f_alpha
,
B_ptr
.
data
().
get
(),
CUDA_R_16BF
,
ldb
,
A_ptr
.
data
().
get
(),
CUDA_R_16BF
,
lda
,
&
f_beta
,
C_ptr
.
data
().
get
(),
CUDA_R_16BF
,
ldc
,
batchCount
,
CUDA_R_32F
,
algo
));
});
#else
// raise error
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"cublasGemmBatchedEx with bfloat16 is not supported on cuda <= 11"
));
#endif // CUDA_VERSION >= 11000
}
template
<
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录