未验证 提交 fa463b90 编写于 作者: L Liu-xiandong 提交者: GitHub

Add sparse_attention mask ,test=develop (#37973)

Add key_padding_mask and attn_mask in sparse_attention Api

1.Key padding mask is a tensor with dimensions [batch_size, seq_len], and attention mask is a tensor with dimensions [seq_len, seq_len]. The data types of the two masks are consistent with Q, K, and V, which are float32 or float64. If the value in Mask is 0, it means that the position needs to be masked.

2.The changed files are mainly paddle/fluid/operators/sparse_attention_op.cu and python/paddle/fluid/tests/unittests/test_sparse_attention_op.py. sparse_attention has three parts: sddmm, softmax, and dsd. Adding the mask operation only needs to modify the softmax. It has no effect on the other two parts. In addition, in order to test the mask function, related tests has been added.
上级 524389ee
...@@ -43,6 +43,14 @@ class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,6 +43,14 @@ class SparseAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default: Tensor<int32>), The input tensor of columns in " "(Tensor, default: Tensor<int32>), The input tensor of columns in "
"CSR sparse format, " "CSR sparse format, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_num]`."); "whose dimension : `[batch_size, num_heads, sparse_nnz_num]`.");
AddInput("KeyPaddingMask",
"(Tensor), The input tensor of key padding mask"
"whose dimension : `[batch_size, target_len]`.")
.AsDispensable();
AddInput("AttnMask",
"(Tensor), The input tensor of attention mask"
"whose dimension : `[target_len, target_len]`.")
.AsDispensable();
AddOutput( AddOutput(
"Out", "Out",
"(Tensor), The output tensor of result in attention, " "(Tensor), The output tensor of result in attention, "
......
...@@ -72,24 +72,32 @@ __global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale, ...@@ -72,24 +72,32 @@ __global__ void BlockSparseSoftmaxForward(T* softmax, const T* src, T scale,
const int cur_block_nnz = const int cur_block_nnz =
layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row]; layout_rowptr[cur_block_row + 1] - layout_rowptr[cur_block_row];
T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; T srcdata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0};
T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize]; T attndata[(BlockSize * BlockNnzMax + WarpSize - 1) / WarpSize] = {0};
// read kp mask
T cur_kp_mask = (kp_mask == nullptr) ? 0 : kp_mask[cur_row];
// read tensor data, attn mask // read tensor data, attn mask
const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize; const int iter = (cur_block_nnz + WarpSize - 1) / WarpSize;
const T* srcptr = src + layout_rowptr[cur_block_row]; const T* srcptr = src + layout_rowptr[cur_block_row];
T* attnptr = nullptr;
if (attn_mask != nullptr) { const T* attnptr = (attn_mask == nullptr)
const T* attnptr = attn_mask + cur_block_row * num_rows; ? nullptr
} : (attn_mask + cur_block_row * num_rows);
// the coloumn start index in current row
const int* colindex = layout_colindex + layout_rowptr[cur_block_row]; const int* colindex = layout_colindex + layout_rowptr[cur_block_row];
for (int j = 0; j < iter; j++) { for (int j = 0; j < iter; j++) {
int cur_block_col = j * WarpSize + threadIdx.x; int cur_block_col = j * WarpSize + threadIdx.x;
int cur_reg_index = j; int cur_reg_index = j;
if (cur_block_col < cur_block_nnz) { if (cur_block_col < cur_block_nnz) {
// read kp mask
T cur_kp_mask;
if ((kp_mask != nullptr) &&
std::abs(kp_mask[colindex[cur_block_col]]) <
std::numeric_limits<T>::epsilon()) {
cur_kp_mask = -std::numeric_limits<T>::infinity();
} else {
cur_kp_mask = 0;
}
// do mask operation
if ((attnptr != nullptr) && if ((attnptr != nullptr) &&
std::abs(attnptr[colindex[cur_block_col]]) < std::abs(attnptr[colindex[cur_block_col]]) <
std::numeric_limits<T>::epsilon()) { std::numeric_limits<T>::epsilon()) {
...@@ -197,21 +205,61 @@ template <typename DeviceContext, typename T> ...@@ -197,21 +205,61 @@ template <typename DeviceContext, typename T>
void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx, void SparseSoftmaxForward(const platform::CUDADeviceContext& ctx,
const Tensor* offset, const Tensor* columns, const Tensor* offset, const Tensor* columns,
Tensor* input, Tensor* output, const int blocksize, Tensor* input, Tensor* output, const int blocksize,
const int num_rows, const int num_cols) { const int num_rows, const int num_cols,
const Tensor* key_padding_mask,
const Tensor* attn_mask) {
const int* offset_data = offset->data<int>(); const int* offset_data = offset->data<int>();
const int* columns_data = columns->data<int>(); const int* columns_data = columns->data<int>();
T* input_data = input->data<T>(); T* input_data = input->data<T>();
T* output_data = output->data<T>(); T* output_data = output->data<T>();
// Add mask
const T* key_padding_mask_data =
(key_padding_mask != nullptr) ? key_padding_mask->data<T>() : nullptr;
const T* attn_mask_data =
(attn_mask != nullptr) ? attn_mask->data<T>() : nullptr;
const int block_size = 1; const int block_size = 1;
dim3 blocks(32, 4, 1); dim3 blocks(32, 4, 1);
int grid = (num_rows * block_size + 3) / 4; int grid = (num_rows * block_size + 3) / 4;
T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols)); T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols));
const int block_nnz_max = 256; if (num_cols <= 4) {
BlockSparseSoftmaxForward<T, block_size, block_nnz_max><<<grid, blocks>>>( BlockSparseSoftmaxForward<T, block_size, 4><<<grid, blocks>>>(
output_data, input_data, scaling, nullptr, nullptr, offset_data, output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
columns_data, num_rows); offset_data, columns_data, num_rows);
} else if (num_cols > 4 && num_cols <= 8) {
BlockSparseSoftmaxForward<T, block_size, 8><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 8 && num_cols <= 16) {
BlockSparseSoftmaxForward<T, block_size, 16><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 16 && num_cols <= 32) {
BlockSparseSoftmaxForward<T, block_size, 32><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 32 && num_cols <= 64) {
BlockSparseSoftmaxForward<T, block_size, 64><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 64 && num_cols <= 128) {
BlockSparseSoftmaxForward<T, block_size, 128><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 128 && num_cols <= 256) {
BlockSparseSoftmaxForward<T, block_size, 256><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else if (num_cols > 256 && num_cols <= 512) {
BlockSparseSoftmaxForward<T, block_size, 512><<<grid, blocks>>>(
output_data, input_data, scaling, key_padding_mask_data, attn_mask_data,
offset_data, columns_data, num_rows);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The head_dim of query in sparse_attention op should less or equal "
"512"));
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -231,10 +279,43 @@ void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx, ...@@ -231,10 +279,43 @@ void SparseSoftmaxBackward(const platform::CUDADeviceContext& ctx,
int grid = (num_rows * block_size + 3) / 4; int grid = (num_rows * block_size + 3) / 4;
T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols)); T scaling = static_cast<T>(1.0) / sqrt(static_cast<T>(num_cols));
const int block_nnz_max = 256; if (num_cols <= 4) {
BlockSparseSoftmaxBackward<T, block_size, block_nnz_max><<<grid, blocks>>>( BlockSparseSoftmaxBackward<T, block_size, 4><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data, dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows); num_rows);
} else if (num_cols > 4 && num_cols <= 8) {
BlockSparseSoftmaxBackward<T, block_size, 8><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 8 && num_cols <= 16) {
BlockSparseSoftmaxBackward<T, block_size, 16><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 16 && num_cols <= 32) {
BlockSparseSoftmaxBackward<T, block_size, 32><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 32 && num_cols <= 64) {
BlockSparseSoftmaxBackward<T, block_size, 64><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 64 && num_cols <= 128) {
BlockSparseSoftmaxBackward<T, block_size, 128><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 128 && num_cols <= 256) {
BlockSparseSoftmaxBackward<T, block_size, 256><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else if (num_cols > 256 && num_cols <= 512) {
BlockSparseSoftmaxBackward<T, block_size, 512><<<grid, blocks>>>(
dx_data, dout_data, out_data, scaling, offset_data, columns_data,
num_rows);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The head_dim of query in sparse_attention op should less or equal "
"512"));
}
} }
using VarType = framework::proto::VarType; using VarType = framework::proto::VarType;
...@@ -408,6 +489,12 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> { ...@@ -408,6 +489,12 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> {
sparse_dot_sdd_ptr->mutable_data<T>(ctx.GetPlace()); sparse_dot_sdd_ptr->mutable_data<T>(ctx.GetPlace());
auto softmax_ptr = ctx.Output<Tensor>("Softmax"); auto softmax_ptr = ctx.Output<Tensor>("Softmax");
softmax_ptr->mutable_data<T>(ctx.GetPlace()); softmax_ptr->mutable_data<T>(ctx.GetPlace());
// add Mask
auto* key_padding_mask = ctx.HasInput("KeyPaddingMask")
? ctx.Input<Tensor>("KeyPaddingMask")
: nullptr;
auto* attn_mask =
ctx.HasInput("AttnMask") ? ctx.Input<Tensor>("AttnMask") : nullptr;
auto output = *output_ptr; auto output = *output_ptr;
auto result_sdd = *sparse_dot_sdd_ptr; auto result_sdd = *sparse_dot_sdd_ptr;
...@@ -435,9 +522,25 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> { ...@@ -435,9 +522,25 @@ class SparseAttentionCUDAKernel : public framework::OpKernel<T> {
&offset_lists[i], &columns_lists[i], &offset_lists[i], &columns_lists[i],
&result_sdd_lists[i], M, N, false, true); &result_sdd_lists[i], M, N, false, true);
if (key_padding_mask != nullptr && attn_mask != nullptr) {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N,
key_padding_mask + (i / num_heads) * M, attn_mask);
} else if (key_padding_mask != nullptr && attn_mask == nullptr) {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N,
key_padding_mask + (i / num_heads) * M, nullptr);
} else if (key_padding_mask == nullptr && attn_mask != nullptr) {
SparseSoftmaxForward<DeviceContext, T>( SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i], dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N); &result_softmax_lists[i], 1, M, N, nullptr, attn_mask);
} else {
SparseSoftmaxForward<DeviceContext, T>(
dev_ctx, &offset_lists[i], &columns_lists[i], &result_sdd_lists[i],
&result_softmax_lists[i], 1, M, N, nullptr, nullptr);
}
DotDsd<DeviceContext, T>(dev_ctx, &offset_lists[i], &columns_lists[i], DotDsd<DeviceContext, T>(dev_ctx, &offset_lists[i], &columns_lists[i],
&result_softmax_lists[i], &value_lists[i], &result_softmax_lists[i], &value_lists[i],
......
...@@ -71,6 +71,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -71,6 +71,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"adamw", {"adamw",
{"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow",
"Beta2Pow", "MasterParam"}}, "Beta2Pow", "MasterParam"}},
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -23,6 +23,7 @@ import paddle.fluid.framework as framework ...@@ -23,6 +23,7 @@ import paddle.fluid.framework as framework
import paddle.nn.functional as F import paddle.nn.functional as F
import os import os
import re import re
import copy
def get_cuda_version(): def get_cuda_version():
...@@ -37,7 +38,42 @@ def get_cuda_version(): ...@@ -37,7 +38,42 @@ def get_cuda_version():
return -1 return -1
def softmax(x): def masked_fill(x):
row, col = x.shape[0], x.shape[1]
for i in range(row):
for j in range(col):
if x[i][j] == 0:
x[i][j] = float('-inf')
return x
def init_mask(x):
row, col = x.shape[0], x.shape[1]
for i in range(row):
for j in range(col):
if x[i][j] == 0 and (j < 0.8 * col):
x[i][j] = 1
return x
def softmax(x, kp_mask=None, attn_mask=None, bsz=None):
if kp_mask is None and attn_mask is None:
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True)
f_x = e_x / sum
return f_x
else:
# kp_mask
current_kp_mask = kp_mask[bsz]
row = current_kp_mask.shape[0]
current_kp_mask = np.expand_dims(current_kp_mask, 0).repeat(row, axis=0)
# attn_mask
current_attn_mask = copy.deepcopy(attn_mask)
current_attn_mask = masked_fill(current_attn_mask)
current_kp_mask = masked_fill(current_kp_mask)
x = x + current_kp_mask
x = x + current_attn_mask
max = np.max(x, axis=1, keepdims=True) max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max) e_x = np.exp(x - max)
sum = np.sum(e_x, axis=1, keepdims=True) sum = np.sum(e_x, axis=1, keepdims=True)
...@@ -57,7 +93,14 @@ def get_csr_value(mat, layout, nnz): ...@@ -57,7 +93,14 @@ def get_csr_value(mat, layout, nnz):
return value return value
def ref_sparse_attention(q, k, v, offset, columns): def ref_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None,
bsz=None):
row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
mat = np.zeros((row, row)) mat = np.zeros((row, row))
for cur_row in range(row): for cur_row in range(row):
...@@ -74,13 +117,23 @@ def ref_sparse_attention(q, k, v, offset, columns): ...@@ -74,13 +117,23 @@ def ref_sparse_attention(q, k, v, offset, columns):
for j in range(row): for j in range(row):
if mat[i][j] == 0: if mat[i][j] == 0:
a[i][j] = float('-inf') a[i][j] = float('-inf')
# softmax
if kp_mask is None and attn_mask is None:
b = softmax(a) b = softmax(a)
else:
b = softmax(a, kp_mask=kp_mask, attn_mask=attn_mask, bsz=bsz)
b_value = get_csr_value(b, mat, nnz) b_value = get_csr_value(b, mat, nnz)
result = np.dot(b, v) result = np.dot(b, v)
return result, a_value, b_value return result, a_value, b_value
def ref_batch_sparse_attention(q, k, v, offset, columns): def ref_batch_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None):
batch_size, num_heads, row, col = q.shape batch_size, num_heads, row, col = q.shape
nnz = columns.shape[2] nnz = columns.shape[2]
result = np.zeros((batch_size, num_heads, row, col)) result = np.zeros((batch_size, num_heads, row, col))
...@@ -90,8 +143,19 @@ def ref_batch_sparse_attention(q, k, v, offset, columns): ...@@ -90,8 +143,19 @@ def ref_batch_sparse_attention(q, k, v, offset, columns):
for j in range(num_heads): for j in range(num_heads):
cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j]
cur_offset, cur_columns = offset[i][j], columns[i][j] cur_offset, cur_columns = offset[i][j], columns[i][j]
if kp_mask is None and attn_mask is None:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention( cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns) cur_q, cur_k, cur_v, cur_offset, cur_columns)
else:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q,
cur_k,
cur_v,
cur_offset,
cur_columns,
kp_mask=kp_mask,
attn_mask=attn_mask,
bsz=i)
result[i][j] = cur_result result[i][j] = cur_result
result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
return result, result_sdd, result_softmax return result, result_sdd, result_softmax
...@@ -133,9 +197,10 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): ...@@ -133,9 +197,10 @@ def init_csr_format(batch_size, num_heads, rows, blocksize):
) )
class TestSparseAttentionOp(OpTest): class TestSparseAttentionOp(OpTest):
def config(self): def config(self):
self.shape = (1, 1, 16, 8) self.shape = (1, 1, 16, 16)
self.blocksize = 2 self.blocksize = 4
self.dtype = "float64" self.dtype = "float64"
self.use_mask = True
def setUp(self): def setUp(self):
paddle.enable_static() paddle.enable_static()
...@@ -145,20 +210,51 @@ class TestSparseAttentionOp(OpTest): ...@@ -145,20 +210,51 @@ class TestSparseAttentionOp(OpTest):
self.q = np.random.random(self.shape).astype(self.dtype) self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype) self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype) self.v = np.random.random(self.shape).astype(self.dtype)
# init CSR tensor
offset, columns = init_csr_format(self.shape[0], self.shape[1], offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize) self.shape[2], self.blocksize)
self.offset = offset.astype('int32') self.offset = offset.astype('int32')
self.columns = columns.astype('int32') self.columns = columns.astype('int32')
# init mask tensor
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask = init_mask(key_padding_mask)
attn_mask = init_mask(attn_mask)
self.key_padding_mask = key_padding_mask.astype(self.dtype)
self.attn_mask = attn_mask.astype(self.dtype)
if self.use_mask == True:
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q,
self.k,
self.v,
self.offset,
self.columns,
kp_mask=self.key_padding_mask,
attn_mask=self.attn_mask)
else:
result, result_sdd, result_softmax = ref_batch_sparse_attention( result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns) self.q, self.k, self.v, self.offset, self.columns)
if self.use_mask == True:
self.inputs = { self.inputs = {
'Q': self.q, 'Q': self.q,
'K': self.k, 'K': self.k,
'V': self.v, 'V': self.v,
'Offset': self.offset, 'Offset': self.offset,
'Columns': self.columns 'Columns': self.columns,
'KeyPaddingMask': self.key_padding_mask,
'AttnMask': self.attn_mask,
}
else:
self.inputs = {
'Q': self.q,
'K': self.k,
'V': self.v,
'Offset': self.offset,
'Columns': self.columns,
} }
self.outputs = { self.outputs = {
'Out': result.astype(self.dtype), 'Out': result.astype(self.dtype),
...@@ -180,6 +276,7 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): ...@@ -180,6 +276,7 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
self.shape = (1, 1, 8, 16) self.shape = (1, 1, 8, 16)
self.blocksize = 2 self.blocksize = 2
self.dtype = "float32" self.dtype = "float32"
self.use_mask = False
class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
...@@ -187,6 +284,7 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): ...@@ -187,6 +284,7 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
self.shape = (2, 2, 32, 8) self.shape = (2, 2, 32, 8)
self.blocksize = 8 self.blocksize = 8
self.dtype = "float64" self.dtype = "float64"
self.use_mask = False
@unittest.skipIf( @unittest.skipIf(
...@@ -199,6 +297,7 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -199,6 +297,7 @@ class TestSparseAttentionAPI(unittest.TestCase):
self.shape = (1, 1, 8, 4) self.shape = (1, 1, 8, 4)
self.blocksize = 2 self.blocksize = 2
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = True
def test_static_graph(self): def test_static_graph(self):
paddle.enable_static() paddle.enable_static()
...@@ -219,6 +318,24 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -219,6 +318,24 @@ class TestSparseAttentionAPI(unittest.TestCase):
name="Offset", shape=offset_shape, dtype="int32") name="Offset", shape=offset_shape, dtype="int32")
columns = paddle.static.data( columns = paddle.static.data(
name="Columns", shape=columns_shape, dtype="int32") name="Columns", shape=columns_shape, dtype="int32")
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
if self.use_mask == True:
key_padding_mask = paddle.static.data(
name="KeyPaddingMask",
shape=key_padding_mask_shape,
dtype=self.dtype)
attn_mask = paddle.static.data(
name="AttnMask", shape=attn_mask_shape, dtype=self.dtype)
Out = F.sparse_attention(
Q,
K,
V,
offset,
columns,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask)
else:
Out = F.sparse_attention(Q, K, V, offset, columns) Out = F.sparse_attention(Q, K, V, offset, columns)
Q_np = np.random.random(self.shape).astype(self.dtype) Q_np = np.random.random(self.shape).astype(self.dtype)
...@@ -229,7 +346,36 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -229,7 +346,36 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset_np = offset_np.astype('int32') offset_np = offset_np.astype('int32')
columns_np = columns_np.astype('int32') columns_np = columns_np.astype('int32')
# init mask tensor
key_padding_mask_np = np.random.randint(
0, 2, size=key_padding_mask_shape)
attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask_np = init_mask(key_padding_mask_np)
attn_mask_np = init_mask(attn_mask_np)
key_padding_mask_np = key_padding_mask_np.astype(self.dtype)
attn_mask_np = attn_mask_np.astype(self.dtype)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
if self.use_mask == True:
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np,
'KeyPaddingMask': key_padding_mask_np,
'AttnMask': attn_mask_np
},
fetch_list=[Out])
expected_result, __, __ = ref_batch_sparse_attention(
Q_np,
K_np,
V_np,
offset_np,
columns_np,
kp_mask=key_padding_mask_np,
attn_mask=attn_mask_np)
else:
fetches_result = exe.run(feed={ fetches_result = exe.run(feed={
"Q": Q_np, "Q": Q_np,
"K": K_np, "K": K_np,
...@@ -254,13 +400,44 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -254,13 +400,44 @@ class TestSparseAttentionAPI(unittest.TestCase):
query = np.random.random(self.shape).astype(self.dtype) query = np.random.random(self.shape).astype(self.dtype)
key = np.random.random(self.shape).astype(self.dtype) key = np.random.random(self.shape).astype(self.dtype)
value = np.random.random(self.shape).astype(self.dtype) value = np.random.random(self.shape).astype(self.dtype)
# init mask tensor
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask = init_mask(key_padding_mask)
attn_mask = init_mask(attn_mask)
key_padding_mask = key_padding_mask.astype(self.dtype)
attn_mask = attn_mask.astype(self.dtype)
paddle_query = paddle.to_tensor(query, place=self.place) paddle_query = paddle.to_tensor(query, place=self.place)
paddle_key = paddle.to_tensor(key, place=self.place) paddle_key = paddle.to_tensor(key, place=self.place)
paddle_value = paddle.to_tensor(value, place=self.place) paddle_value = paddle.to_tensor(value, place=self.place)
paddle_offset = paddle.to_tensor(offset, place=self.place) paddle_offset = paddle.to_tensor(offset, place=self.place)
paddle_colunmns = paddle.to_tensor(columns, place=self.place) paddle_colunmns = paddle.to_tensor(columns, place=self.place)
paddle_kp_mask = paddle.to_tensor(key_padding_mask, place=self.place)
paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place)
if self.use_mask == True:
paddle_result = F.sparse_attention(
paddle_query,
paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
key_padding_mask=paddle_kp_mask,
attn_mask=paddle_attn_mask)
numpy_result, __, __ = ref_batch_sparse_attention(
query,
key,
value,
offset,
columns,
kp_mask=key_padding_mask,
attn_mask=attn_mask)
numpy_result = numpy_result.astype(self.dtype)
else:
paddle_result = F.sparse_attention(paddle_query, paddle_key, paddle_result = F.sparse_attention(paddle_query, paddle_key,
paddle_value, paddle_offset, paddle_value, paddle_offset,
paddle_colunmns) paddle_colunmns)
...@@ -280,6 +457,7 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): ...@@ -280,6 +457,7 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
self.shape = (2, 2, 8, 4) self.shape = (2, 2, 8, 4)
self.blocksize = 2 self.blocksize = 2
self.dtype = 'float32' self.dtype = 'float32'
self.use_mask = False
class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
...@@ -288,6 +466,7 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): ...@@ -288,6 +466,7 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
self.shape = (2, 2, 64, 32) self.shape = (2, 2, 64, 32)
self.blocksize = 2 self.blocksize = 2
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
...@@ -296,6 +475,7 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): ...@@ -296,6 +475,7 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
self.shape = (2, 1, 64, 32) self.shape = (2, 1, 64, 32)
self.blocksize = 2 self.blocksize = 2
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
...@@ -304,6 +484,7 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): ...@@ -304,6 +484,7 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
self.shape = (4, 4, 128, 32) self.shape = (4, 4, 128, 32)
self.blocksize = 8 self.blocksize = 8
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
...@@ -312,6 +493,7 @@ class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): ...@@ -312,6 +493,7 @@ class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
self.shape = (3, 3, 35, 15) self.shape = (3, 3, 35, 15)
self.blocksize = 3 self.blocksize = 3
self.dtype = 'float64' self.dtype = 'float64'
self.use_mask = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -25,6 +25,8 @@ def sparse_attention(query, ...@@ -25,6 +25,8 @@ def sparse_attention(query,
value, value,
sparse_csr_offset, sparse_csr_offset,
sparse_csr_columns, sparse_csr_columns,
key_padding_mask=None,
attn_mask=None,
name=None): name=None):
r""" r"""
This operator sparsify the Attention matrix in Transformer module This operator sparsify the Attention matrix in Transformer module
...@@ -68,6 +70,14 @@ def sparse_attention(query, ...@@ -68,6 +70,14 @@ def sparse_attention(query,
3-D tensor with shape: 3-D tensor with shape:
[batch_size, num_heads, sparse_nnz]. [batch_size, num_heads, sparse_nnz].
The dtype should be int32. The dtype should be int32.
key_padding_mask(Tensor, optional):The key padding mask tensor in the Attention module.
2-D tensor with shape: [batch_size, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
attn_mask(Tensor, optional):The attention mask tensor in the Attention module.
2-D tensor with shape: [seq_len, seq_len].
The dtype can be float32 and float64.
A value of 0 means that the position is masked.
name(str, optional): The default value is None. Normally there is no need for user name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to to set this property. For more information, please refer to
:ref:`api_guide_Name`. :ref:`api_guide_Name`.
...@@ -94,6 +104,8 @@ def sparse_attention(query, ...@@ -94,6 +104,8 @@ def sparse_attention(query,
4, 6, 8]]]).astype("int32") 4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1, sparse_csr_columns_data = np.array([[[0, 1,
0, 1, 2, 3, 2, 3]]]).astype("int32") 0, 1, 2, 3, 2, 3]]]).astype("int32")
key_padding_mask_data = np.array([[1,1,1,0]]).astype("float32")
attention_mask_data = np.array([[1,0,1,1],[1,1,1,1],[1,1,1,1],[1,1,1,1]]).astype("float32")
print(query_data.shape) print(query_data.shape)
# (1, 1, 4, 2) # (1, 1, 4, 2)
print(sparse_csr_offset_data.shape) print(sparse_csr_offset_data.shape)
...@@ -111,10 +123,21 @@ def sparse_attention(query, ...@@ -111,10 +123,21 @@ def sparse_attention(query,
place=paddle.CUDAPlace(0)) place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False, columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False,
place=paddle.CUDAPlace(0)) place=paddle.CUDAPlace(0))
key_padding_mask = paddle.to_tensor(key_padding_mask_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
attention_mask = paddle.to_tensor(attention_mask_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
output_mask = paddle.nn.functional.sparse_attention(query, key,
value, offset, columns,
key_padding_mask=key_padding_mask, attn_mask=attention_mask)
print(output_mask)
# [[[[0. , 1. ],
# [1.99830270, 2.99830270],
# [0. , 1. ],
# [0. , 1. ]]]]
output = paddle.nn.functional.sparse_attention(query, key, output = paddle.nn.functional.sparse_attention(query, key,
value, offset, columns) value, offset, columns)
print(output) print(output)
# [[[[1.60885942, 2.60885954], # [[[[1.60885942, 2.60885954],
# [1.99830270, 2.99830270], # [1.99830270, 2.99830270],
# [1.60885942, 2.60885954], # [1.60885942, 2.60885954],
...@@ -122,7 +145,8 @@ def sparse_attention(query, ...@@ -122,7 +145,8 @@ def sparse_attention(query,
""" """
if in_dygraph_mode(): if in_dygraph_mode():
result_attention, result_sdd, result_softmax = _C_ops.sparse_attention( result_attention, result_sdd, result_softmax = _C_ops.sparse_attention(
query, key, value, sparse_csr_offset, sparse_csr_columns) query, key, value, sparse_csr_offset, sparse_csr_columns,
key_padding_mask, attn_mask)
return result_attention return result_attention
helper = LayerHelper('sparse_attention', **locals()) helper = LayerHelper('sparse_attention', **locals())
...@@ -135,7 +159,9 @@ def sparse_attention(query, ...@@ -135,7 +159,9 @@ def sparse_attention(query,
'K': key, 'K': key,
'V': value, 'V': value,
'Offset': sparse_csr_offset, 'Offset': sparse_csr_offset,
'Columns': sparse_csr_columns 'Columns': sparse_csr_columns,
'KeyPaddingMask': key_padding_mask,
'AttnMask': attn_mask,
} }
outputs = { outputs = {
'Out': out, 'Out': out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册