Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
1149a378
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1149a378
编写于
8月 01, 2022
作者:
zhouweiwei2014
提交者:
GitHub
8月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse] optimize sparse attention (#44743)
上级
c28bb981
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
54 addition
and
101 deletion
+54
-101
paddle/fluid/platform/dynload/cusparse.h
paddle/fluid/platform/dynload/cusparse.h
+1
-0
paddle/phi/backends/dynload/cusparse.h
paddle/phi/backends/dynload/cusparse.h
+1
-0
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
+11
-2
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
+8
-15
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
+32
-83
python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py
...e/fluid/tests/unittests/test_sparse_fused_attention_op.py
+1
-1
未找到文件。
paddle/fluid/platform/dynload/cusparse.h
浏览文件 @
1149a378
...
...
@@ -56,6 +56,7 @@ CUSPARSE_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
...
...
paddle/phi/backends/dynload/cusparse.h
浏览文件 @
1149a378
...
...
@@ -68,6 +68,7 @@ CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSpMM_preprocess); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
...
...
paddle/phi/kernels/funcs/sparse/sparse_blas_impl.cu.h
浏览文件 @
1149a378
...
...
@@ -48,6 +48,15 @@ inline cusparseOperation_t GetTransposeOperation(const bool trans) {
}
}
inline
cusparseSpMMAlg_t
GetSpMMAlgorithm
(
const
SparseCsrTensor
&
x
)
{
// TODO(zhouwei): will change to 'CUSPARSE_SPMM_CSR_ALG2' when support batch
return
CUSPARSE_SPMM_CSR_ALG2
;
}
inline
cusparseSpMMAlg_t
GetSpMMAlgorithm
(
const
SparseCooTensor
&
x
)
{
return
CUSPARSE_SPMM_ALG_DEFAULT
;
}
/************* SPARSE MATRIX DESCRIPTOR (COO/CSR) ************/
template
<
typename
T
,
typename
IntT
>
...
...
@@ -324,7 +333,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&
beta
,
out_descriptor
.
descriptor
(),
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
GetSpMMAlgorithm
(
mat_a
)
,
&
buffer_size
);
});
...
...
@@ -341,7 +350,7 @@ void SparseBlas<phi::GPUContext>::SPMM(bool transa,
&
beta
,
out_descriptor
.
descriptor
(),
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
GetSpMMAlgorithm
(
mat_a
)
,
tmp_buffer_ptr
);
});
}
...
...
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
浏览文件 @
1149a378
...
...
@@ -43,21 +43,14 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
int
row_nnz
=
static_cast
<
int
>
(
out_crows
[
crow_idx
+
1
]
-
out_crows
[
crow_idx
]);
if
(
row_nnz
==
0
)
return
;
int
kIteration
=
(
row_nnz
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
T
mul_result
=
0
;
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
idx
>=
row_nnz
)
break
;
mul_result
+=
out_values
[
row_first
+
idx
]
*
dout_values
[
row_first
+
idx
];
T
mul
=
0
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
mul
+=
out_values
[
row_first
+
idx
]
*
dout_values
[
row_first
+
idx
];
}
T
sum
=
phi
::
funcs
::
warpReduceSum
<
T
>
(
mul_result
,
0xFFFFFFFF
);
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
idx
>=
row_nnz
)
break
;
T
mul_sum
=
phi
::
funcs
::
warpReduceSum
<
T
>
(
mul
,
0xFFFFFFFF
);
dx_values
[
row_first
+
idx
]
=
(
dout_values
[
row_first
+
idx
]
-
sum
)
*
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
dx_values
[
row_first
+
idx
]
=
(
dout_values
[
row_first
+
idx
]
-
mul_sum
)
*
out_values
[
row_first
+
idx
]
/
scale
;
}
}
...
...
@@ -96,8 +89,8 @@ void FusedAttentionCsrGradKernel(const Context& dev_ctx,
int
N
=
q_dim
[
q_rank
-
1
];
int
batch_nnz
=
softmax
.
nnz
()
/
batch_num
;
dim3
grid
((
total_row_num
+
3
)
/
4
);
dim3
block
(
WARP_SIZE
,
4
);
dim3
grid
((
total_row_num
+
7
)
/
8
);
dim3
block
(
WARP_SIZE
,
8
);
AttnSoftmaxGpuGradKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
softmax
.
non_zero_crows
().
data
<
int64_t
>
(),
...
...
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
浏览文件 @
1149a378
...
...
@@ -26,30 +26,7 @@ limitations under the License. */
namespace
phi
{
namespace
sparse
{
#define PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, size, HINT, ...) \
case size: { \
constexpr int HINT = size; \
__VA_ARGS__(); \
break; \
}
#define VISIT_ATTN_SFOTMAX(SIZE, NAME, ...) \
[&] { \
const auto& __size__ = SIZE; \
switch (__size__) { \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 1, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 2, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 3, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 4, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 8, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 12, KBufferSize, __VA_ARGS__) \
PRIVATE_CASE_VISIT_ATTN_SOFTMAX(NAME, 16, KBufferSize, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for columns>512 "); \
} \
}()
template
<
typename
T
,
int
BufferSize
>
template
<
typename
T
>
__global__
void
AttnSoftmaxGpuKernel
(
const
int64_t
*
x_crows
,
const
int64_t
*
x_cols
,
const
T
*
x_values
,
...
...
@@ -58,7 +35,6 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
T
*
out_values
,
int
M
,
int
total_row_num
,
float
scale
,
int
num_heads
,
int
batch_nnz
)
{
// out = exp(x-x_max) / sum(exp(x-x_max))
...
...
@@ -72,17 +48,10 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
int
row_nnz
=
static_cast
<
int
>
(
x_crows
[
crow_idx
+
1
]
-
x_crows
[
crow_idx
]);
if
(
row_nnz
==
0
)
return
;
T
buffer
[
BufferSize
]
=
{
0
};
int
kIteration
=
(
row_nnz
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
T
max_val
=
-
std
::
numeric_limits
<
T
>::
infinity
();
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
for
(
int
i
dx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
bool
mask
=
false
;
int
idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
idx
>=
row_nnz
)
break
;
int
col_idx
=
static_cast
<
int
>
(
x_cols
[
row_first
+
idx
]);
if
(
kp_mask
!=
nullptr
&&
kp_mask
[(
cur_batch
/
num_heads
)
*
M
+
col_idx
]
==
0
)
{
mask
=
true
;
...
...
@@ -92,37 +61,30 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
}
if
(
!
mask
)
{
buffer
[
i
]
=
x_values
[
row_first
+
idx
]
/
scale
;
if
(
buffer
[
i
]
>
max_val
)
{
max_val
=
buffer
[
i
]
;
T
val
=
x_values
[
row_first
+
idx
]
;
if
(
val
>
max_val
)
{
max_val
=
val
;
}
out_values
[
row_first
+
idx
]
=
val
;
}
else
{
// Note corner case: when all elements of the row are masked, result
// may be wrong because of exp('-inf' - '-inf'), just ignore now.
out_values
[
row_first
+
idx
]
=
-
std
::
numeric_limits
<
T
>::
infinity
();
}
}
T
row_max_val
=
phi
::
funcs
::
warpReduceMax
<
T
>
(
max_val
,
0xFFFFFFFF
);
auto
functor
=
phi
::
funcs
::
CudaExpFunctor
<
T
>
();
T
exp_sum
=
0
;
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
idx
>=
row_nnz
)
break
;
if
(
buffer
[
i
])
{
T
exp
=
functor
(
buffer
[
i
]
-
row_max_val
);
exp_sum
+=
exp
;
buffer
[
i
]
=
exp
;
}
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
auto
functor
=
phi
::
funcs
::
CudaExpFunctor
<
T
>
();
T
exp
=
functor
(
out_values
[
row_first
+
idx
]
-
row_max_val
);
exp_sum
+=
exp
;
out_values
[
row_first
+
idx
]
=
exp
;
}
T
row_exp_sum
=
phi
::
funcs
::
warpReduceSum
<
T
>
(
exp_sum
,
0xFFFFFFFF
);
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
threadIdx
.
x
+
i
*
WARP_SIZE
;
if
(
idx
>=
row_nnz
)
break
;
if
(
buffer
[
i
])
{
out_values
[
row_first
+
idx
]
=
buffer
[
i
]
/
row_exp_sum
;
}
else
{
out_values
[
row_first
+
idx
]
=
static_cast
<
T
>
(
0
);
}
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
out_values
[
row_first
+
idx
]
=
out_values
[
row_first
+
idx
]
/
row_exp_sum
;
}
}
...
...
@@ -219,49 +181,36 @@ void FusedAttentionCsrKernel(
"shape of 'attn_mask' must be [seq_len, seq_len]"
));
}
/* Step1: SDD Matmul, reuse */
/* Step1: SDD Matmul, reuse
matmul
*/
SparseCsrTensor
sdd_result
;
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
sparse_mask
,
&
sdd_result
);
auto
sparse_blas
=
phi
::
funcs
::
sparse
::
GetSparseBlas
<
Context
,
T
>
(
dev_ctx
);
sparse_blas
.
SDDMM
(
false
,
true
,
static_cast
<
T
>
(
1
),
static_cast
<
T
>
(
1
/
std
::
sqrt
(
N
)
),
query
,
key
,
static_cast
<
T
>
(
0
),
&
sdd_result
);
/* Step2: Softmax with kp_mask/attn_mask, manualy not reuse */
EmptyLikeCsrKernel
<
T
,
Context
>
(
dev_ctx
,
sdd_result
,
softmax
);
int
buffer_size
;
if
(
M
<
128
)
{
buffer_size
=
(
M
+
32
-
1
)
/
32
;
}
else
{
buffer_size
=
((
M
+
128
-
1
)
/
128
)
*
4
;
}
dim3
grid
((
total_row_num
+
3
)
/
4
);
dim3
block
(
WARP_SIZE
,
4
);
dim3
grid
((
total_row_num
+
7
)
/
8
);
dim3
block
(
WARP_SIZE
,
8
);
int
batch_nnz
=
sdd_result
.
nnz
()
/
batch_num
;
AttnSoftmaxGpuKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
sdd_result
.
non_zero_crows
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_cols
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_elements
().
data
<
T
>
(),
kp_mask_ptr
?
kp_mask_ptr
->
data
<
T
>
()
:
nullptr
,
attn_mask_ptr
?
attn_mask_ptr
->
data
<
T
>
()
:
nullptr
,
softmax
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
M
,
total_row_num
,
q_dim
[
1
],
batch_nnz
);
VISIT_ATTN_SFOTMAX
(
buffer_size
,
"AttnSoftmaxGpuKernel"
,
[
&
]
{
AttnSoftmaxGpuKernel
<
T
,
KBufferSize
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
sdd_result
.
non_zero_crows
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_cols
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_elements
().
data
<
T
>
(),
kp_mask_ptr
?
kp_mask_ptr
->
data
<
T
>
()
:
nullptr
,
attn_mask_ptr
?
attn_mask_ptr
->
data
<
T
>
()
:
nullptr
,
softmax
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
M
,
total_row_num
,
std
::
sqrt
(
N
),
q_dim
[
1
],
batch_nnz
);
});
/* Step3: DSD Matmul, reuse */
softmax
->
set_dims
(
phi
::
make_ddim
({
q_dim
[
0
],
q_dim
[
1
],
q_dim
[
2
],
q_dim
[
2
]}));
MatmulCsrDenseKernel
<
T
,
Context
>
(
dev_ctx
,
*
softmax
,
value
,
out
);
#else
...
...
python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py
浏览文件 @
1149a378
...
...
@@ -37,7 +37,7 @@ def get_cuda_version():
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11070
,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.
3
"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.
7
"
)
class
TestSparseAttentionAPI1
(
unittest
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录