From fa463b901f1e06dfbb5fda172cb13af0de49a580 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Thu, 16 Dec 2021 10:36:28 +0800 Subject: [PATCH] Add sparse_attention mask ,test=develop (#37973) Add key_padding_mask and attn_mask in sparse_attention Api 1.Key padding mask is a tensor with dimensions [batch_size, seq_len], and attention mask is a tensor with dimensions [seq_len, seq_len]. The data types of the two masks are consistent with Q, K, and V, which are float32 or float64. If the value in Mask is 0, it means that the position needs to be masked. 2.The changed files are mainly paddle/fluid/operators/sparse_attention_op.cu and python/paddle/fluid/tests/unittests/test_sparse_attention_op.py. sparse_attention has three parts: sddmm, softmax, and dsd. Adding the mask operation only needs to modify the softmax. It has no effect on the other two parts. In addition, in order to test the mask function, related tests has been added. --- paddle/fluid/operators/sparse_attention_op.cc | 8 + paddle/fluid/operators/sparse_attention_op.cu | 145 ++++++++-- paddle/fluid/pybind/op_function_generator.h | 2 + .../unittests/test_sparse_attention_op.py | 268 +++++++++++++++--- .../paddle/nn/functional/sparse_attention.py | 36 ++- 5 files changed, 390 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/sparse_attention_op.cc b/paddle/fluid/operators/sparse_attention_op.cc index 9b6bc1b629..a6534543a6 100644 --- a/paddle/fluid/operators/sparse_attention_op.cc +++ b/paddle/fluid/operators/sparse_attention_op.cc @@ -43,6 +43,14 @@ class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, default: Tensor), The input tensor of columns in " "CSR sparse format, " "whose dimension : `[batch_size, num_heads, sparse_nnz_num]`."); + AddInput("KeyPaddingMask", + "(Tensor), The input tensor of key padding mask" + "whose dimension : `[batch_size, target_len]`.") + .AsDispensable(); + AddInput("AttnMask", + "(Tensor), The input tensor of attention mask" + "whose dimension : `[target_len, target_len]`.") + .AsDispensable(); AddOutput( "Out", "(Tensor), The output tensor of result in attention, " diff --git a/paddle/fluid/operators/sparse_attention_op.cu b/paddle/fluid/operators/sparse_attention_op.cu index 88ee8999c5..b937de1bc8 100644 --- a/paddle/fluid/operators/sparse_attention_op.cu +++ b/paddle/fluid/operators/sparse_attention_op.cu @@ -72,24 +72,32 @@ __global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale, const int cur_block_nnz = layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row]; - T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; - T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; - - // read kp mask - T cur_kp_mask = (kp_mask == nullptr) ? 0 : kp_mask[cur_row]; + T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0}; + T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0}; // read tensor data, attn mask const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize; const T* srcptr = src + layout_rowptr[cur_block_row]; - T* attnptr = nullptr; - if (attn_mask != nullptr) { - const T* attnptr = attn_mask + cur_block_row * num_rows; - } + + const T* attnptr = (attn_mask == nullptr) + ? nullptr + : (attn_mask + cur_block_row * num_rows); + // the coloumn start index in current row const int* colindex = layout_colindex + layout_rowptr[cur_block_row]; for (int j = 0; j < iter; j++) { int cur_block_col = j * WarpSize + threadIdx.x; int cur_reg_index = j; if (cur_block_col < cur_block_nnz) { + // read kp mask + T cur_kp_mask; + if ((kp_mask != nullptr) && + std::abs(kp_mask[colindex[cur_block_col]]) < + std::numeric_limits::epsilon()) { + cur_kp_mask = -std::numeric_limits::infinity(); + } else { + cur_kp_mask = 0; + } + // do mask operation if ((attnptr != nullptr) && std::abs(attnptr[colindex[cur_block_col]]) < std::numeric_limits::epsilon()) { @@ -197,21 +205,61 @@ template void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx, const Tensor* offset, const Tensor* columns, Tensor* input, Tensor* output, const int blocksize, - const int num_rows, const int num_cols) { + const int num_rows, const int num_cols, + const Tensor* key_padding_mask, + const Tensor* attn_mask) { const int* offset_data = offset->data(); const int* columns_data = columns->data(); T* input_data = input->data(); T* output_data = output->data(); + // Add mask + const T* key_padding_mask_data = + (key_padding_mask != nullptr) ? key_padding_mask->data() : nullptr; + const T* attn_mask_data = + (attn_mask != nullptr) ? attn_mask->data() : nullptr; const int block_size = 1; dim3 blocks(32, 4, 1); int grid = (num_rows * block_size + 3) / 4; T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); - const int block_nnz_max = 256; - BlockSparseSoftmaxForward<<>>( - output_data, input_data, scaling, nullptr, nullptr, offset_data, - columns_data, num_rows); + if (num_cols <= 4) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 4 && num_cols <= 8) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 8 && num_cols <= 16) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 16 && num_cols <= 32) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 32 && num_cols <= 64) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 64 && num_cols <= 128) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 128 && num_cols <= 256) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 256 && num_cols <= 512) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The head_dim of query in sparse_attention op should less or equal " + "512")); + } } template @@ -231,10 +279,43 @@ void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx, int grid = (num_rows * block_size + 3) / 4; T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); - const int block_nnz_max = 256; - BlockSparseSoftmaxBackward<<>>( - dx_data, dout_data, out_data, scaling, offset_data, columns_data, - num_rows); + if (num_cols <= 4) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 4 && num_cols <= 8) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 8 && num_cols <= 16) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 16 && num_cols <= 32) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 32 && num_cols <= 64) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 64 && num_cols <= 128) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 128 && num_cols <= 256) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 256 && num_cols <= 512) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The head_dim of query in sparse_attention op should less or equal " + "512")); + } } using VarType = framework::proto::VarType; @@ -408,6 +489,12 @@ class SparseAttentionCUDAKernel : public framework::OpKernel { sparse_dot_sdd_ptr->mutable_data(ctx.GetPlace()); auto softmax_ptr = ctx.Output("Softmax"); softmax_ptr->mutable_data(ctx.GetPlace()); + // add Mask + auto* key_padding_mask = ctx.HasInput("KeyPaddingMask") + ? ctx.Input("KeyPaddingMask") + : nullptr; + auto* attn_mask = + ctx.HasInput("AttnMask") ? ctx.Input("AttnMask") : nullptr; auto output = *output_ptr; auto result_sdd = *sparse_dot_sdd_ptr; @@ -435,9 +522,25 @@ class SparseAttentionCUDAKernel : public framework::OpKernel { &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], M, N, false, true); - SparseSoftmaxForward( - dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], - &result_softmax_lists[i], 1, M, N); + if (key_padding_mask != nullptr && attn_mask != nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, + key_padding_mask + (i / num_heads) * M, attn_mask); + } else if (key_padding_mask != nullptr && attn_mask == nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, + key_padding_mask + (i / num_heads) * M, nullptr); + } else if (key_padding_mask == nullptr && attn_mask != nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, nullptr, attn_mask); + } else { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, nullptr, nullptr); + } DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], &result_softmax_lists[i], &value_lists[i], diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 7000097e0a..3e1c5b736f 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,8 @@ std::map> op_ins_map = { {"adamw", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow", "MasterParam"}}, + {"sparse_attention", + {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index cce4742f16..c016a482f3 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -23,6 +23,7 @@ import paddle.fluid.framework as framework import paddle.nn.functional as F import os import re +import copy def get_cuda_version(): @@ -37,12 +38,47 @@ def get_cuda_version(): return -1 -def softmax(x): - max = np.max(x, axis=1, keepdims=True) - e_x = np.exp(x - max) - sum = np.sum(e_x, axis=1, keepdims=True) - f_x = e_x / sum - return f_x +def masked_fill(x): + row, col = x.shape[0], x.shape[1] + for i in range(row): + for j in range(col): + if x[i][j] == 0: + x[i][j] = float('-inf') + return x + + +def init_mask(x): + row, col = x.shape[0], x.shape[1] + for i in range(row): + for j in range(col): + if x[i][j] == 0 and (j < 0.8 * col): + x[i][j] = 1 + return x + + +def softmax(x, kp_mask=None, attn_mask=None, bsz=None): + if kp_mask is None and attn_mask is None: + max = np.max(x, axis=1, keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x, axis=1, keepdims=True) + f_x = e_x / sum + return f_x + else: + # kp_mask + current_kp_mask = kp_mask[bsz] + row = current_kp_mask.shape[0] + current_kp_mask = np.expand_dims(current_kp_mask, 0).repeat(row, axis=0) + # attn_mask + current_attn_mask = copy.deepcopy(attn_mask) + current_attn_mask = masked_fill(current_attn_mask) + current_kp_mask = masked_fill(current_kp_mask) + x = x + current_kp_mask + x = x + current_attn_mask + max = np.max(x, axis=1, keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x, axis=1, keepdims=True) + f_x = e_x / sum + return f_x def get_csr_value(mat, layout, nnz): @@ -57,7 +93,14 @@ def get_csr_value(mat, layout, nnz): return value -def ref_sparse_attention(q, k, v, offset, columns): +def ref_sparse_attention(q, + k, + v, + offset, + columns, + kp_mask=None, + attn_mask=None, + bsz=None): row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] mat = np.zeros((row, row)) for cur_row in range(row): @@ -74,13 +117,23 @@ def ref_sparse_attention(q, k, v, offset, columns): for j in range(row): if mat[i][j] == 0: a[i][j] = float('-inf') - b = softmax(a) + # softmax + if kp_mask is None and attn_mask is None: + b = softmax(a) + else: + b = softmax(a, kp_mask=kp_mask, attn_mask=attn_mask, bsz=bsz) b_value = get_csr_value(b, mat, nnz) result = np.dot(b, v) return result, a_value, b_value -def ref_batch_sparse_attention(q, k, v, offset, columns): +def ref_batch_sparse_attention(q, + k, + v, + offset, + columns, + kp_mask=None, + attn_mask=None): batch_size, num_heads, row, col = q.shape nnz = columns.shape[2] result = np.zeros((batch_size, num_heads, row, col)) @@ -90,8 +143,19 @@ def ref_batch_sparse_attention(q, k, v, offset, columns): for j in range(num_heads): cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] cur_offset, cur_columns = offset[i][j], columns[i][j] - cur_result, cur_sdd, cur_softmax = ref_sparse_attention( - cur_q, cur_k, cur_v, cur_offset, cur_columns) + if kp_mask is None and attn_mask is None: + cur_result, cur_sdd, cur_softmax = ref_sparse_attention( + cur_q, cur_k, cur_v, cur_offset, cur_columns) + else: + cur_result, cur_sdd, cur_softmax = ref_sparse_attention( + cur_q, + cur_k, + cur_v, + cur_offset, + cur_columns, + kp_mask=kp_mask, + attn_mask=attn_mask, + bsz=i) result[i][j] = cur_result result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax return result, result_sdd, result_softmax @@ -133,9 +197,10 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): ) class TestSparseAttentionOp(OpTest): def config(self): - self.shape = (1, 1, 16, 8) - self.blocksize = 2 + self.shape = (1, 1, 16, 16) + self.blocksize = 4 self.dtype = "float64" + self.use_mask = True def setUp(self): paddle.enable_static() @@ -145,21 +210,52 @@ class TestSparseAttentionOp(OpTest): self.q = np.random.random(self.shape).astype(self.dtype) self.k = np.random.random(self.shape).astype(self.dtype) self.v = np.random.random(self.shape).astype(self.dtype) + # init CSR tensor offset, columns = init_csr_format(self.shape[0], self.shape[1], self.shape[2], self.blocksize) self.offset = offset.astype('int32') self.columns = columns.astype('int32') - - result, result_sdd, result_softmax = ref_batch_sparse_attention( - self.q, self.k, self.v, self.offset, self.columns) - - self.inputs = { - 'Q': self.q, - 'K': self.k, - 'V': self.v, - 'Offset': self.offset, - 'Columns': self.columns - } + # init mask tensor + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape) + attn_mask = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask = init_mask(key_padding_mask) + attn_mask = init_mask(attn_mask) + + self.key_padding_mask = key_padding_mask.astype(self.dtype) + self.attn_mask = attn_mask.astype(self.dtype) + if self.use_mask == True: + result, result_sdd, result_softmax = ref_batch_sparse_attention( + self.q, + self.k, + self.v, + self.offset, + self.columns, + kp_mask=self.key_padding_mask, + attn_mask=self.attn_mask) + else: + result, result_sdd, result_softmax = ref_batch_sparse_attention( + self.q, self.k, self.v, self.offset, self.columns) + + if self.use_mask == True: + self.inputs = { + 'Q': self.q, + 'K': self.k, + 'V': self.v, + 'Offset': self.offset, + 'Columns': self.columns, + 'KeyPaddingMask': self.key_padding_mask, + 'AttnMask': self.attn_mask, + } + else: + self.inputs = { + 'Q': self.q, + 'K': self.k, + 'V': self.v, + 'Offset': self.offset, + 'Columns': self.columns, + } self.outputs = { 'Out': result.astype(self.dtype), 'SparseDotSdd': result_sdd.astype(self.dtype), @@ -180,6 +276,7 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): self.shape = (1, 1, 8, 16) self.blocksize = 2 self.dtype = "float32" + self.use_mask = False class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): @@ -187,6 +284,7 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): self.shape = (2, 2, 32, 8) self.blocksize = 8 self.dtype = "float64" + self.use_mask = False @unittest.skipIf( @@ -199,6 +297,7 @@ class TestSparseAttentionAPI(unittest.TestCase): self.shape = (1, 1, 8, 4) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = True def test_static_graph(self): paddle.enable_static() @@ -219,7 +318,25 @@ class TestSparseAttentionAPI(unittest.TestCase): name="Offset", shape=offset_shape, dtype="int32") columns = paddle.static.data( name="Columns", shape=columns_shape, dtype="int32") - Out = F.sparse_attention(Q, K, V, offset, columns) + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + if self.use_mask == True: + key_padding_mask = paddle.static.data( + name="KeyPaddingMask", + shape=key_padding_mask_shape, + dtype=self.dtype) + attn_mask = paddle.static.data( + name="AttnMask", shape=attn_mask_shape, dtype=self.dtype) + Out = F.sparse_attention( + Q, + K, + V, + offset, + columns, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask) + else: + Out = F.sparse_attention(Q, K, V, offset, columns) Q_np = np.random.random(self.shape).astype(self.dtype) K_np = np.random.random(self.shape).astype(self.dtype) @@ -229,17 +346,46 @@ class TestSparseAttentionAPI(unittest.TestCase): offset_np = offset_np.astype('int32') columns_np = columns_np.astype('int32') + # init mask tensor + key_padding_mask_np = np.random.randint( + 0, 2, size=key_padding_mask_shape) + attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask_np = init_mask(key_padding_mask_np) + attn_mask_np = init_mask(attn_mask_np) + key_padding_mask_np = key_padding_mask_np.astype(self.dtype) + attn_mask_np = attn_mask_np.astype(self.dtype) + exe = fluid.Executor(self.place) - fetches_result = exe.run(feed={ - "Q": Q_np, - "K": K_np, - "V": V_np, - "Offset": offset_np, - "Columns": columns_np - }, - fetch_list=[Out]) - expected_result, __, __ = ref_batch_sparse_attention( - Q_np, K_np, V_np, offset_np, columns_np) + if self.use_mask == True: + fetches_result = exe.run(feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np, + 'KeyPaddingMask': key_padding_mask_np, + 'AttnMask': attn_mask_np + }, + fetch_list=[Out]) + expected_result, __, __ = ref_batch_sparse_attention( + Q_np, + K_np, + V_np, + offset_np, + columns_np, + kp_mask=key_padding_mask_np, + attn_mask=attn_mask_np) + else: + fetches_result = exe.run(feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np + }, + fetch_list=[Out]) + expected_result, __, __ = ref_batch_sparse_attention( + Q_np, K_np, V_np, offset_np, columns_np) self.assertTrue( np.allclose( @@ -254,20 +400,51 @@ class TestSparseAttentionAPI(unittest.TestCase): query = np.random.random(self.shape).astype(self.dtype) key = np.random.random(self.shape).astype(self.dtype) value = np.random.random(self.shape).astype(self.dtype) + # init mask tensor + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape) + attn_mask = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask = init_mask(key_padding_mask) + attn_mask = init_mask(attn_mask) + key_padding_mask = key_padding_mask.astype(self.dtype) + attn_mask = attn_mask.astype(self.dtype) paddle_query = paddle.to_tensor(query, place=self.place) paddle_key = paddle.to_tensor(key, place=self.place) paddle_value = paddle.to_tensor(value, place=self.place) paddle_offset = paddle.to_tensor(offset, place=self.place) paddle_colunmns = paddle.to_tensor(columns, place=self.place) - - paddle_result = F.sparse_attention(paddle_query, paddle_key, - paddle_value, paddle_offset, - paddle_colunmns) - - numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, - offset, columns) - numpy_result = numpy_result.astype(self.dtype) + paddle_kp_mask = paddle.to_tensor(key_padding_mask, place=self.place) + paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place) + + if self.use_mask == True: + paddle_result = F.sparse_attention( + paddle_query, + paddle_key, + paddle_value, + paddle_offset, + paddle_colunmns, + key_padding_mask=paddle_kp_mask, + attn_mask=paddle_attn_mask) + + numpy_result, __, __ = ref_batch_sparse_attention( + query, + key, + value, + offset, + columns, + kp_mask=key_padding_mask, + attn_mask=attn_mask) + numpy_result = numpy_result.astype(self.dtype) + else: + paddle_result = F.sparse_attention(paddle_query, paddle_key, + paddle_value, paddle_offset, + paddle_colunmns) + + numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, + offset, columns) + numpy_result = numpy_result.astype(self.dtype) self.assertTrue( np.allclose( @@ -280,6 +457,7 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): self.shape = (2, 2, 8, 4) self.blocksize = 2 self.dtype = 'float32' + self.use_mask = False class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): @@ -288,6 +466,7 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): self.shape = (2, 2, 64, 32) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): @@ -296,6 +475,7 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): self.shape = (2, 1, 64, 32) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): @@ -304,6 +484,7 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): self.shape = (4, 4, 128, 32) self.blocksize = 8 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): @@ -312,6 +493,7 @@ class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): self.shape = (3, 3, 35, 15) self.blocksize = 3 self.dtype = 'float64' + self.use_mask = False if __name__ == '__main__': diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index b98e8142f4..c39fcb8554 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -25,6 +25,8 @@ def sparse_attention(query, value, sparse_csr_offset, sparse_csr_columns, + key_padding_mask=None, + attn_mask=None, name=None): r""" This operator sparsify the Attention matrix in Transformer module @@ -68,6 +70,14 @@ def sparse_attention(query, 3-D tensor with shape: [batch_size, num_heads, sparse_nnz]. The dtype should be int32. + key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module. + 2-D tensor with shape: [batch_size, seq_len]. + The dtype can be float32 and float64. + A value of 0 means that the position is masked. + attn_mask(Tensor, optional):The attention mask tensor in the Attention module. + 2-D tensor with shape: [seq_len, seq_len]. + The dtype can be float32 and float64. + A value of 0 means that the position is masked. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -83,7 +93,7 @@ def sparse_attention(query, # required: skiptest import paddle import numpy as np - + query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32") key_data = np.array([[[[0, 1,], [2, 3], @@ -94,6 +104,8 @@ def sparse_attention(query, 4, 6, 8]]]).astype("int32") sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32") + key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32") + attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32") print(query_data.shape) # (1, 1, 4, 2) print(sparse_csr_offset_data.shape) @@ -111,10 +123,21 @@ def sparse_attention(query, place=paddle.CUDAPlace(0)) columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0)) + key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + output_mask = paddle.nn.functional.sparse_attention(query, key, + value, offset, columns, + key_padding_mask=key_padding_mask, attn_mask=attention_mask) + print(output_mask) + # [[[[0. , 1. ], + # [1.99830270, 2.99830270], + # [0. , 1. ], + # [0. , 1. ]]]] output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns) - print(output) - + print(output) # [[[[1.60885942, 2.60885954], # [1.99830270, 2.99830270], # [1.60885942, 2.60885954], @@ -122,7 +145,8 @@ def sparse_attention(query, """ if in_dygraph_mode(): result_attention, result_sdd, result_softmax = _C_ops.sparse_attention( - query, key, value, sparse_csr_offset, sparse_csr_columns) + query, key, value, sparse_csr_offset, sparse_csr_columns, + key_padding_mask, attn_mask) return result_attention helper = LayerHelper('sparse_attention', **locals()) @@ -135,7 +159,9 @@ def sparse_attention(query, 'K': key, 'V': value, 'Offset': sparse_csr_offset, - 'Columns': sparse_csr_columns + 'Columns': sparse_csr_columns, + 'KeyPaddingMask': key_padding_mask, + 'AttnMask': attn_mask, } outputs = { 'Out': out, -- GitLab