未验证 提交 9c32099d 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

[Sparse] support optional kp_mask/attn_mask of sparse attention (#44120)

上级 064e549b
...@@ -111,9 +111,8 @@ class SparseAPI(ForwardAPI): ...@@ -111,9 +111,8 @@ class SparseAPI(ForwardAPI):
for param in kernel_param: for param in kernel_param:
if param in input_names: if param in input_names:
if param in self.optional_vars: if param in self.optional_vars:
raise ValueError( kernel_context_code = kernel_context_code + f"""
f"{self.api} : Unsupport optional input({param}) for sparse api." kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);"""
)
else: else:
kernel_context_code = kernel_context_code + f""" kernel_context_code = kernel_context_code + f"""
kernel_context.EmplaceBackInput({param}.impl().get());""" kernel_context.EmplaceBackInput({param}.impl().get());"""
...@@ -170,9 +169,14 @@ class SparseAPI(ForwardAPI): ...@@ -170,9 +169,14 @@ class SparseAPI(ForwardAPI):
condition_list = [] condition_list = []
for i, in_type in enumerate(input_types): for i, in_type in enumerate(input_types):
if in_type == "dense": if in_type == "dense":
condition_list.append( if self.inputs['names'][i] in self.optional_vars:
f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())" condition_list.append(
) f"(!{self.inputs['names'][i]} || phi::DenseTensor::classof({self.inputs['names'][i]}->impl().get()))"
)
else:
condition_list.append(
f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())"
)
else: else:
condition_list.append( condition_list.append(
f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}" f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}"
......
...@@ -147,6 +147,8 @@ ...@@ -147,6 +147,8 @@
kernel : kernel :
func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr} func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr}
layout : sparse_mask layout : sparse_mask
data_type: query
optional : key_padding_mask, attn_mask
intermediate : softmax intermediate : softmax
backward: fused_attention_grad backward: fused_attention_grad
......
...@@ -134,3 +134,5 @@ ...@@ -134,3 +134,5 @@
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel : kernel :
func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout : softmax
data_type: query
...@@ -21,15 +21,16 @@ namespace phi { ...@@ -21,15 +21,16 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename Context> template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx, void FusedAttentionCsrKernel(
const DenseTensor& query, const Context& dev_ctx,
const DenseTensor& key, const DenseTensor& query,
const DenseTensor& value, const DenseTensor& key,
const SparseCsrTensor& sparse_mask, const DenseTensor& value,
const DenseTensor& key_padding_mask, const SparseCsrTensor& sparse_mask,
const DenseTensor& attn_mask, const paddle::optional<DenseTensor>& key_padding_mask,
DenseTensor* out, const paddle::optional<DenseTensor>& attn_mask,
SparseCsrTensor* softmax) { DenseTensor* out,
SparseCsrTensor* softmax) {
PD_THROW( PD_THROW(
"Not support CPU kernel of 'sparse.nn.functional.fused_attention' now"); "Not support CPU kernel of 'sparse.nn.functional.fused_attention' now");
} }
......
...@@ -21,15 +21,16 @@ namespace phi { ...@@ -21,15 +21,16 @@ namespace phi {
namespace sparse { namespace sparse {
template <typename T, typename Context> template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx, void FusedAttentionCsrKernel(
const DenseTensor& query, const Context& dev_ctx,
const DenseTensor& key, const DenseTensor& query,
const DenseTensor& value, const DenseTensor& key,
const SparseCsrTensor& sparse_mask, const DenseTensor& value,
const DenseTensor& key_padding_mask, const SparseCsrTensor& sparse_mask,
const DenseTensor& attn_mask, const paddle::optional<DenseTensor>& key_padding_mask,
DenseTensor* out, const paddle::optional<DenseTensor>& attn_mask,
SparseCsrTensor* softmax); DenseTensor* out,
SparseCsrTensor* softmax);
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
...@@ -127,15 +127,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, ...@@ -127,15 +127,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
} }
template <typename T, typename Context> template <typename T, typename Context>
void FusedAttentionCsrKernel(const Context& dev_ctx, void FusedAttentionCsrKernel(
const DenseTensor& query, const Context& dev_ctx,
const DenseTensor& key, const DenseTensor& query,
const DenseTensor& value, const DenseTensor& key,
const SparseCsrTensor& sparse_mask, const DenseTensor& value,
const DenseTensor& key_padding_mask, const SparseCsrTensor& sparse_mask,
const DenseTensor& attn_mask, const paddle::optional<DenseTensor>& key_padding_mask,
DenseTensor* out, const paddle::optional<DenseTensor>& attn_mask,
SparseCsrTensor* softmax) { DenseTensor* out,
SparseCsrTensor* softmax) {
#if CUDA_VERSION >= 11070 #if CUDA_VERSION >= 11070
/* Check Shape */ /* Check Shape */
auto q_dim = query.dims(); auto q_dim = query.dims();
...@@ -183,34 +184,40 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, ...@@ -183,34 +184,40 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be "
"[batch_size*num_heads, seq_len, seq_len]")); "[batch_size*num_heads, seq_len, seq_len]"));
PADDLE_ENFORCE_EQ( const auto kp_mask_ptr = key_padding_mask.get_ptr();
key_padding_mask.dims().size(), if (kp_mask_ptr) {
2, PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( kp_mask_ptr->dims().size(),
"shape of 'key_padding_mask' must be [batch_size, seq_len]")); 2,
PADDLE_ENFORCE_EQ( phi::errors::InvalidArgument(
key_padding_mask.dims()[0], "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
q_dim[0], PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( kp_mask_ptr->dims()[0],
"shape of 'key_padding_mask' must be [batch_size, seq_len]")); q_dim[0],
PADDLE_ENFORCE_EQ( phi::errors::InvalidArgument(
key_padding_mask.dims()[1], "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
M, PADDLE_ENFORCE_EQ(
phi::errors::InvalidArgument( kp_mask_ptr->dims()[1],
"shape of 'key_padding_mask' must be [batch_size, seq_len]")); M,
phi::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(attn_mask.dims().size(), "shape of 'key_padding_mask' must be [batch_size, seq_len]"));
2, }
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]")); const auto attn_mask_ptr = attn_mask.get_ptr();
PADDLE_ENFORCE_EQ(attn_mask.dims()[0], if (attn_mask_ptr) {
M, PADDLE_ENFORCE_EQ(attn_mask_ptr->dims().size(),
phi::errors::InvalidArgument( 2,
"shape of 'attn_mask' must be [seq_len, seq_len]")); phi::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(attn_mask.dims()[1], "shape of 'attn_mask' must be [seq_len, seq_len]"));
M, PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[0],
phi::errors::InvalidArgument( M,
"shape of 'attn_mask' must be [seq_len, seq_len]")); phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[1],
M,
phi::errors::InvalidArgument(
"shape of 'attn_mask' must be [seq_len, seq_len]"));
}
/* Step1: SDD Matmul, reuse */ /* Step1: SDD Matmul, reuse */
SparseCsrTensor sdd_result; SparseCsrTensor sdd_result;
...@@ -244,8 +251,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, ...@@ -244,8 +251,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx,
sdd_result.non_zero_crows().data<int64_t>(), sdd_result.non_zero_crows().data<int64_t>(),
sdd_result.non_zero_cols().data<int64_t>(), sdd_result.non_zero_cols().data<int64_t>(),
sdd_result.non_zero_elements().data<T>(), sdd_result.non_zero_elements().data<T>(),
key_padding_mask.data<T>(), kp_mask_ptr ? kp_mask_ptr->data<T>() : nullptr,
attn_mask.data<T>(), attn_mask_ptr ? attn_mask_ptr->data<T>() : nullptr,
softmax->mutable_non_zero_elements()->data<T>(), softmax->mutable_non_zero_elements()->data<T>(),
M, M,
total_row_num, total_row_num,
......
...@@ -47,6 +47,7 @@ class TestSparseAttentionAPI1(unittest.TestCase): ...@@ -47,6 +47,7 @@ class TestSparseAttentionAPI1(unittest.TestCase):
self.seq_len = 128 self.seq_len = 128
self.head_dim = 16 self.head_dim = 16
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = True
def test_dygraph(self): def test_dygraph(self):
with _test_eager_guard(): with _test_eager_guard():
...@@ -69,37 +70,49 @@ class TestSparseAttentionAPI1(unittest.TestCase): ...@@ -69,37 +70,49 @@ class TestSparseAttentionAPI1(unittest.TestCase):
sp_mask = mask.reshape([-1, self.seq_len, sp_mask = mask.reshape([-1, self.seq_len,
self.seq_len]).to_sparse_csr() self.seq_len]).to_sparse_csr()
kp_mask = paddle.randint( query_sp = copy.deepcopy(query)
0, 2, [self.batch_size, self.seq_len]).astype(self.dtype) key_sp = copy.deepcopy(key)
attn_mask = paddle.randint( value_sp = copy.deepcopy(value)
0, 2, [self.seq_len, self.seq_len]).astype(self.dtype)
query_sp.stop_gradient = False
sdd = paddle.matmul(query, key, False, True) / math.sqrt( key_sp.stop_gradient = False
float(self.head_dim)) value_sp.stop_gradient = False
sdd = sdd + (
(mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9 if self.use_mask:
softmax = paddle.nn.functional.softmax(sdd) kp_mask = paddle.randint(
output = paddle.matmul(softmax, value) 0, 2, [self.batch_size, self.seq_len]).astype(self.dtype)
output.backward() attn_mask = paddle.randint(
0, 2, [self.seq_len, self.seq_len]).astype(self.dtype)
query_cp = copy.deepcopy(query)
key_cp = copy.deepcopy(key) sdd = paddle.matmul(query, key, False, True) / math.sqrt(
value_cp = copy.deepcopy(value) float(self.head_dim))
sdd = sdd + (
query_cp.stop_gradient = False (mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9
key_cp.stop_gradient = False softmax = paddle.nn.functional.softmax(sdd)
value_cp.stop_gradient = False output = paddle.matmul(softmax, value)
output.backward()
output_cp = paddle.incubate.sparse.nn.functional.attention(
query_cp, key_cp, value_cp, sp_mask, kp_mask, attn_mask) output_sp = paddle.incubate.sparse.nn.functional.attention(
output_cp.backward() query_sp, key_sp, value_sp, sp_mask, kp_mask, attn_mask)
output_sp.backward()
self.assertTrue(np.allclose(output_cp.numpy(), output.numpy())) else:
sdd = paddle.matmul(query, key, False, True) / math.sqrt(
float(self.head_dim))
sdd = sdd + (mask - 1.0) * 1e9
softmax = paddle.nn.functional.softmax(sdd)
output = paddle.matmul(softmax, value)
output.backward()
output_sp = paddle.incubate.sparse.nn.functional.attention(
query_sp, key_sp, value_sp, sp_mask)
output_sp.backward()
self.assertTrue(np.allclose(output_sp.numpy(), output.numpy()))
self.assertTrue( self.assertTrue(
np.allclose(query_cp.grad.numpy(), query.grad.numpy())) np.allclose(query_sp.grad.numpy(), query.grad.numpy()))
self.assertTrue(np.allclose(key_cp.grad.numpy(), key.grad.numpy())) self.assertTrue(np.allclose(key_sp.grad.numpy(), key.grad.numpy()))
self.assertTrue( self.assertTrue(
np.allclose(value_cp.grad.numpy(), value.grad.numpy())) np.allclose(value_sp.grad.numpy(), value.grad.numpy()))
class TestSparseAttentionAPI2(TestSparseAttentionAPI1): class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
...@@ -110,6 +123,7 @@ class TestSparseAttentionAPI2(TestSparseAttentionAPI1): ...@@ -110,6 +123,7 @@ class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
self.seq_len = 128 self.seq_len = 128
self.head_dim = 32 self.head_dim = 32
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPI3(TestSparseAttentionAPI1): class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
...@@ -120,6 +134,7 @@ class TestSparseAttentionAPI3(TestSparseAttentionAPI1): ...@@ -120,6 +134,7 @@ class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
self.seq_len = 512 self.seq_len = 512
self.head_dim = 16 self.head_dim = 16
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = True
class TestSparseAttentionAPI4(TestSparseAttentionAPI1): class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
...@@ -130,6 +145,7 @@ class TestSparseAttentionAPI4(TestSparseAttentionAPI1): ...@@ -130,6 +145,7 @@ class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
self.seq_len = 512 self.seq_len = 512
self.head_dim = 32 self.head_dim = 32
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPI5(TestSparseAttentionAPI1): class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
...@@ -140,6 +156,7 @@ class TestSparseAttentionAPI5(TestSparseAttentionAPI1): ...@@ -140,6 +156,7 @@ class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
self.seq_len = 512 self.seq_len = 512
self.head_dim = 64 self.head_dim = 64
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = True
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -23,8 +23,8 @@ def attention(query, ...@@ -23,8 +23,8 @@ def attention(query,
key, key,
value, value,
sparse_mask, sparse_mask,
key_padding_mask, key_padding_mask=None,
attn_mask, attn_mask=None,
name=None): name=None):
""" """
Note: Note:
...@@ -50,10 +50,10 @@ def attention(query, ...@@ -50,10 +50,10 @@ def attention(query,
sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same. is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same.
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64. dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module. key_padding_mask(DenseTensor, optional): The key padding mask tensor in the Attention module.
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
attn_mask(DenseTensor):The attention mask tensor in the Attention module. attn_mask(DenseTensor, optional): The attention mask tensor in the Attention module.
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册