未验证 提交 fa463b90 编写于 作者: L Liu-xiandong 提交者: GitHub

Add sparse_attention mask ,test=develop (#37973)

Add key_padding_mask and attn_mask in sparse_attention Api

1.Key padding mask is a tensor with dimensions [batch_size, seq_len], and attention mask is a tensor with dimensions [seq_len, seq_len]. The data types of the two masks are consistent with Q, K, and V, which are float32 or float64. If the value in Mask is 0, it means that the position needs to be masked.

2.The changed files are mainly paddle/fluid/operators/sparse_attention_op.cu and python/paddle/fluid/tests/unittests/test_sparse_attention_op.py. sparse_attention has three parts: sddmm, softmax, and dsd. Adding the mask operation only needs to modify the softmax. It has no effect on the other two parts. In addition, in order to test the mask function, related tests has been added.
上级 524389ee
......@@ -43,6 +43,14 @@ class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default: Tensor<int32>), The input tensor of columns in "
"CSR sparse format, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_num]`.");
AddInput("KeyPaddingMask",
"(Tensor), The input tensor of key padding mask"
"whose dimension : `[batch_size, target_len]`.")
.AsDispensable();
AddInput("AttnMask",
"(Tensor), The input tensor of attention mask"
"whose dimension : `[target_len, target_len]`.")
.AsDispensable();
AddOutput(
"Out",
"(Tensor), The output tensor of result in attention, "
......
......@@ -72,24 +72,32 @@ __global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale,
const int cur_block_nnz =
layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row];
T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize];
T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize];
// read kp mask
T cur_kp_mask = (kp_mask == nullptr) ? 0 : kp_mask[cur_row];
T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0};
T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0};
// read tensor data, attn mask
const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize;
const T* srcptr = src + layout_rowptr[cur_block_row];
T* attnptr = nullptr;
if (attn_mask != nullptr) {
const T* attnptr = attn_mask + cur_block_row * num_rows;
}
const T* attnptr = (attn_mask == nullptr)
? nullptr
: (attn_mask + cur_block_row * num_rows);
// the coloumn start index in current row
const int* colindex = layout_colindex + layout_rowptr[cur_block_row];
for (int j = 0; j < iter; j++) {
int cur_block_col = j * WarpSize + threadIdx.x;
int cur_reg_index = j;
if (cur_block_col < cur_block_nnz) {
// read kp mask
T cur_kp_mask;
if ((kp_mask != nullptr) &&
std::abs(kp_mask[colindex[cur_block_col]]) <
std::numeric_limits<T>::epsilon()) {
cur_kp_mask = -std::numeric_limits<T>::infinity();
} else {
cur_kp_mask = 0;
}
// do mask operation
if ((attnptr != nullptr) &&
std::abs(attnptr[colindex[cur_block_col]]) <
std::numeric_limits<T>::epsilon()) {
......@@ -197,21 +205,61 @@ template <typename DeviceContext, typename T>
void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx,
const Tensor* offset, const Tensor* columns,
Tensor* input, Tensor* output, const int blocksize,
const int num_rows, const int num_cols) {
const int num_rows, const int num_cols,
const Tensor* key_padding_mask,
const Tensor* attn_mask) {
const int* offset_data = offset->data<int>();
const int* columns_data = columns->data<int>();
T* input_data = input->data<T>();
T* output_data = output->data<T>();
// Add mask
const T* key_padding_mask_data =
(key_padding_mask != nullptr) ? key_padding_mask->data<T>() : nullptr;
const T* attn_mask_data =
(attn_mask != nullptr) ? attn_mask->data<T>() : nullptr;
const int block_size = 1;
dim3 blocks(32, 4, 1);
int grid = (num_rows * block_size + 3) / 4;
T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols));
const int block_nnz_max = 256;
BlockSparseSoftmaxForward<T, block_size, block_nnz_max><<<grid, blocks>>>(
output_data, input_data, scaling, nullptr, nullptr, offset_data,
columns_data, num_rows);
if (num_cols <= 4) {
BlockSparseSoftmaxForward<T, block_size, 4><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 4 && num_cols <= 8) {
BlockSparseSoftmaxForward<T, block_size, 8><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 8 && num_cols <= 16) {
BlockSparseSoftmaxForward<T, block_size, 16><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 16 && num_cols <= 32) {
BlockSparseSoftmaxForward<T, block_size, 32><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 32 && num_cols <= 64) {
BlockSparseSoftmaxForward<T, block_size, 64><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 64 && num_cols <= 128) {
BlockSparseSoftmaxForward<T, block_size, 128><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 128 && num_cols <= 256) {
BlockSparseSoftmaxForward<T, block_size, 256><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 256 && num_cols <= 512) {
BlockSparseSoftmaxForward<T, block_size, 512><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The head_dim of query in sparse_attention op should less or equal "
"512"));
}
}
template <typename DeviceContext, typename T>
......@@ -231,10 +279,43 @@ void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx,
int grid = (num_rows * block_size + 3) / 4;
T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols));
const int block_nnz_max = 256;
BlockSparseSoftmaxBackward<T, block_size, block_nnz_max><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
if (num_cols <= 4) {
BlockSparseSoftmaxBackward<T, block_size, 4><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 4 && num_cols <= 8) {
BlockSparseSoftmaxBackward<T, block_size, 8><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 8 && num_cols <= 16) {
BlockSparseSoftmaxBackward<T, block_size, 16><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 16 && num_cols <= 32) {
BlockSparseSoftmaxBackward<T, block_size, 32><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 32 && num_cols <= 64) {
BlockSparseSoftmaxBackward<T, block_size, 64><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 64 && num_cols <= 128) {
BlockSparseSoftmaxBackward<T, block_size, 128><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 128 && num_cols <= 256) {
BlockSparseSoftmaxBackward<T, block_size, 256><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 256 && num_cols <= 512) {
BlockSparseSoftmaxBackward<T, block_size, 512><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The head_dim of query in sparse_attention op should less or equal "
"512"));
}
}
using VarType = framework::proto::VarType;
......@@ -408,6 +489,12 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> {
sparse_dot_sdd_ptr->mutable_data<T>(ctx.GetPlace());
auto softmax_ptr = ctx.Output<Tensor>("Softmax");
softmax_ptr->mutable_data<T>(ctx.GetPlace());
// add Mask
auto* key_padding_mask = ctx.HasInput("KeyPaddingMask")
? ctx.Input<Tensor>("KeyPaddingMask")
: nullptr;
auto* attn_mask =
ctx.HasInput("AttnMask") ? ctx.Input<Tensor>("AttnMask") : nullptr;
auto output = *output_ptr;
auto result_sdd = *sparse_dot_sdd_ptr;
......@@ -435,9 +522,25 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> {
&offset_lists[i], &columns_lists[i],
&result_sdd_lists[i], M, N, false, true);
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N);
if (key_padding_mask != nullptr && attn_mask != nullptr) {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N,
key_padding_mask + (i / num_heads) * M, attn_mask);
} else if (key_padding_mask != nullptr && attn_mask == nullptr) {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N,
key_padding_mask + (i / num_heads) * M, nullptr);
} else if (key_padding_mask == nullptr && attn_mask != nullptr) {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N, nullptr, attn_mask);
} else {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N, nullptr, nullptr);
}
DotDsd<DeviceContext, T>(dev_ctx, &offset_lists[i], &columns_lists[i],
&result_softmax_lists[i], &value_lists[i],
......
......@@ -71,6 +71,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}},
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......
......@@ -23,6 +23,7 @@ import paddle.fluid.framework as framework
import paddle.nn.functional as F
import os
import re
import copy
def get_cuda_version():
......@@ -37,12 +38,47 @@ def get_cuda_version():
return -1
def softmax(x):
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True)
f_x = e_x / sum
return f_x
def masked_fill(x):
row, col = x.shape[0], x.shape[1]
for i in range(row):
for j in range(col):
if x[i][j] == 0:
x[i][j] = float('-inf')
return x
def init_mask(x):
row, col = x.shape[0], x.shape[1]
for i in range(row):
for j in range(col):
if x[i][j] == 0 and (j < 0.8 * col):
x[i][j] = 1
return x
def softmax(x, kp_mask=None, attn_mask=None, bsz=None):
if kp_mask is None and attn_mask is None:
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True)
f_x = e_x / sum
return f_x
else:
# kp_mask
current_kp_mask = kp_mask[bsz]
row = current_kp_mask.shape[0]
current_kp_mask = np.expand_dims(current_kp_mask, 0).repeat(row, axis=0)
# attn_mask
current_attn_mask = copy.deepcopy(attn_mask)
current_attn_mask = masked_fill(current_attn_mask)
current_kp_mask = masked_fill(current_kp_mask)
x = x + current_kp_mask
x = x + current_attn_mask
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True)
f_x = e_x / sum
return f_x
def get_csr_value(mat, layout, nnz):
......@@ -57,7 +93,14 @@ def get_csr_value(mat, layout, nnz):
return value
def ref_sparse_attention(q, k, v, offset, columns):
def ref_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None,
bsz=None):
row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
mat = np.zeros((row, row))
for cur_row in range(row):
......@@ -74,13 +117,23 @@ def ref_sparse_attention(q, k, v, offset, columns):
for j in range(row):
if mat[i][j] == 0:
a[i][j] = float('-inf')
b = softmax(a)
# softmax
if kp_mask is None and attn_mask is None:
b = softmax(a)
else:
b = softmax(a, kp_mask=kp_mask, attn_mask=attn_mask, bsz=bsz)
b_value = get_csr_value(b, mat, nnz)
result = np.dot(b, v)
return result, a_value, b_value
def ref_batch_sparse_attention(q, k, v, offset, columns):
def ref_batch_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None):
batch_size, num_heads, row, col = q.shape
nnz = columns.shape[2]
result = np.zeros((batch_size, num_heads, row, col))
......@@ -90,8 +143,19 @@ def ref_batch_sparse_attention(q, k, v, offset, columns):
for j in range(num_heads):
cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j]
cur_offset, cur_columns = offset[i][j], columns[i][j]
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns)
if kp_mask is None and attn_mask is None:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns)
else:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q,
cur_k,
cur_v,
cur_offset,
cur_columns,
kp_mask=kp_mask,
attn_mask=attn_mask,
bsz=i)
result[i][j] = cur_result
result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
return result, result_sdd, result_softmax
......@@ -133,9 +197,10 @@ def init_csr_format(batch_size, num_heads, rows, blocksize):
)
class TestSparseAttentionOp(OpTest):
def config(self):
self.shape = (1, 1, 16, 8)
self.blocksize = 2
self.shape = (1, 1, 16, 16)
self.blocksize = 4
self.dtype = "float64"
self.use_mask = True
def setUp(self):
paddle.enable_static()
......@@ -145,21 +210,52 @@ class TestSparseAttentionOp(OpTest):
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
# init CSR tensor
offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize)
self.offset = offset.astype('int32')
self.columns = columns.astype('int32')
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns)
self.inputs = {
'Q': self.q,
'K': self.k,
'V': self.v,
'Offset': self.offset,
'Columns': self.columns
}
# init mask tensor
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask = init_mask(key_padding_mask)
attn_mask = init_mask(attn_mask)
self.key_padding_mask = key_padding_mask.astype(self.dtype)
self.attn_mask = attn_mask.astype(self.dtype)
if self.use_mask == True:
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q,
self.k,
self.v,
self.offset,
self.columns,
kp_mask=self.key_padding_mask,
attn_mask=self.attn_mask)
else:
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns)
if self.use_mask == True:
self.inputs = {
'Q': self.q,
'K': self.k,
'V': self.v,
'Offset': self.offset,
'Columns': self.columns,
'KeyPaddingMask': self.key_padding_mask,
'AttnMask': self.attn_mask,
}
else:
self.inputs = {
'Q': self.q,
'K': self.k,
'V': self.v,
'Offset': self.offset,
'Columns': self.columns,
}
self.outputs = {
'Out': result.astype(self.dtype),
'SparseDotSdd': result_sdd.astype(self.dtype),
......@@ -180,6 +276,7 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
self.shape = (1, 1, 8, 16)
self.blocksize = 2
self.dtype = "float32"
self.use_mask = False
class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
......@@ -187,6 +284,7 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
self.shape = (2, 2, 32, 8)
self.blocksize = 8
self.dtype = "float64"
self.use_mask = False
@unittest.skipIf(
......@@ -199,6 +297,7 @@ class TestSparseAttentionAPI(unittest.TestCase):
self.shape = (1, 1, 8, 4)
self.blocksize = 2
self.dtype = 'float64'
self.use_mask = True
def test_static_graph(self):
paddle.enable_static()
......@@ -219,7 +318,25 @@ class TestSparseAttentionAPI(unittest.TestCase):
name="Offset", shape=offset_shape, dtype="int32")
columns = paddle.static.data(
name="Columns", shape=columns_shape, dtype="int32")
Out = F.sparse_attention(Q, K, V, offset, columns)
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
if self.use_mask == True:
key_padding_mask = paddle.static.data(
name="KeyPaddingMask",
shape=key_padding_mask_shape,
dtype=self.dtype)
attn_mask = paddle.static.data(
name="AttnMask", shape=attn_mask_shape, dtype=self.dtype)
Out = F.sparse_attention(
Q,
K,
V,
offset,
columns,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask)
else:
Out = F.sparse_attention(Q, K, V, offset, columns)
Q_np = np.random.random(self.shape).astype(self.dtype)
K_np = np.random.random(self.shape).astype(self.dtype)
......@@ -229,17 +346,46 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset_np = offset_np.astype('int32')
columns_np = columns_np.astype('int32')
# init mask tensor
key_padding_mask_np = np.random.randint(
0, 2, size=key_padding_mask_shape)
attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask_np = init_mask(key_padding_mask_np)
attn_mask_np = init_mask(attn_mask_np)
key_padding_mask_np = key_padding_mask_np.astype(self.dtype)
attn_mask_np = attn_mask_np.astype(self.dtype)
exe = fluid.Executor(self.place)
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np
},
fetch_list=[Out])
expected_result, __, __ = ref_batch_sparse_attention(
Q_np, K_np, V_np, offset_np, columns_np)
if self.use_mask == True:
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np,
'KeyPaddingMask': key_padding_mask_np,
'AttnMask': attn_mask_np
},
fetch_list=[Out])
expected_result, __, __ = ref_batch_sparse_attention(
Q_np,
K_np,
V_np,
offset_np,
columns_np,
kp_mask=key_padding_mask_np,
attn_mask=attn_mask_np)
else:
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np
},
fetch_list=[Out])
expected_result, __, __ = ref_batch_sparse_attention(
Q_np, K_np, V_np, offset_np, columns_np)
self.assertTrue(
np.allclose(
......@@ -254,20 +400,51 @@ class TestSparseAttentionAPI(unittest.TestCase):
query = np.random.random(self.shape).astype(self.dtype)
key = np.random.random(self.shape).astype(self.dtype)
value = np.random.random(self.shape).astype(self.dtype)
# init mask tensor
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask = init_mask(key_padding_mask)
attn_mask = init_mask(attn_mask)
key_padding_mask = key_padding_mask.astype(self.dtype)
attn_mask = attn_mask.astype(self.dtype)
paddle_query = paddle.to_tensor(query, place=self.place)
paddle_key = paddle.to_tensor(key, place=self.place)
paddle_value = paddle.to_tensor(value, place=self.place)
paddle_offset = paddle.to_tensor(offset, place=self.place)
paddle_colunmns = paddle.to_tensor(columns, place=self.place)
paddle_result = F.sparse_attention(paddle_query, paddle_key,
paddle_value, paddle_offset,
paddle_colunmns)
numpy_result, __, __ = ref_batch_sparse_attention(query, key, value,
offset, columns)
numpy_result = numpy_result.astype(self.dtype)
paddle_kp_mask = paddle.to_tensor(key_padding_mask, place=self.place)
paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place)
if self.use_mask == True:
paddle_result = F.sparse_attention(
paddle_query,
paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
key_padding_mask=paddle_kp_mask,
attn_mask=paddle_attn_mask)
numpy_result, __, __ = ref_batch_sparse_attention(
query,
key,
value,
offset,
columns,
kp_mask=key_padding_mask,
attn_mask=attn_mask)
numpy_result = numpy_result.astype(self.dtype)
else:
paddle_result = F.sparse_attention(paddle_query, paddle_key,
paddle_value, paddle_offset,
paddle_colunmns)
numpy_result, __, __ = ref_batch_sparse_attention(query, key, value,
offset, columns)
numpy_result = numpy_result.astype(self.dtype)
self.assertTrue(
np.allclose(
......@@ -280,6 +457,7 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
self.shape = (2, 2, 8, 4)
self.blocksize = 2
self.dtype = 'float32'
self.use_mask = False
class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
......@@ -288,6 +466,7 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
self.shape = (2, 2, 64, 32)
self.blocksize = 2
self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
......@@ -296,6 +475,7 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
self.shape = (2, 1, 64, 32)
self.blocksize = 2
self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
......@@ -304,6 +484,7 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
self.shape = (4, 4, 128, 32)
self.blocksize = 8
self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
......@@ -312,6 +493,7 @@ class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
self.shape = (3, 3, 35, 15)
self.blocksize = 3
self.dtype = 'float64'
self.use_mask = False
if __name__ == '__main__':
......
......@@ -25,6 +25,8 @@ def sparse_attention(query,
value,
sparse_csr_offset,
sparse_csr_columns,
key_padding_mask=None,
attn_mask=None,
name=None):
r"""
This operator sparsify the Attention matrix in Transformer module
......@@ -68,6 +70,14 @@ def sparse_attention(query,
3-D tensor with shape:
[batch_size, num_heads, sparse_nnz].
The dtype should be int32.
key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module.
2-D tensor with shape: [batch_size, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
attn_mask(Tensor, optional):The attention mask tensor in the Attention module.
2-D tensor with shape: [seq_len, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
......@@ -83,7 +93,7 @@ def sparse_attention(query,
# required: skiptest
import paddle
import numpy as np
query_data = np.array([[[[0, 1,], [2, 3],
[ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3],
......@@ -94,6 +104,8 @@ def sparse_attention(query,
4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1,
0, 1, 2, 3, 2, 3]]]).astype("int32")
key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32")
attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
......@@ -111,10 +123,21 @@ def sparse_attention(query,
place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
output_mask = paddle.nn.functional.sparse_attention(query, key,
value, offset, columns,
key_padding_mask=key_padding_mask, attn_mask=attention_mask)
print(output_mask)
# [[[[0. , 1. ],
# [1.99830270, 2.99830270],
# [0. , 1. ],
# [0. , 1. ]]]]
output = paddle.nn.functional.sparse_attention(query, key,
value, offset, columns)
print(output)
print(output)
# [[[[1.60885942, 2.60885954],
# [1.99830270, 2.99830270],
# [1.60885942, 2.60885954],
......@@ -122,7 +145,8 @@ def sparse_attention(query,
"""
if in_dygraph_mode():
result_attention, result_sdd, result_softmax = _C_ops.sparse_attention(
query, key, value, sparse_csr_offset, sparse_csr_columns)
query, key, value, sparse_csr_offset, sparse_csr_columns,
key_padding_mask, attn_mask)
return result_attention
helper = LayerHelper('sparse_attention', **locals())
......@@ -135,7 +159,9 @@ def sparse_attention(query,
'K': key,
'V': value,
'Offset': sparse_csr_offset,
'Columns': sparse_csr_columns
'Columns': sparse_csr_columns,
'KeyPaddingMask': key_padding_mask,
'AttnMask': attn_mask,
}
outputs = {
'Out': out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册