未验证 提交 9c32099d 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Sparse] support optional kp_mask/attn_mask of sparse attention (#44120)

上级 064e549b
......@@ -111,9 +111,8 @@ class SparseAPI(ForwardAPI):
for param in kernel_param:
if param in input_names:
if param in self.optional_vars:
raise ValueError(
f"{self.api} : Unsupport optional input({param}) for sparse api."
)
kernel_context_code = kernel_context_code + f"""
kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
else:
kernel_context_code = kernel_context_code + f"""
kernel_context.EmplaceBackInput({param}.impl().get());"""
......@@ -170,6 +169,11 @@ class SparseAPI(ForwardAPI):
condition_list = []
for i, in_type in enumerate(input_types):
if in_type == "dense":
if self.inputs['names'][i] in self.optional_vars:
condition_list.append(
f"(!{self.inputs['names'][i]} || phi::DenseTensor::classof({self.inputs['names'][i]}->impl().get()))"
)
else:
condition_list.append(
f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())"
)
......
......@@ -147,6 +147,8 @@
kernel :
func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr}
layout : sparse_mask
data_type: query
optional : key_padding_mask, attn_mask
intermediate : softmax
backward: fused_attention_grad
......
......@@ -134,3 +134,5 @@
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax
data_type: query
......@@ -21,13 +21,14 @@ namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
void FusedAttentionCsrKernel(
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
const paddle::optional<DenseTensor>& key_padding_mask,
const paddle::optional<DenseTensor>& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW(
......
......@@ -21,13 +21,14 @@ namespace phi {
namespace sparse {
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
void FusedAttentionCsrKernel(
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
const paddle::optional<DenseTensor>& key_padding_mask,
const paddle::optional<DenseTensor>& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax);
......
......@@ -127,13 +127,14 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
}
template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx,
void FusedAttentionCsrKernel(
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const SparseCsrTensor& sparse_mask,
const DenseTensor& key_padding_mask,
const DenseTensor& attn_mask,
const paddle::optional<DenseTensor>& key_padding_mask,
const paddle::optional<DenseTensor>& attn_mask,
DenseTensor* out,
SparseCsrTensor* softmax) {
#if CUDA_VERSION >= 11070
......@@ -183,34 +184,40 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]"));
const auto kp_mask_ptr = key_padding_mask.get_ptr();
if (kp_mask_ptr) {
PADDLE_ENFORCE_EQ(
key_padding_mask.dims().size(),
kp_mask_ptr->dims().size(),
2,
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
PADDLE_ENFORCE_EQ(
key_padding_mask.dims()[0],
kp_mask_ptr->dims()[0],
q_dim[0],
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
PADDLE_ENFORCE_EQ(
key_padding_mask.dims()[1],
kp_mask_ptr->dims()[1],
M,
phi::errors::InvalidArgument(
"shape of 'key_padding_mask' must be [batch_size, seq_len]"));
}
PADDLE_ENFORCE_EQ(attn_mask.dims().size(),
const auto attn_mask_ptr = attn_mask.get_ptr();
if (attn_mask_ptr) {
PADDLE_ENFORCE_EQ(attn_mask_ptr->dims().size(),
2,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask.dims()[0],
PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[0],
M,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask.dims()[1],
PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[1],
M,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
}
/* Step1: SDD Matmul, reuse */
SparseCsrTensor sdd_result;
......@@ -244,8 +251,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(),
key_padding_mask.data<T>(),
attn_mask.data<T>(),
kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(),
M,
total_row_num,
......
......@@ -47,6 +47,7 @@ class TestSparseAttentionAPI1(unittest.TestCase):
self.seq_len = 128
self.head_dim = 16
self.dtype = 'float64'
self.use_mask = True
def test_dygraph(self):
with _test_eager_guard():
......@@ -69,6 +70,15 @@ class TestSparseAttentionAPI1(unittest.TestCase):
sp_mask = mask.reshape([-1, self.seq_len,
self.seq_len]).to_sparse_csr()
query_sp = copy.deepcopy(query)
key_sp = copy.deepcopy(key)
value_sp = copy.deepcopy(value)
query_sp.stop_gradient = False
key_sp.stop_gradient = False
value_sp.stop_gradient = False
if self.use_mask:
kp_mask = paddle.randint(
0, 2, [self.batch_size, self.seq_len]).astype(self.dtype)
attn_mask = paddle.randint(
......@@ -82,24 +92,27 @@ class TestSparseAttentionAPI1(unittest.TestCase):
output = paddle.matmul(softmax, value)
output.backward()
query_cp = copy.deepcopy(query)
key_cp = copy.deepcopy(key)
value_cp = copy.deepcopy(value)
query_cp.stop_gradient = False
key_cp.stop_gradient = False
value_cp.stop_gradient = False
output_sp = paddle.incubate.sparse.nn.functional.attention(
query_sp, key_sp, value_sp, sp_mask, kp_mask, attn_mask)
output_sp.backward()
else:
sdd = paddle.matmul(query, key, False, True) / math.sqrt(
float(self.head_dim))
sdd = sdd + (mask - 1.0) * 1e9
softmax = paddle.nn.functional.softmax(sdd)
output = paddle.matmul(softmax, value)
output.backward()
output_cp = paddle.incubate.sparse.nn.functional.attention(
query_cp, key_cp, value_cp, sp_mask, kp_mask, attn_mask)
output_cp.backward()
output_sp = paddle.incubate.sparse.nn.functional.attention(
query_sp, key_sp, value_sp, sp_mask)
output_sp.backward()
self.assertTrue(np.allclose(output_cp.numpy(), output.numpy()))
self.assertTrue(np.allclose(output_sp.numpy(), output.numpy()))
self.assertTrue(
np.allclose(query_cp.grad.numpy(), query.grad.numpy()))
self.assertTrue(np.allclose(key_cp.grad.numpy(), key.grad.numpy()))
np.allclose(query_sp.grad.numpy(), query.grad.numpy()))
self.assertTrue(np.allclose(key_sp.grad.numpy(), key.grad.numpy()))
self.assertTrue(
np.allclose(value_cp.grad.numpy(), value.grad.numpy()))
np.allclose(value_sp.grad.numpy(), value.grad.numpy()))
class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
......@@ -110,6 +123,7 @@ class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
self.seq_len = 128
self.head_dim = 32
self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
......@@ -120,6 +134,7 @@ class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
self.seq_len = 512
self.head_dim = 16
self.dtype = 'float64'
self.use_mask = True
class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
......@@ -130,6 +145,7 @@ class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
self.seq_len = 512
self.head_dim = 32
self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
......@@ -140,6 +156,7 @@ class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
self.seq_len = 512
self.head_dim = 64
self.dtype = 'float64'
self.use_mask = True
if __name__ == '__main__':
......
......@@ -23,8 +23,8 @@ def attention(query,
key,
value,
sparse_mask,
key_padding_mask,
attn_mask,
key_padding_mask=None,
attn_mask=None,
name=None):
"""
Note:
......@@ -50,10 +50,10 @@ def attention(query,
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64.
attn_mask(DenseTensor):The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64.
key_padding_mask(DenseTensor, optional): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
attn_mask(DenseTensor, optional): The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册