diff --git a/paddle/fluid/operators/sparse_attention_op.cc b/paddle/fluid/operators/sparse_attention_op.cc index 9b6bc1b6290451b9616e51bcacd221f22cb26cf6..a6534543a6515a80886fc61953310f6988f20b3f 100644 --- a/paddle/fluid/operators/sparse_attention_op.cc +++ b/paddle/fluid/operators/sparse_attention_op.cc @@ -43,6 +43,14 @@ class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, default: Tensor), The input tensor of columns in " "CSR sparse format, " "whose dimension : `[batch_size, num_heads, sparse_nnz_num]`."); + AddInput("KeyPaddingMask", + "(Tensor), The input tensor of key padding mask" + "whose dimension : `[batch_size, target_len]`.") + .AsDispensable(); + AddInput("AttnMask", + "(Tensor), The input tensor of attention mask" + "whose dimension : `[target_len, target_len]`.") + .AsDispensable(); AddOutput( "Out", "(Tensor), The output tensor of result in attention, " diff --git a/paddle/fluid/operators/sparse_attention_op.cu b/paddle/fluid/operators/sparse_attention_op.cu index 88ee8999c5f4af725b684585496b37565f953269..b937de1bc86842dcd3da42525b20d5bde6832614 100644 --- a/paddle/fluid/operators/sparse_attention_op.cu +++ b/paddle/fluid/operators/sparse_attention_op.cu @@ -72,24 +72,32 @@ __global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale, const int cur_block_nnz = layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row]; - T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; - T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; - - // read kp mask - T cur_kp_mask = (kp_mask == nullptr) ? 0 : kp_mask[cur_row]; + T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0}; + T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0}; // read tensor data, attn mask const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize; const T* srcptr = src + layout_rowptr[cur_block_row]; - T* attnptr = nullptr; - if (attn_mask != nullptr) { - const T* attnptr = attn_mask + cur_block_row * num_rows; - } + + const T* attnptr = (attn_mask == nullptr) + ? nullptr + : (attn_mask + cur_block_row * num_rows); + // the coloumn start index in current row const int* colindex = layout_colindex + layout_rowptr[cur_block_row]; for (int j = 0; j < iter; j++) { int cur_block_col = j * WarpSize + threadIdx.x; int cur_reg_index = j; if (cur_block_col < cur_block_nnz) { + // read kp mask + T cur_kp_mask; + if ((kp_mask != nullptr) && + std::abs(kp_mask[colindex[cur_block_col]]) < + std::numeric_limits::epsilon()) { + cur_kp_mask = -std::numeric_limits::infinity(); + } else { + cur_kp_mask = 0; + } + // do mask operation if ((attnptr != nullptr) && std::abs(attnptr[colindex[cur_block_col]]) < std::numeric_limits::epsilon()) { @@ -197,21 +205,61 @@ template void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx, const Tensor* offset, const Tensor* columns, Tensor* input, Tensor* output, const int blocksize, - const int num_rows, const int num_cols) { + const int num_rows, const int num_cols, + const Tensor* key_padding_mask, + const Tensor* attn_mask) { const int* offset_data = offset->data(); const int* columns_data = columns->data(); T* input_data = input->data(); T* output_data = output->data(); + // Add mask + const T* key_padding_mask_data = + (key_padding_mask != nullptr) ? key_padding_mask->data() : nullptr; + const T* attn_mask_data = + (attn_mask != nullptr) ? attn_mask->data() : nullptr; const int block_size = 1; dim3 blocks(32, 4, 1); int grid = (num_rows * block_size + 3) / 4; T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); - const int block_nnz_max = 256; - BlockSparseSoftmaxForward<<>>( - output_data, input_data, scaling, nullptr, nullptr, offset_data, - columns_data, num_rows); + if (num_cols <= 4) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 4 && num_cols <= 8) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 8 && num_cols <= 16) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 16 && num_cols <= 32) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 32 && num_cols <= 64) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 64 && num_cols <= 128) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 128 && num_cols <= 256) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else if (num_cols > 256 && num_cols <= 512) { + BlockSparseSoftmaxForward<<>>( + output_data, input_data, scaling, key_padding_mask_data, attn_mask_data, + offset_data, columns_data, num_rows); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The head_dim of query in sparse_attention op should less or equal " + "512")); + } } template @@ -231,10 +279,43 @@ void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx, int grid = (num_rows * block_size + 3) / 4; T scaling = static_cast(1.0) / sqrt(static_cast(num_cols)); - const int block_nnz_max = 256; - BlockSparseSoftmaxBackward<<>>( - dx_data, dout_data, out_data, scaling, offset_data, columns_data, - num_rows); + if (num_cols <= 4) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 4 && num_cols <= 8) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 8 && num_cols <= 16) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 16 && num_cols <= 32) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 32 && num_cols <= 64) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 64 && num_cols <= 128) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 128 && num_cols <= 256) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else if (num_cols > 256 && num_cols <= 512) { + BlockSparseSoftmaxBackward<<>>( + dx_data, dout_data, out_data, scaling, offset_data, columns_data, + num_rows); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The head_dim of query in sparse_attention op should less or equal " + "512")); + } } using VarType = framework::proto::VarType; @@ -408,6 +489,12 @@ class SparseAttentionCUDAKernel : public framework::OpKernel { sparse_dot_sdd_ptr->mutable_data(ctx.GetPlace()); auto softmax_ptr = ctx.Output("Softmax"); softmax_ptr->mutable_data(ctx.GetPlace()); + // add Mask + auto* key_padding_mask = ctx.HasInput("KeyPaddingMask") + ? ctx.Input("KeyPaddingMask") + : nullptr; + auto* attn_mask = + ctx.HasInput("AttnMask") ? ctx.Input("AttnMask") : nullptr; auto output = *output_ptr; auto result_sdd = *sparse_dot_sdd_ptr; @@ -435,9 +522,25 @@ class SparseAttentionCUDAKernel : public framework::OpKernel { &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], M, N, false, true); - SparseSoftmaxForward( - dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], - &result_softmax_lists[i], 1, M, N); + if (key_padding_mask != nullptr && attn_mask != nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, + key_padding_mask + (i / num_heads) * M, attn_mask); + } else if (key_padding_mask != nullptr && attn_mask == nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, + key_padding_mask + (i / num_heads) * M, nullptr); + } else if (key_padding_mask == nullptr && attn_mask != nullptr) { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, nullptr, attn_mask); + } else { + SparseSoftmaxForward( + dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], + &result_softmax_lists[i], 1, M, N, nullptr, nullptr); + } DotDsd(dev_ctx, &offset_lists[i], &columns_lists[i], &result_softmax_lists[i], &value_lists[i], diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index 7000097e0abcb216a5634cc18a70a49cff33e513..3e1c5b736f27eeb5d8b24c219b4f321333313a6e 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,8 @@ std::map> op_ins_map = { {"adamw", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow", "MasterParam"}}, + {"sparse_attention", + {"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}}, }; // NOTE(zhiqiu): Like op_ins_map. diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index cce4742f164557379a2d3adcea6c5af120a733b8..c016a482f36ec13b9cc1492a47e266f540d699f6 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -23,6 +23,7 @@ import paddle.fluid.framework as framework import paddle.nn.functional as F import os import re +import copy def get_cuda_version(): @@ -37,12 +38,47 @@ def get_cuda_version(): return -1 -def softmax(x): - max = np.max(x, axis=1, keepdims=True) - e_x = np.exp(x - max) - sum = np.sum(e_x, axis=1, keepdims=True) - f_x = e_x / sum - return f_x +def masked_fill(x): + row, col = x.shape[0], x.shape[1] + for i in range(row): + for j in range(col): + if x[i][j] == 0: + x[i][j] = float('-inf') + return x + + +def init_mask(x): + row, col = x.shape[0], x.shape[1] + for i in range(row): + for j in range(col): + if x[i][j] == 0 and (j < 0.8 * col): + x[i][j] = 1 + return x + + +def softmax(x, kp_mask=None, attn_mask=None, bsz=None): + if kp_mask is None and attn_mask is None: + max = np.max(x, axis=1, keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x, axis=1, keepdims=True) + f_x = e_x / sum + return f_x + else: + # kp_mask + current_kp_mask = kp_mask[bsz] + row = current_kp_mask.shape[0] + current_kp_mask = np.expand_dims(current_kp_mask, 0).repeat(row, axis=0) + # attn_mask + current_attn_mask = copy.deepcopy(attn_mask) + current_attn_mask = masked_fill(current_attn_mask) + current_kp_mask = masked_fill(current_kp_mask) + x = x + current_kp_mask + x = x + current_attn_mask + max = np.max(x, axis=1, keepdims=True) + e_x = np.exp(x - max) + sum = np.sum(e_x, axis=1, keepdims=True) + f_x = e_x / sum + return f_x def get_csr_value(mat, layout, nnz): @@ -57,7 +93,14 @@ def get_csr_value(mat, layout, nnz): return value -def ref_sparse_attention(q, k, v, offset, columns): +def ref_sparse_attention(q, + k, + v, + offset, + columns, + kp_mask=None, + attn_mask=None, + bsz=None): row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] mat = np.zeros((row, row)) for cur_row in range(row): @@ -74,13 +117,23 @@ def ref_sparse_attention(q, k, v, offset, columns): for j in range(row): if mat[i][j] == 0: a[i][j] = float('-inf') - b = softmax(a) + # softmax + if kp_mask is None and attn_mask is None: + b = softmax(a) + else: + b = softmax(a, kp_mask=kp_mask, attn_mask=attn_mask, bsz=bsz) b_value = get_csr_value(b, mat, nnz) result = np.dot(b, v) return result, a_value, b_value -def ref_batch_sparse_attention(q, k, v, offset, columns): +def ref_batch_sparse_attention(q, + k, + v, + offset, + columns, + kp_mask=None, + attn_mask=None): batch_size, num_heads, row, col = q.shape nnz = columns.shape[2] result = np.zeros((batch_size, num_heads, row, col)) @@ -90,8 +143,19 @@ def ref_batch_sparse_attention(q, k, v, offset, columns): for j in range(num_heads): cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] cur_offset, cur_columns = offset[i][j], columns[i][j] - cur_result, cur_sdd, cur_softmax = ref_sparse_attention( - cur_q, cur_k, cur_v, cur_offset, cur_columns) + if kp_mask is None and attn_mask is None: + cur_result, cur_sdd, cur_softmax = ref_sparse_attention( + cur_q, cur_k, cur_v, cur_offset, cur_columns) + else: + cur_result, cur_sdd, cur_softmax = ref_sparse_attention( + cur_q, + cur_k, + cur_v, + cur_offset, + cur_columns, + kp_mask=kp_mask, + attn_mask=attn_mask, + bsz=i) result[i][j] = cur_result result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax return result, result_sdd, result_softmax @@ -133,9 +197,10 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): ) class TestSparseAttentionOp(OpTest): def config(self): - self.shape = (1, 1, 16, 8) - self.blocksize = 2 + self.shape = (1, 1, 16, 16) + self.blocksize = 4 self.dtype = "float64" + self.use_mask = True def setUp(self): paddle.enable_static() @@ -145,21 +210,52 @@ class TestSparseAttentionOp(OpTest): self.q = np.random.random(self.shape).astype(self.dtype) self.k = np.random.random(self.shape).astype(self.dtype) self.v = np.random.random(self.shape).astype(self.dtype) + # init CSR tensor offset, columns = init_csr_format(self.shape[0], self.shape[1], self.shape[2], self.blocksize) self.offset = offset.astype('int32') self.columns = columns.astype('int32') - - result, result_sdd, result_softmax = ref_batch_sparse_attention( - self.q, self.k, self.v, self.offset, self.columns) - - self.inputs = { - 'Q': self.q, - 'K': self.k, - 'V': self.v, - 'Offset': self.offset, - 'Columns': self.columns - } + # init mask tensor + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape) + attn_mask = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask = init_mask(key_padding_mask) + attn_mask = init_mask(attn_mask) + + self.key_padding_mask = key_padding_mask.astype(self.dtype) + self.attn_mask = attn_mask.astype(self.dtype) + if self.use_mask == True: + result, result_sdd, result_softmax = ref_batch_sparse_attention( + self.q, + self.k, + self.v, + self.offset, + self.columns, + kp_mask=self.key_padding_mask, + attn_mask=self.attn_mask) + else: + result, result_sdd, result_softmax = ref_batch_sparse_attention( + self.q, self.k, self.v, self.offset, self.columns) + + if self.use_mask == True: + self.inputs = { + 'Q': self.q, + 'K': self.k, + 'V': self.v, + 'Offset': self.offset, + 'Columns': self.columns, + 'KeyPaddingMask': self.key_padding_mask, + 'AttnMask': self.attn_mask, + } + else: + self.inputs = { + 'Q': self.q, + 'K': self.k, + 'V': self.v, + 'Offset': self.offset, + 'Columns': self.columns, + } self.outputs = { 'Out': result.astype(self.dtype), 'SparseDotSdd': result_sdd.astype(self.dtype), @@ -180,6 +276,7 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): self.shape = (1, 1, 8, 16) self.blocksize = 2 self.dtype = "float32" + self.use_mask = False class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): @@ -187,6 +284,7 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): self.shape = (2, 2, 32, 8) self.blocksize = 8 self.dtype = "float64" + self.use_mask = False @unittest.skipIf( @@ -199,6 +297,7 @@ class TestSparseAttentionAPI(unittest.TestCase): self.shape = (1, 1, 8, 4) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = True def test_static_graph(self): paddle.enable_static() @@ -219,7 +318,25 @@ class TestSparseAttentionAPI(unittest.TestCase): name="Offset", shape=offset_shape, dtype="int32") columns = paddle.static.data( name="Columns", shape=columns_shape, dtype="int32") - Out = F.sparse_attention(Q, K, V, offset, columns) + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + if self.use_mask == True: + key_padding_mask = paddle.static.data( + name="KeyPaddingMask", + shape=key_padding_mask_shape, + dtype=self.dtype) + attn_mask = paddle.static.data( + name="AttnMask", shape=attn_mask_shape, dtype=self.dtype) + Out = F.sparse_attention( + Q, + K, + V, + offset, + columns, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask) + else: + Out = F.sparse_attention(Q, K, V, offset, columns) Q_np = np.random.random(self.shape).astype(self.dtype) K_np = np.random.random(self.shape).astype(self.dtype) @@ -229,17 +346,46 @@ class TestSparseAttentionAPI(unittest.TestCase): offset_np = offset_np.astype('int32') columns_np = columns_np.astype('int32') + # init mask tensor + key_padding_mask_np = np.random.randint( + 0, 2, size=key_padding_mask_shape) + attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask_np = init_mask(key_padding_mask_np) + attn_mask_np = init_mask(attn_mask_np) + key_padding_mask_np = key_padding_mask_np.astype(self.dtype) + attn_mask_np = attn_mask_np.astype(self.dtype) + exe = fluid.Executor(self.place) - fetches_result = exe.run(feed={ - "Q": Q_np, - "K": K_np, - "V": V_np, - "Offset": offset_np, - "Columns": columns_np - }, - fetch_list=[Out]) - expected_result, __, __ = ref_batch_sparse_attention( - Q_np, K_np, V_np, offset_np, columns_np) + if self.use_mask == True: + fetches_result = exe.run(feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np, + 'KeyPaddingMask': key_padding_mask_np, + 'AttnMask': attn_mask_np + }, + fetch_list=[Out]) + expected_result, __, __ = ref_batch_sparse_attention( + Q_np, + K_np, + V_np, + offset_np, + columns_np, + kp_mask=key_padding_mask_np, + attn_mask=attn_mask_np) + else: + fetches_result = exe.run(feed={ + "Q": Q_np, + "K": K_np, + "V": V_np, + "Offset": offset_np, + "Columns": columns_np + }, + fetch_list=[Out]) + expected_result, __, __ = ref_batch_sparse_attention( + Q_np, K_np, V_np, offset_np, columns_np) self.assertTrue( np.allclose( @@ -254,20 +400,51 @@ class TestSparseAttentionAPI(unittest.TestCase): query = np.random.random(self.shape).astype(self.dtype) key = np.random.random(self.shape).astype(self.dtype) value = np.random.random(self.shape).astype(self.dtype) + # init mask tensor + key_padding_mask_shape = (self.shape[0], self.shape[2]) + attn_mask_shape = (self.shape[2], self.shape[2]) + key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape) + attn_mask = np.random.randint(0, 2, size=attn_mask_shape) + key_padding_mask = init_mask(key_padding_mask) + attn_mask = init_mask(attn_mask) + key_padding_mask = key_padding_mask.astype(self.dtype) + attn_mask = attn_mask.astype(self.dtype) paddle_query = paddle.to_tensor(query, place=self.place) paddle_key = paddle.to_tensor(key, place=self.place) paddle_value = paddle.to_tensor(value, place=self.place) paddle_offset = paddle.to_tensor(offset, place=self.place) paddle_colunmns = paddle.to_tensor(columns, place=self.place) - - paddle_result = F.sparse_attention(paddle_query, paddle_key, - paddle_value, paddle_offset, - paddle_colunmns) - - numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, - offset, columns) - numpy_result = numpy_result.astype(self.dtype) + paddle_kp_mask = paddle.to_tensor(key_padding_mask, place=self.place) + paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place) + + if self.use_mask == True: + paddle_result = F.sparse_attention( + paddle_query, + paddle_key, + paddle_value, + paddle_offset, + paddle_colunmns, + key_padding_mask=paddle_kp_mask, + attn_mask=paddle_attn_mask) + + numpy_result, __, __ = ref_batch_sparse_attention( + query, + key, + value, + offset, + columns, + kp_mask=key_padding_mask, + attn_mask=attn_mask) + numpy_result = numpy_result.astype(self.dtype) + else: + paddle_result = F.sparse_attention(paddle_query, paddle_key, + paddle_value, paddle_offset, + paddle_colunmns) + + numpy_result, __, __ = ref_batch_sparse_attention(query, key, value, + offset, columns) + numpy_result = numpy_result.astype(self.dtype) self.assertTrue( np.allclose( @@ -280,6 +457,7 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): self.shape = (2, 2, 8, 4) self.blocksize = 2 self.dtype = 'float32' + self.use_mask = False class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): @@ -288,6 +466,7 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): self.shape = (2, 2, 64, 32) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): @@ -296,6 +475,7 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): self.shape = (2, 1, 64, 32) self.blocksize = 2 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): @@ -304,6 +484,7 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): self.shape = (4, 4, 128, 32) self.blocksize = 8 self.dtype = 'float64' + self.use_mask = False class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): @@ -312,6 +493,7 @@ class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): self.shape = (3, 3, 35, 15) self.blocksize = 3 self.dtype = 'float64' + self.use_mask = False if __name__ == '__main__': diff --git a/python/paddle/nn/functional/sparse_attention.py b/python/paddle/nn/functional/sparse_attention.py index b98e8142f457f1cd75982f5f1b6c373b073c6c22..c39fcb8554a2f6e64610e3135043f86c7943e620 100644 --- a/python/paddle/nn/functional/sparse_attention.py +++ b/python/paddle/nn/functional/sparse_attention.py @@ -25,6 +25,8 @@ def sparse_attention(query, value, sparse_csr_offset, sparse_csr_columns, + key_padding_mask=None, + attn_mask=None, name=None): r""" This operator sparsify the Attention matrix in Transformer module @@ -68,6 +70,14 @@ def sparse_attention(query, 3-D tensor with shape: [batch_size, num_heads, sparse_nnz]. The dtype should be int32. + key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module. + 2-D tensor with shape: [batch_size, seq_len]. + The dtype can be float32 and float64. + A value of 0 means that the position is masked. + attn_mask(Tensor, optional):The attention mask tensor in the Attention module. + 2-D tensor with shape: [seq_len, seq_len]. + The dtype can be float32 and float64. + A value of 0 means that the position is masked. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -83,7 +93,7 @@ def sparse_attention(query, # required: skiptest import paddle import numpy as np - + query_data = np.array([[[[0, 1,], [2, 3], [ 0, 1], [2, 3]]]]).astype("float32") key_data = np.array([[[[0, 1,], [2, 3], @@ -94,6 +104,8 @@ def sparse_attention(query, 4, 6, 8]]]).astype("int32") sparse_csr_columns_data = np.array([[[0, 1, 0, 1, 2, 3, 2, 3]]]).astype("int32") + key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32") + attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32") print(query_data.shape) # (1, 1, 4, 2) print(sparse_csr_offset_data.shape) @@ -111,10 +123,21 @@ def sparse_attention(query, place=paddle.CUDAPlace(0)) columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, place=paddle.CUDAPlace(0)) + key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False, + place=paddle.CUDAPlace(0)) + output_mask = paddle.nn.functional.sparse_attention(query, key, + value, offset, columns, + key_padding_mask=key_padding_mask, attn_mask=attention_mask) + print(output_mask) + # [[[[0. , 1. ], + # [1.99830270, 2.99830270], + # [0. , 1. ], + # [0. , 1. ]]]] output = paddle.nn.functional.sparse_attention(query, key, value, offset, columns) - print(output) - + print(output) # [[[[1.60885942, 2.60885954], # [1.99830270, 2.99830270], # [1.60885942, 2.60885954], @@ -122,7 +145,8 @@ def sparse_attention(query, """ if in_dygraph_mode(): result_attention, result_sdd, result_softmax = _C_ops.sparse_attention( - query, key, value, sparse_csr_offset, sparse_csr_columns) + query, key, value, sparse_csr_offset, sparse_csr_columns, + key_padding_mask, attn_mask) return result_attention helper = LayerHelper('sparse_attention', **locals()) @@ -135,7 +159,9 @@ def sparse_attention(query, 'K': key, 'V': value, 'Offset': sparse_csr_offset, - 'Columns': sparse_csr_columns + 'Columns': sparse_csr_columns, + 'KeyPaddingMask': key_padding_mask, + 'AttnMask': attn_mask, } outputs = { 'Out': out,