未验证 提交 fb19648a 编写于 作者: Z zhangkaihuo 提交者: GitHub

cherry-pick #75b734 (#49201)

上级 cdab3a44
......@@ -93,14 +93,9 @@ def get_csr_value(mat, layout, nnz):
return value
def ref_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None,
bsz=None):
def ref_sparse_attention(
q, k, v, offset, columns, kp_mask=None, attn_mask=None, bsz=None
):
row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
mat = np.zeros((row, row))
for cur_row in range(row):
......@@ -111,7 +106,7 @@ def ref_sparse_attention(q,
mat[cur_row][cur_col] = 1
a = np.dot(q, k.T) * mat
a_value = get_csr_value(a, mat, nnz)
scaling = float(col)**-0.5
scaling = float(col) ** -0.5
a = scaling * a
for i in range(row):
for j in range(row):
......@@ -127,13 +122,9 @@ def ref_sparse_attention(q,
return result, a_value, b_value
def ref_batch_sparse_attention(q,
k,
v,
offset,
columns,
kp_mask=None,
attn_mask=None):
def ref_batch_sparse_attention(
q, k, v, offset, columns, kp_mask=None, attn_mask=None
):
batch_size, num_heads, row, col = q.shape
nnz = columns.shape[2]
result = np.zeros((batch_size, num_heads, row, col))
......@@ -141,11 +132,16 @@ def ref_batch_sparse_attention(q,
result_softmax = np.zeros((batch_size, num_heads, nnz))
for i in range(batch_size):
for j in range(num_heads):
cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j]
cur_q, cur_k, cur_v, = (
q[i][j],
k[i][j],
v[i][j],
)
cur_offset, cur_columns = offset[i][j], columns[i][j]
if kp_mask is None and attn_mask is None:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns)
cur_q, cur_k, cur_v, cur_offset, cur_columns
)
else:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q,
......@@ -155,7 +151,8 @@ def ref_batch_sparse_attention(q,
cur_columns,
kp_mask=kp_mask,
attn_mask=attn_mask,
bsz=i)
bsz=i,
)
result[i][j] = cur_result
result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
return result, result_sdd, result_softmax
......@@ -193,10 +190,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize):
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
)
class TestSparseAttentionOp(OpTest):
def config(self):
self.shape = (1, 1, 16, 16)
self.blocksize = 4
......@@ -212,8 +208,9 @@ class TestSparseAttentionOp(OpTest):
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
# init CSR tensor
offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize)
offset, columns = init_csr_format(
self.shape[0], self.shape[1], self.shape[2], self.blocksize
)
self.offset = offset.astype('int32')
self.columns = columns.astype('int32')
# init mask tensor
......@@ -234,10 +231,12 @@ class TestSparseAttentionOp(OpTest):
self.offset,
self.columns,
kp_mask=self.key_padding_mask,
attn_mask=self.attn_mask)
attn_mask=self.attn_mask,
)
else:
result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns)
self.q, self.k, self.v, self.offset, self.columns
)
if self.use_mask == True:
self.inputs = {
......@@ -260,7 +259,7 @@ class TestSparseAttentionOp(OpTest):
self.outputs = {
'Out': result.astype(self.dtype),
'SparseDotSdd': result_sdd.astype(self.dtype),
'Softmax': result_softmax.astype(self.dtype)
'Softmax': result_softmax.astype(self.dtype),
}
def test_check_output(self):
......@@ -273,7 +272,6 @@ class TestSparseAttentionOp(OpTest):
class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
def config(self):
self.shape = (1, 1, 8, 16)
self.blocksize = 2
......@@ -282,7 +280,6 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
def config(self):
self.shape = (2, 2, 32, 8)
self.blocksize = 8
......@@ -292,10 +289,9 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
)
class TestSparseAttentionAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (1, 1, 8, 4)
......@@ -310,54 +306,62 @@ class TestSparseAttentionAPI(unittest.TestCase):
K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype)
V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype)
batch_size, num_heads, rows = self.shape[0], self.shape[
1], self.shape[2]
batch_size, num_heads, rows = (
self.shape[0],
self.shape[1],
self.shape[2],
)
block_num = rows / self.blocksize
block_last = rows % self.blocksize
sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last
sparse_nnz_num = (
block_num * self.blocksize * self.blocksize
+ block_last * block_last
)
offset_shape = (batch_size, num_heads, rows + 1)
columns_shape = (batch_size, num_heads, int(sparse_nnz_num))
offset = paddle.static.data(name="Offset",
shape=offset_shape,
dtype="int32")
columns = paddle.static.data(name="Columns",
shape=columns_shape,
dtype="int32")
offset = paddle.static.data(
name="Offset", shape=offset_shape, dtype="int32"
)
columns = paddle.static.data(
name="Columns", shape=columns_shape, dtype="int32"
)
key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2])
if self.use_mask == True:
key_padding_mask = paddle.static.data(
name="KeyPaddingMask",
shape=key_padding_mask_shape,
dtype=self.dtype)
attn_mask = paddle.static.data(name="AttnMask",
shape=attn_mask_shape,
dtype=self.dtype)
Out = F.sparse_attention(Q,
K,
V,
offset,
columns,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask)
dtype=self.dtype,
)
attn_mask = paddle.static.data(
name="AttnMask", shape=attn_mask_shape, dtype=self.dtype
)
Out = F.sparse_attention(
Q,
K,
V,
offset,
columns,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
else:
Out = F.sparse_attention(Q, K, V, offset, columns)
Q_np = np.random.random(self.shape).astype(self.dtype)
K_np = np.random.random(self.shape).astype(self.dtype)
V_np = np.random.random(self.shape).astype(self.dtype)
offset_np, columns_np = init_csr_format(self.shape[0],
self.shape[1],
self.shape[2],
self.blocksize)
offset_np, columns_np = init_csr_format(
self.shape[0], self.shape[1], self.shape[2], self.blocksize
)
offset_np = offset_np.astype('int32')
columns_np = columns_np.astype('int32')
# init mask tensor
key_padding_mask_np = np.random.randint(0,
2,
size=key_padding_mask_shape)
key_padding_mask_np = np.random.randint(
0, 2, size=key_padding_mask_shape
)
attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask_np = init_mask(key_padding_mask_np)
attn_mask_np = init_mask(attn_mask_np)
......@@ -366,16 +370,18 @@ class TestSparseAttentionAPI(unittest.TestCase):
exe = fluid.Executor(self.place)
if self.use_mask == True:
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np,
'KeyPaddingMask': key_padding_mask_np,
'AttnMask': attn_mask_np
},
fetch_list=[Out])
fetches_result = exe.run(
feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np,
'KeyPaddingMask': key_padding_mask_np,
'AttnMask': attn_mask_np,
},
fetch_list=[Out],
)
expected_result, __, __ = ref_batch_sparse_attention(
Q_np,
K_np,
......@@ -383,28 +389,32 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset_np,
columns_np,
kp_mask=key_padding_mask_np,
attn_mask=attn_mask_np)
attn_mask=attn_mask_np,
)
else:
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np
},
fetch_list=[Out])
fetches_result = exe.run(
feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np,
},
fetch_list=[Out],
)
expected_result, __, __ = ref_batch_sparse_attention(
Q_np, K_np, V_np, offset_np, columns_np)
Q_np, K_np, V_np, offset_np, columns_np
)
np.testing.assert_allclose(fetches_result,
expected_result,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(
fetches_result[0], expected_result, rtol=1e-05, atol=1e-05
)
def test_dygraph(self):
paddle.disable_static()
offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize)
offset, columns = init_csr_format(
self.shape[0], self.shape[1], self.shape[2], self.blocksize
)
offset = offset.astype('int32')
columns = columns.astype('int32')
query = np.random.random(self.shape).astype(self.dtype)
......@@ -429,13 +439,15 @@ class TestSparseAttentionAPI(unittest.TestCase):
paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place)
if self.use_mask == True:
paddle_result = F.sparse_attention(paddle_query,
paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
key_padding_mask=paddle_kp_mask,
attn_mask=paddle_attn_mask)
paddle_result = F.sparse_attention(
paddle_query,
paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
key_padding_mask=paddle_kp_mask,
attn_mask=paddle_attn_mask,
)
numpy_result, __, __ = ref_batch_sparse_attention(
query,
......@@ -444,25 +456,29 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset,
columns,
kp_mask=key_padding_mask,
attn_mask=attn_mask)
attn_mask=attn_mask,
)
numpy_result = numpy_result.astype(self.dtype)
else:
paddle_result = F.sparse_attention(paddle_query, paddle_key,
paddle_value, paddle_offset,
paddle_colunmns)
paddle_result = F.sparse_attention(
paddle_query,
paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
)
numpy_result, __, __ = ref_batch_sparse_attention(
query, key, value, offset, columns)
query, key, value, offset, columns
)
numpy_result = numpy_result.astype(self.dtype)
np.testing.assert_allclose(paddle_result.numpy(),
numpy_result,
rtol=1e-05,
atol=1e-05)
np.testing.assert_allclose(
paddle_result.numpy(), numpy_result, rtol=1e-05, atol=1e-05
)
class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 8, 4)
......@@ -472,7 +488,6 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 64, 32)
......@@ -482,7 +497,6 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 1, 64, 32)
......@@ -492,7 +506,6 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (4, 4, 128, 32)
......@@ -502,7 +515,6 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (3, 3, 35, 15)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册