Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9c32099d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c32099d
编写于
7月 06, 2022
作者:
zhouweiwei2014
提交者:
GitHub
7月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse] support optional kp_mask/attn_mask of sparse attention (#44120)
上级
064e549b
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
132 addition
and
98 deletion
+132
-98
paddle/phi/api/yaml/generator/sparse_api_gen.py
paddle/phi/api/yaml/generator/sparse_api_gen.py
+10
-6
paddle/phi/api/yaml/sparse_api.yaml
paddle/phi/api/yaml/sparse_api.yaml
+2
-0
paddle/phi/api/yaml/sparse_bw_api.yaml
paddle/phi/api/yaml/sparse_bw_api.yaml
+2
-0
paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
+10
-9
paddle/phi/kernels/sparse/fused_attention_kernel.h
paddle/phi/kernels/sparse/fused_attention_kernel.h
+10
-9
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
+46
-39
python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py
...e/fluid/tests/unittests/test_sparse_fused_attention_op.py
+46
-29
python/paddle/incubate/sparse/nn/functional/transformer.py
python/paddle/incubate/sparse/nn/functional/transformer.py
+6
-6
未找到文件。
paddle/phi/api/yaml/generator/sparse_api_gen.py
浏览文件 @
9c32099d
...
...
@@ -111,9 +111,8 @@ class SparseAPI(ForwardAPI):
for
param
in
kernel_param
:
if
param
in
input_names
:
if
param
in
self
.
optional_vars
:
raise
ValueError
(
f
"
{
self
.
api
}
: Unsupport optional input(
{
param
}
) for sparse api."
)
kernel_context_code
=
kernel_context_code
+
f
"""
kernel_context.EmplaceBackInput(
{
param
}
?
{
param
}
->impl().get() : nullptr);"""
else
:
kernel_context_code
=
kernel_context_code
+
f
"""
kernel_context.EmplaceBackInput(
{
param
}
.impl().get());"""
...
...
@@ -170,6 +169,11 @@ class SparseAPI(ForwardAPI):
condition_list
=
[]
for
i
,
in_type
in
enumerate
(
input_types
):
if
in_type
==
"dense"
:
if
self
.
inputs
[
'names'
][
i
]
in
self
.
optional_vars
:
condition_list
.
append
(
f
"(!
{
self
.
inputs
[
'names'
][
i
]
}
|| phi::DenseTensor::classof(
{
self
.
inputs
[
'names'
][
i
]
}
->impl().get()))"
)
else
:
condition_list
.
append
(
f
"phi::DenseTensor::classof(
{
self
.
inputs
[
'names'
][
i
]
}
.impl().get())"
)
...
...
paddle/phi/api/yaml/sparse_api.yaml
浏览文件 @
9c32099d
...
...
@@ -147,6 +147,8 @@
kernel
:
func
:
fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr}
layout
:
sparse_mask
data_type
:
query
optional
:
key_padding_mask, attn_mask
intermediate
:
softmax
backward
:
fused_attention_grad
...
...
paddle/phi/api/yaml/sparse_bw_api.yaml
浏览文件 @
9c32099d
...
...
@@ -134,3 +134,5 @@
output
:
Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel
:
func
:
fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout
:
softmax
data_type
:
query
paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc
浏览文件 @
9c32099d
...
...
@@ -21,13 +21,14 @@ namespace phi {
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
query
,
const
DenseTensor
&
key
,
const
DenseTensor
&
value
,
const
SparseCsrTensor
&
sparse_mask
,
const
DenseTensor
&
key_padding_mask
,
const
DenseTensor
&
attn_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
key_padding_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
attn_mask
,
DenseTensor
*
out
,
SparseCsrTensor
*
softmax
)
{
PD_THROW
(
...
...
paddle/phi/kernels/sparse/fused_attention_kernel.h
浏览文件 @
9c32099d
...
...
@@ -21,13 +21,14 @@ namespace phi {
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
query
,
const
DenseTensor
&
key
,
const
DenseTensor
&
value
,
const
SparseCsrTensor
&
sparse_mask
,
const
DenseTensor
&
key_padding_mask
,
const
DenseTensor
&
attn_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
key_padding_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
attn_mask
,
DenseTensor
*
out
,
SparseCsrTensor
*
softmax
);
...
...
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
浏览文件 @
9c32099d
...
...
@@ -127,13 +127,14 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
}
template
<
typename
T
,
typename
Context
>
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
void
FusedAttentionCsrKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
query
,
const
DenseTensor
&
key
,
const
DenseTensor
&
value
,
const
SparseCsrTensor
&
sparse_mask
,
const
DenseTensor
&
key_padding_mask
,
const
DenseTensor
&
attn_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
key_padding_mask
,
const
paddle
::
optional
<
DenseTensor
>
&
attn_mask
,
DenseTensor
*
out
,
SparseCsrTensor
*
softmax
)
{
#if CUDA_VERSION >= 11070
...
...
@@ -183,34 +184,40 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
phi
::
errors
::
InvalidArgument
(
"dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"
));
const
auto
kp_mask_ptr
=
key_padding_mask
.
get_ptr
();
if
(
kp_mask_ptr
)
{
PADDLE_ENFORCE_EQ
(
key_padding_mask
.
dims
().
size
(),
kp_mask_ptr
->
dims
().
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"
));
PADDLE_ENFORCE_EQ
(
key_padding_mask
.
dims
()[
0
],
kp_mask_ptr
->
dims
()[
0
],
q_dim
[
0
],
phi
::
errors
::
InvalidArgument
(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"
));
PADDLE_ENFORCE_EQ
(
key_padding_mask
.
dims
()[
1
],
kp_mask_ptr
->
dims
()[
1
],
M
,
phi
::
errors
::
InvalidArgument
(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"
));
}
PADDLE_ENFORCE_EQ
(
attn_mask
.
dims
().
size
(),
const
auto
attn_mask_ptr
=
attn_mask
.
get_ptr
();
if
(
attn_mask_ptr
)
{
PADDLE_ENFORCE_EQ
(
attn_mask_ptr
->
dims
().
size
(),
2
,
phi
::
errors
::
InvalidArgument
(
"shape of 'attn_mask' must be [seq_len, seq_len]"
));
PADDLE_ENFORCE_EQ
(
attn_mask
.
dims
()[
0
],
PADDLE_ENFORCE_EQ
(
attn_mask_ptr
->
dims
()[
0
],
M
,
phi
::
errors
::
InvalidArgument
(
"shape of 'attn_mask' must be [seq_len, seq_len]"
));
PADDLE_ENFORCE_EQ
(
attn_mask
.
dims
()[
1
],
PADDLE_ENFORCE_EQ
(
attn_mask_ptr
->
dims
()[
1
],
M
,
phi
::
errors
::
InvalidArgument
(
"shape of 'attn_mask' must be [seq_len, seq_len]"
));
}
/* Step1: SDD Matmul, reuse */
SparseCsrTensor
sdd_result
;
...
...
@@ -244,8 +251,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
sdd_result
.
non_zero_crows
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_cols
().
data
<
int64_t
>
(),
sdd_result
.
non_zero_elements
().
data
<
T
>
(),
k
ey_padding_mask
.
data
<
T
>
()
,
attn_mask
.
data
<
T
>
()
,
k
p_mask_ptr
?
kp_mask_ptr
->
data
<
T
>
()
:
nullptr
,
attn_mask
_ptr
?
attn_mask_ptr
->
data
<
T
>
()
:
nullptr
,
softmax
->
mutable_non_zero_elements
()
->
data
<
T
>
(),
M
,
total_row_num
,
...
...
python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py
浏览文件 @
9c32099d
...
...
@@ -47,6 +47,7 @@ class TestSparseAttentionAPI1(unittest.TestCase):
self
.
seq_len
=
128
self
.
head_dim
=
16
self
.
dtype
=
'float64'
self
.
use_mask
=
True
def
test_dygraph
(
self
):
with
_test_eager_guard
():
...
...
@@ -69,6 +70,15 @@ class TestSparseAttentionAPI1(unittest.TestCase):
sp_mask
=
mask
.
reshape
([
-
1
,
self
.
seq_len
,
self
.
seq_len
]).
to_sparse_csr
()
query_sp
=
copy
.
deepcopy
(
query
)
key_sp
=
copy
.
deepcopy
(
key
)
value_sp
=
copy
.
deepcopy
(
value
)
query_sp
.
stop_gradient
=
False
key_sp
.
stop_gradient
=
False
value_sp
.
stop_gradient
=
False
if
self
.
use_mask
:
kp_mask
=
paddle
.
randint
(
0
,
2
,
[
self
.
batch_size
,
self
.
seq_len
]).
astype
(
self
.
dtype
)
attn_mask
=
paddle
.
randint
(
...
...
@@ -82,24 +92,27 @@ class TestSparseAttentionAPI1(unittest.TestCase):
output
=
paddle
.
matmul
(
softmax
,
value
)
output
.
backward
()
query_cp
=
copy
.
deepcopy
(
query
)
key_cp
=
copy
.
deepcopy
(
key
)
value_cp
=
copy
.
deepcopy
(
value
)
query_cp
.
stop_gradient
=
False
key_cp
.
stop_gradient
=
False
value_cp
.
stop_gradient
=
False
output_sp
=
paddle
.
incubate
.
sparse
.
nn
.
functional
.
attention
(
query_sp
,
key_sp
,
value_sp
,
sp_mask
,
kp_mask
,
attn_mask
)
output_sp
.
backward
()
else
:
sdd
=
paddle
.
matmul
(
query
,
key
,
False
,
True
)
/
math
.
sqrt
(
float
(
self
.
head_dim
))
sdd
=
sdd
+
(
mask
-
1.0
)
*
1e9
softmax
=
paddle
.
nn
.
functional
.
softmax
(
sdd
)
output
=
paddle
.
matmul
(
softmax
,
value
)
output
.
backward
()
output_c
p
=
paddle
.
incubate
.
sparse
.
nn
.
functional
.
attention
(
query_cp
,
key_cp
,
value_cp
,
sp_mask
,
kp_mask
,
attn
_mask
)
output_c
p
.
backward
()
output_s
p
=
paddle
.
incubate
.
sparse
.
nn
.
functional
.
attention
(
query_sp
,
key_sp
,
value_sp
,
sp
_mask
)
output_s
p
.
backward
()
self
.
assertTrue
(
np
.
allclose
(
output_
c
p
.
numpy
(),
output
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
output_
s
p
.
numpy
(),
output
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
query_
c
p
.
grad
.
numpy
(),
query
.
grad
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
key_
c
p
.
grad
.
numpy
(),
key
.
grad
.
numpy
()))
np
.
allclose
(
query_
s
p
.
grad
.
numpy
(),
query
.
grad
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
key_
s
p
.
grad
.
numpy
(),
key
.
grad
.
numpy
()))
self
.
assertTrue
(
np
.
allclose
(
value_
c
p
.
grad
.
numpy
(),
value
.
grad
.
numpy
()))
np
.
allclose
(
value_
s
p
.
grad
.
numpy
(),
value
.
grad
.
numpy
()))
class
TestSparseAttentionAPI2
(
TestSparseAttentionAPI1
):
...
...
@@ -110,6 +123,7 @@ class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
self
.
seq_len
=
128
self
.
head_dim
=
32
self
.
dtype
=
'float64'
self
.
use_mask
=
False
class
TestSparseAttentionAPI3
(
TestSparseAttentionAPI1
):
...
...
@@ -120,6 +134,7 @@ class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
self
.
seq_len
=
512
self
.
head_dim
=
16
self
.
dtype
=
'float64'
self
.
use_mask
=
True
class
TestSparseAttentionAPI4
(
TestSparseAttentionAPI1
):
...
...
@@ -130,6 +145,7 @@ class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
self
.
seq_len
=
512
self
.
head_dim
=
32
self
.
dtype
=
'float64'
self
.
use_mask
=
False
class
TestSparseAttentionAPI5
(
TestSparseAttentionAPI1
):
...
...
@@ -140,6 +156,7 @@ class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
self
.
seq_len
=
512
self
.
head_dim
=
64
self
.
dtype
=
'float64'
self
.
use_mask
=
True
if
__name__
==
'__main__'
:
...
...
python/paddle/incubate/sparse/nn/functional/transformer.py
浏览文件 @
9c32099d
...
...
@@ -23,8 +23,8 @@ def attention(query,
key
,
value
,
sparse_mask
,
key_padding_mask
,
attn_mask
,
key_padding_mask
=
None
,
attn_mask
=
None
,
name
=
None
):
"""
Note:
...
...
@@ -50,10 +50,10 @@ def attention(query,
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
attn_mask(DenseTensor
):
The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64.
key_padding_mask(DenseTensor
, optional
): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
Default: None.
attn_mask(DenseTensor
, optional):
The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64.
Default: None.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录