未验证 提交 fb19648a 编写于 作者: Z zhangkaihuo 提交者: GitHub

cherry-pick #75b734 (#49201)

上级 cdab3a44
...@@ -93,14 +93,9 @@ def get_csr_value(mat, layout, nnz): ...@@ -93,14 +93,9 @@ def get_csr_value(mat, layout, nnz):
return value return value
def ref_sparse_attention(q, def ref_sparse_attention(
k, q, k, v, offset, columns, kp_mask=None, attn_mask=None, bsz=None
v, ):
offset,
columns,
kp_mask=None,
attn_mask=None,
bsz=None):
row, col, nnz = q.shape[0], q.shape[1], columns.shape[0] row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
mat = np.zeros((row, row)) mat = np.zeros((row, row))
for cur_row in range(row): for cur_row in range(row):
...@@ -111,7 +106,7 @@ def ref_sparse_attention(q, ...@@ -111,7 +106,7 @@ def ref_sparse_attention(q,
mat[cur_row][cur_col] = 1 mat[cur_row][cur_col] = 1
a = np.dot(q, k.T) * mat a = np.dot(q, k.T) * mat
a_value = get_csr_value(a, mat, nnz) a_value = get_csr_value(a, mat, nnz)
scaling = float(col)**-0.5 scaling = float(col) ** -0.5
a = scaling * a a = scaling * a
for i in range(row): for i in range(row):
for j in range(row): for j in range(row):
...@@ -127,13 +122,9 @@ def ref_sparse_attention(q, ...@@ -127,13 +122,9 @@ def ref_sparse_attention(q,
return result, a_value, b_value return result, a_value, b_value
def ref_batch_sparse_attention(q, def ref_batch_sparse_attention(
k, q, k, v, offset, columns, kp_mask=None, attn_mask=None
v, ):
offset,
columns,
kp_mask=None,
attn_mask=None):
batch_size, num_heads, row, col = q.shape batch_size, num_heads, row, col = q.shape
nnz = columns.shape[2] nnz = columns.shape[2]
result = np.zeros((batch_size, num_heads, row, col)) result = np.zeros((batch_size, num_heads, row, col))
...@@ -141,11 +132,16 @@ def ref_batch_sparse_attention(q, ...@@ -141,11 +132,16 @@ def ref_batch_sparse_attention(q,
result_softmax = np.zeros((batch_size, num_heads, nnz)) result_softmax = np.zeros((batch_size, num_heads, nnz))
for i in range(batch_size): for i in range(batch_size):
for j in range(num_heads): for j in range(num_heads):
cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j] cur_q, cur_k, cur_v, = (
q[i][j],
k[i][j],
v[i][j],
)
cur_offset, cur_columns = offset[i][j], columns[i][j] cur_offset, cur_columns = offset[i][j], columns[i][j]
if kp_mask is None and attn_mask is None: if kp_mask is None and attn_mask is None:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention( cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_k, cur_v, cur_offset, cur_columns) cur_q, cur_k, cur_v, cur_offset, cur_columns
)
else: else:
cur_result, cur_sdd, cur_softmax = ref_sparse_attention( cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
cur_q, cur_q,
...@@ -155,7 +151,8 @@ def ref_batch_sparse_attention(q, ...@@ -155,7 +151,8 @@ def ref_batch_sparse_attention(q,
cur_columns, cur_columns,
kp_mask=kp_mask, kp_mask=kp_mask,
attn_mask=attn_mask, attn_mask=attn_mask,
bsz=i) bsz=i,
)
result[i][j] = cur_result result[i][j] = cur_result
result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
return result, result_sdd, result_softmax return result, result_sdd, result_softmax
...@@ -193,10 +190,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize): ...@@ -193,10 +190,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030, not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
) )
class TestSparseAttentionOp(OpTest): class TestSparseAttentionOp(OpTest):
def config(self): def config(self):
self.shape = (1, 1, 16, 16) self.shape = (1, 1, 16, 16)
self.blocksize = 4 self.blocksize = 4
...@@ -212,8 +208,9 @@ class TestSparseAttentionOp(OpTest): ...@@ -212,8 +208,9 @@ class TestSparseAttentionOp(OpTest):
self.k = np.random.random(self.shape).astype(self.dtype) self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype) self.v = np.random.random(self.shape).astype(self.dtype)
# init CSR tensor # init CSR tensor
offset, columns = init_csr_format(self.shape[0], self.shape[1], offset, columns = init_csr_format(
self.shape[2], self.blocksize) self.shape[0], self.shape[1], self.shape[2], self.blocksize
)
self.offset = offset.astype('int32') self.offset = offset.astype('int32')
self.columns = columns.astype('int32') self.columns = columns.astype('int32')
# init mask tensor # init mask tensor
...@@ -234,10 +231,12 @@ class TestSparseAttentionOp(OpTest): ...@@ -234,10 +231,12 @@ class TestSparseAttentionOp(OpTest):
self.offset, self.offset,
self.columns, self.columns,
kp_mask=self.key_padding_mask, kp_mask=self.key_padding_mask,
attn_mask=self.attn_mask) attn_mask=self.attn_mask,
)
else: else:
result, result_sdd, result_softmax = ref_batch_sparse_attention( result, result_sdd, result_softmax = ref_batch_sparse_attention(
self.q, self.k, self.v, self.offset, self.columns) self.q, self.k, self.v, self.offset, self.columns
)
if self.use_mask == True: if self.use_mask == True:
self.inputs = { self.inputs = {
...@@ -260,7 +259,7 @@ class TestSparseAttentionOp(OpTest): ...@@ -260,7 +259,7 @@ class TestSparseAttentionOp(OpTest):
self.outputs = { self.outputs = {
'Out': result.astype(self.dtype), 'Out': result.astype(self.dtype),
'SparseDotSdd': result_sdd.astype(self.dtype), 'SparseDotSdd': result_sdd.astype(self.dtype),
'Softmax': result_softmax.astype(self.dtype) 'Softmax': result_softmax.astype(self.dtype),
} }
def test_check_output(self): def test_check_output(self):
...@@ -273,7 +272,6 @@ class TestSparseAttentionOp(OpTest): ...@@ -273,7 +272,6 @@ class TestSparseAttentionOp(OpTest):
class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
def config(self): def config(self):
self.shape = (1, 1, 8, 16) self.shape = (1, 1, 8, 16)
self.blocksize = 2 self.blocksize = 2
...@@ -282,7 +280,6 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp): ...@@ -282,7 +280,6 @@ class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
def config(self): def config(self):
self.shape = (2, 2, 32, 8) self.shape = (2, 2, 32, 8)
self.blocksize = 8 self.blocksize = 8
...@@ -292,10 +289,9 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp): ...@@ -292,10 +289,9 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030, not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" "core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
) )
class TestSparseAttentionAPI(unittest.TestCase): class TestSparseAttentionAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (1, 1, 8, 4) self.shape = (1, 1, 8, 4)
...@@ -310,54 +306,62 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -310,54 +306,62 @@ class TestSparseAttentionAPI(unittest.TestCase):
K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype) K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype)
V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype) V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype)
batch_size, num_heads, rows = self.shape[0], self.shape[ batch_size, num_heads, rows = (
1], self.shape[2] self.shape[0],
self.shape[1],
self.shape[2],
)
block_num = rows / self.blocksize block_num = rows / self.blocksize
block_last = rows % self.blocksize block_last = rows % self.blocksize
sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last sparse_nnz_num = (
block_num * self.blocksize * self.blocksize
+ block_last * block_last
)
offset_shape = (batch_size, num_heads, rows + 1) offset_shape = (batch_size, num_heads, rows + 1)
columns_shape = (batch_size, num_heads, int(sparse_nnz_num)) columns_shape = (batch_size, num_heads, int(sparse_nnz_num))
offset = paddle.static.data(name="Offset", offset = paddle.static.data(
shape=offset_shape, name="Offset", shape=offset_shape, dtype="int32"
dtype="int32") )
columns = paddle.static.data(name="Columns", columns = paddle.static.data(
shape=columns_shape, name="Columns", shape=columns_shape, dtype="int32"
dtype="int32") )
key_padding_mask_shape = (self.shape[0], self.shape[2]) key_padding_mask_shape = (self.shape[0], self.shape[2])
attn_mask_shape = (self.shape[2], self.shape[2]) attn_mask_shape = (self.shape[2], self.shape[2])
if self.use_mask == True: if self.use_mask == True:
key_padding_mask = paddle.static.data( key_padding_mask = paddle.static.data(
name="KeyPaddingMask", name="KeyPaddingMask",
shape=key_padding_mask_shape, shape=key_padding_mask_shape,
dtype=self.dtype) dtype=self.dtype,
attn_mask = paddle.static.data(name="AttnMask", )
shape=attn_mask_shape, attn_mask = paddle.static.data(
dtype=self.dtype) name="AttnMask", shape=attn_mask_shape, dtype=self.dtype
Out = F.sparse_attention(Q, )
K, Out = F.sparse_attention(
V, Q,
offset, K,
columns, V,
key_padding_mask=key_padding_mask, offset,
attn_mask=attn_mask) columns,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
else: else:
Out = F.sparse_attention(Q, K, V, offset, columns) Out = F.sparse_attention(Q, K, V, offset, columns)
Q_np = np.random.random(self.shape).astype(self.dtype) Q_np = np.random.random(self.shape).astype(self.dtype)
K_np = np.random.random(self.shape).astype(self.dtype) K_np = np.random.random(self.shape).astype(self.dtype)
V_np = np.random.random(self.shape).astype(self.dtype) V_np = np.random.random(self.shape).astype(self.dtype)
offset_np, columns_np = init_csr_format(self.shape[0], offset_np, columns_np = init_csr_format(
self.shape[1], self.shape[0], self.shape[1], self.shape[2], self.blocksize
self.shape[2], )
self.blocksize)
offset_np = offset_np.astype('int32') offset_np = offset_np.astype('int32')
columns_np = columns_np.astype('int32') columns_np = columns_np.astype('int32')
# init mask tensor # init mask tensor
key_padding_mask_np = np.random.randint(0, key_padding_mask_np = np.random.randint(
2, 0, 2, size=key_padding_mask_shape
size=key_padding_mask_shape) )
attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape) attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape)
key_padding_mask_np = init_mask(key_padding_mask_np) key_padding_mask_np = init_mask(key_padding_mask_np)
attn_mask_np = init_mask(attn_mask_np) attn_mask_np = init_mask(attn_mask_np)
...@@ -366,16 +370,18 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -366,16 +370,18 @@ class TestSparseAttentionAPI(unittest.TestCase):
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
if self.use_mask == True: if self.use_mask == True:
fetches_result = exe.run(feed={ fetches_result = exe.run(
"Q": Q_np, feed={
"K": K_np, "Q": Q_np,
"V": V_np, "K": K_np,
"Offset": offset_np, "V": V_np,
"Columns": columns_np, "Offset": offset_np,
'KeyPaddingMask': key_padding_mask_np, "Columns": columns_np,
'AttnMask': attn_mask_np 'KeyPaddingMask': key_padding_mask_np,
}, 'AttnMask': attn_mask_np,
fetch_list=[Out]) },
fetch_list=[Out],
)
expected_result, __, __ = ref_batch_sparse_attention( expected_result, __, __ = ref_batch_sparse_attention(
Q_np, Q_np,
K_np, K_np,
...@@ -383,28 +389,32 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -383,28 +389,32 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset_np, offset_np,
columns_np, columns_np,
kp_mask=key_padding_mask_np, kp_mask=key_padding_mask_np,
attn_mask=attn_mask_np) attn_mask=attn_mask_np,
)
else: else:
fetches_result = exe.run(feed={ fetches_result = exe.run(
"Q": Q_np, feed={
"K": K_np, "Q": Q_np,
"V": V_np, "K": K_np,
"Offset": offset_np, "V": V_np,
"Columns": columns_np "Offset": offset_np,
}, "Columns": columns_np,
fetch_list=[Out]) },
fetch_list=[Out],
)
expected_result, __, __ = ref_batch_sparse_attention( expected_result, __, __ = ref_batch_sparse_attention(
Q_np, K_np, V_np, offset_np, columns_np) Q_np, K_np, V_np, offset_np, columns_np
)
np.testing.assert_allclose(fetches_result, np.testing.assert_allclose(
expected_result, fetches_result[0], expected_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05)
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
offset, columns = init_csr_format(self.shape[0], self.shape[1], offset, columns = init_csr_format(
self.shape[2], self.blocksize) self.shape[0], self.shape[1], self.shape[2], self.blocksize
)
offset = offset.astype('int32') offset = offset.astype('int32')
columns = columns.astype('int32') columns = columns.astype('int32')
query = np.random.random(self.shape).astype(self.dtype) query = np.random.random(self.shape).astype(self.dtype)
...@@ -429,13 +439,15 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -429,13 +439,15 @@ class TestSparseAttentionAPI(unittest.TestCase):
paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place) paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place)
if self.use_mask == True: if self.use_mask == True:
paddle_result = F.sparse_attention(paddle_query, paddle_result = F.sparse_attention(
paddle_key, paddle_query,
paddle_value, paddle_key,
paddle_offset, paddle_value,
paddle_colunmns, paddle_offset,
key_padding_mask=paddle_kp_mask, paddle_colunmns,
attn_mask=paddle_attn_mask) key_padding_mask=paddle_kp_mask,
attn_mask=paddle_attn_mask,
)
numpy_result, __, __ = ref_batch_sparse_attention( numpy_result, __, __ = ref_batch_sparse_attention(
query, query,
...@@ -444,25 +456,29 @@ class TestSparseAttentionAPI(unittest.TestCase): ...@@ -444,25 +456,29 @@ class TestSparseAttentionAPI(unittest.TestCase):
offset, offset,
columns, columns,
kp_mask=key_padding_mask, kp_mask=key_padding_mask,
attn_mask=attn_mask) attn_mask=attn_mask,
)
numpy_result = numpy_result.astype(self.dtype) numpy_result = numpy_result.astype(self.dtype)
else: else:
paddle_result = F.sparse_attention(paddle_query, paddle_key, paddle_result = F.sparse_attention(
paddle_value, paddle_offset, paddle_query,
paddle_colunmns) paddle_key,
paddle_value,
paddle_offset,
paddle_colunmns,
)
numpy_result, __, __ = ref_batch_sparse_attention( numpy_result, __, __ = ref_batch_sparse_attention(
query, key, value, offset, columns) query, key, value, offset, columns
)
numpy_result = numpy_result.astype(self.dtype) numpy_result = numpy_result.astype(self.dtype)
np.testing.assert_allclose(paddle_result.numpy(), np.testing.assert_allclose(
numpy_result, paddle_result.numpy(), numpy_result, rtol=1e-05, atol=1e-05
rtol=1e-05, )
atol=1e-05)
class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 8, 4) self.shape = (2, 2, 8, 4)
...@@ -472,7 +488,6 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI): ...@@ -472,7 +488,6 @@ class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 64, 32) self.shape = (2, 2, 64, 32)
...@@ -482,7 +497,6 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI): ...@@ -482,7 +497,6 @@ class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (2, 1, 64, 32) self.shape = (2, 1, 64, 32)
...@@ -492,7 +506,6 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI): ...@@ -492,7 +506,6 @@ class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (4, 4, 128, 32) self.shape = (4, 4, 128, 32)
...@@ -502,7 +515,6 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI): ...@@ -502,7 +515,6 @@ class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI): class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
def setUp(self): def setUp(self):
self.place = paddle.CUDAPlace(0) self.place = paddle.CUDAPlace(0)
self.shape = (3, 3, 35, 15) self.shape = (3, 3, 35, 15)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册