From 9c32099d050713d1416fb59750761316f5f0831a Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 6 Jul 2022 17:55:56 +0800 Subject: [PATCH] [Sparse] support optional kp_mask/attn_mask of sparse attention (#44120) --- .../phi/api/yaml/generator/sparse_api_gen.py | 16 ++-- paddle/phi/api/yaml/sparse_api.yaml | 2 + paddle/phi/api/yaml/sparse_bw_api.yaml | 2 + .../sparse/cpu/fused_attention_kernel.cc | 19 +++-- .../kernels/sparse/fused_attention_kernel.h | 19 +++-- .../sparse/gpu/fused_attention_kernel.cu | 85 ++++++++++--------- .../test_sparse_fused_attention_op.py | 75 +++++++++------- .../sparse/nn/functional/transformer.py | 12 +-- 8 files changed, 132 insertions(+), 98 deletions(-) diff --git a/paddle/phi/api/yaml/generator/sparse_api_gen.py b/paddle/phi/api/yaml/generator/sparse_api_gen.py index 17eb70e5c3..69bf6950cd 100644 --- a/paddle/phi/api/yaml/generator/sparse_api_gen.py +++ b/paddle/phi/api/yaml/generator/sparse_api_gen.py @@ -111,9 +111,8 @@ class SparseAPI(ForwardAPI): for param in kernel_param: if param in input_names: if param in self.optional_vars: - raise ValueError( - f"{self.api} : Unsupport optional input({param}) for sparse api." - ) + kernel_context_code = kernel_context_code + f""" + kernel_context.EmplaceBackInput({param} ? {param}->impl().get() : nullptr);""" else: kernel_context_code = kernel_context_code + f""" kernel_context.EmplaceBackInput({param}.impl().get());""" @@ -170,9 +169,14 @@ class SparseAPI(ForwardAPI): condition_list = [] for i, in_type in enumerate(input_types): if in_type == "dense": - condition_list.append( - f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())" - ) + if self.inputs['names'][i] in self.optional_vars: + condition_list.append( + f"(!{self.inputs['names'][i]} || phi::DenseTensor::classof({self.inputs['names'][i]}->impl().get()))" + ) + else: + condition_list.append( + f"phi::DenseTensor::classof({self.inputs['names'][i]}.impl().get())" + ) else: condition_list.append( f"{self.inputs['names'][i]}.layout() == {sparse_type_map[in_type]}" diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index a6520a0d48..68c41d50ae 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -147,6 +147,8 @@ kernel : func : fused_attention_csr{dense, dense, dense, sparse_csr, dense, dense -> dense, sparse_csr} layout : sparse_mask + data_type: query + optional : key_padding_mask, attn_mask intermediate : softmax backward: fused_attention_grad diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index 5296d1b870..0ca9c9daa9 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -134,3 +134,5 @@ output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad) kernel : func : fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense} + layout : softmax + data_type: query diff --git a/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc index 6c652c6a8c..11c9e2d5c2 100644 --- a/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/fused_attention_kernel.cc @@ -21,15 +21,16 @@ namespace phi { namespace sparse { template -void FusedAttentionCsrKernel(const Context& dev_ctx, - const DenseTensor& query, - const DenseTensor& key, - const DenseTensor& value, - const SparseCsrTensor& sparse_mask, - const DenseTensor& key_padding_mask, - const DenseTensor& attn_mask, - DenseTensor* out, - SparseCsrTensor* softmax) { +void FusedAttentionCsrKernel( + const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const paddle::optional& key_padding_mask, + const paddle::optional& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax) { PD_THROW( "Not support CPU kernel of 'sparse.nn.functional.fused_attention' now"); } diff --git a/paddle/phi/kernels/sparse/fused_attention_kernel.h b/paddle/phi/kernels/sparse/fused_attention_kernel.h index feff9d72e6..340fdce019 100644 --- a/paddle/phi/kernels/sparse/fused_attention_kernel.h +++ b/paddle/phi/kernels/sparse/fused_attention_kernel.h @@ -21,15 +21,16 @@ namespace phi { namespace sparse { template -void FusedAttentionCsrKernel(const Context& dev_ctx, - const DenseTensor& query, - const DenseTensor& key, - const DenseTensor& value, - const SparseCsrTensor& sparse_mask, - const DenseTensor& key_padding_mask, - const DenseTensor& attn_mask, - DenseTensor* out, - SparseCsrTensor* softmax); +void FusedAttentionCsrKernel( + const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const paddle::optional& key_padding_mask, + const paddle::optional& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax); } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu index 9a7e55d2d6..46412d57f1 100644 --- a/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu @@ -127,15 +127,16 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows, } template -void FusedAttentionCsrKernel(const Context& dev_ctx, - const DenseTensor& query, - const DenseTensor& key, - const DenseTensor& value, - const SparseCsrTensor& sparse_mask, - const DenseTensor& key_padding_mask, - const DenseTensor& attn_mask, - DenseTensor* out, - SparseCsrTensor* softmax) { +void FusedAttentionCsrKernel( + const Context& dev_ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const SparseCsrTensor& sparse_mask, + const paddle::optional& key_padding_mask, + const paddle::optional& attn_mask, + DenseTensor* out, + SparseCsrTensor* softmax) { #if CUDA_VERSION >= 11070 /* Check Shape */ auto q_dim = query.dims(); @@ -183,34 +184,40 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, phi::errors::InvalidArgument("dense shape of 'sparse_mask' must be " "[batch_size*num_heads, seq_len, seq_len]")); - PADDLE_ENFORCE_EQ( - key_padding_mask.dims().size(), - 2, - phi::errors::InvalidArgument( - "shape of 'key_padding_mask' must be [batch_size, seq_len]")); - PADDLE_ENFORCE_EQ( - key_padding_mask.dims()[0], - q_dim[0], - phi::errors::InvalidArgument( - "shape of 'key_padding_mask' must be [batch_size, seq_len]")); - PADDLE_ENFORCE_EQ( - key_padding_mask.dims()[1], - M, - phi::errors::InvalidArgument( - "shape of 'key_padding_mask' must be [batch_size, seq_len]")); - - PADDLE_ENFORCE_EQ(attn_mask.dims().size(), - 2, - phi::errors::InvalidArgument( - "shape of 'attn_mask' must be [seq_len, seq_len]")); - PADDLE_ENFORCE_EQ(attn_mask.dims()[0], - M, - phi::errors::InvalidArgument( - "shape of 'attn_mask' must be [seq_len, seq_len]")); - PADDLE_ENFORCE_EQ(attn_mask.dims()[1], - M, - phi::errors::InvalidArgument( - "shape of 'attn_mask' must be [seq_len, seq_len]")); + const auto kp_mask_ptr = key_padding_mask.get_ptr(); + if (kp_mask_ptr) { + PADDLE_ENFORCE_EQ( + kp_mask_ptr->dims().size(), + 2, + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + PADDLE_ENFORCE_EQ( + kp_mask_ptr->dims()[0], + q_dim[0], + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + PADDLE_ENFORCE_EQ( + kp_mask_ptr->dims()[1], + M, + phi::errors::InvalidArgument( + "shape of 'key_padding_mask' must be [batch_size, seq_len]")); + } + + const auto attn_mask_ptr = attn_mask.get_ptr(); + if (attn_mask_ptr) { + PADDLE_ENFORCE_EQ(attn_mask_ptr->dims().size(), + 2, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[0], + M, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + PADDLE_ENFORCE_EQ(attn_mask_ptr->dims()[1], + M, + phi::errors::InvalidArgument( + "shape of 'attn_mask' must be [seq_len, seq_len]")); + } /* Step1: SDD Matmul, reuse */ SparseCsrTensor sdd_result; @@ -244,8 +251,8 @@ void FusedAttentionCsrKernel(const Context& dev_ctx, sdd_result.non_zero_crows().data(), sdd_result.non_zero_cols().data(), sdd_result.non_zero_elements().data(), - key_padding_mask.data(), - attn_mask.data(), + kp_mask_ptr ? kp_mask_ptr->data() : nullptr, + attn_mask_ptr ? attn_mask_ptr->data() : nullptr, softmax->mutable_non_zero_elements()->data(), M, total_row_num, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py index e34f890cc5..0383247886 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_fused_attention_op.py @@ -47,6 +47,7 @@ class TestSparseAttentionAPI1(unittest.TestCase): self.seq_len = 128 self.head_dim = 16 self.dtype = 'float64' + self.use_mask = True def test_dygraph(self): with _test_eager_guard(): @@ -69,37 +70,49 @@ class TestSparseAttentionAPI1(unittest.TestCase): sp_mask = mask.reshape([-1, self.seq_len, self.seq_len]).to_sparse_csr() - kp_mask = paddle.randint( - 0, 2, [self.batch_size, self.seq_len]).astype(self.dtype) - attn_mask = paddle.randint( - 0, 2, [self.seq_len, self.seq_len]).astype(self.dtype) - - sdd = paddle.matmul(query, key, False, True) / math.sqrt( - float(self.head_dim)) - sdd = sdd + ( - (mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9 - softmax = paddle.nn.functional.softmax(sdd) - output = paddle.matmul(softmax, value) - output.backward() - - query_cp = copy.deepcopy(query) - key_cp = copy.deepcopy(key) - value_cp = copy.deepcopy(value) - - query_cp.stop_gradient = False - key_cp.stop_gradient = False - value_cp.stop_gradient = False - - output_cp = paddle.incubate.sparse.nn.functional.attention( - query_cp, key_cp, value_cp, sp_mask, kp_mask, attn_mask) - output_cp.backward() - - self.assertTrue(np.allclose(output_cp.numpy(), output.numpy())) + query_sp = copy.deepcopy(query) + key_sp = copy.deepcopy(key) + value_sp = copy.deepcopy(value) + + query_sp.stop_gradient = False + key_sp.stop_gradient = False + value_sp.stop_gradient = False + + if self.use_mask: + kp_mask = paddle.randint( + 0, 2, [self.batch_size, self.seq_len]).astype(self.dtype) + attn_mask = paddle.randint( + 0, 2, [self.seq_len, self.seq_len]).astype(self.dtype) + + sdd = paddle.matmul(query, key, False, True) / math.sqrt( + float(self.head_dim)) + sdd = sdd + ( + (mask * kp_mask.unsqueeze([1, 2]) * attn_mask) - 1.0) * 1e9 + softmax = paddle.nn.functional.softmax(sdd) + output = paddle.matmul(softmax, value) + output.backward() + + output_sp = paddle.incubate.sparse.nn.functional.attention( + query_sp, key_sp, value_sp, sp_mask, kp_mask, attn_mask) + output_sp.backward() + else: + sdd = paddle.matmul(query, key, False, True) / math.sqrt( + float(self.head_dim)) + sdd = sdd + (mask - 1.0) * 1e9 + softmax = paddle.nn.functional.softmax(sdd) + output = paddle.matmul(softmax, value) + output.backward() + + output_sp = paddle.incubate.sparse.nn.functional.attention( + query_sp, key_sp, value_sp, sp_mask) + output_sp.backward() + + self.assertTrue(np.allclose(output_sp.numpy(), output.numpy())) self.assertTrue( - np.allclose(query_cp.grad.numpy(), query.grad.numpy())) - self.assertTrue(np.allclose(key_cp.grad.numpy(), key.grad.numpy())) + np.allclose(query_sp.grad.numpy(), query.grad.numpy())) + self.assertTrue(np.allclose(key_sp.grad.numpy(), key.grad.numpy())) self.assertTrue( - np.allclose(value_cp.grad.numpy(), value.grad.numpy())) + np.allclose(value_sp.grad.numpy(), value.grad.numpy())) class TestSparseAttentionAPI2(TestSparseAttentionAPI1): @@ -110,6 +123,7 @@ class TestSparseAttentionAPI2(TestSparseAttentionAPI1): self.seq_len = 128 self.head_dim = 32 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPI3(TestSparseAttentionAPI1): @@ -120,6 +134,7 @@ class TestSparseAttentionAPI3(TestSparseAttentionAPI1): self.seq_len = 512 self.head_dim = 16 self.dtype = 'float64' + self.use_mask = True class TestSparseAttentionAPI4(TestSparseAttentionAPI1): @@ -130,6 +145,7 @@ class TestSparseAttentionAPI4(TestSparseAttentionAPI1): self.seq_len = 512 self.head_dim = 32 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPI5(TestSparseAttentionAPI1): @@ -140,6 +156,7 @@ class TestSparseAttentionAPI5(TestSparseAttentionAPI1): self.seq_len = 512 self.head_dim = 64 self.dtype = 'float64' + self.use_mask = True if __name__ == '__main__': diff --git a/python/paddle/incubate/sparse/nn/functional/transformer.py b/python/paddle/incubate/sparse/nn/functional/transformer.py index 3429eecccd..f69714700b 100644 --- a/python/paddle/incubate/sparse/nn/functional/transformer.py +++ b/python/paddle/incubate/sparse/nn/functional/transformer.py @@ -23,8 +23,8 @@ def attention(query, key, value, sparse_mask, - key_padding_mask, - attn_mask, + key_padding_mask=None, + attn_mask=None, name=None): """ Note: @@ -50,10 +50,10 @@ def attention(query, sparse_mask(SparseCsrTensor): The sparse layout in the Attention module. Its dense shape is `[batch_size*num_heads, seq_len, seq_len]` . `nnz` of each batch must be the same. dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64. - key_padding_mask(DenseTensor): The key padding mask tensor in the Attention module. - 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. - attn_mask(DenseTensor):The attention mask tensor in the Attention module. - 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. + key_padding_mask(DenseTensor, optional): The key padding mask tensor in the Attention module. + 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None. + attn_mask(DenseTensor, optional): The attention mask tensor in the Attention module. + 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. -- GitLab