未验证 提交 c57d1e91 编写于 作者: L Liu-xiandong 提交者: GitHub

Add nn.functional.sparse_attention and some test cases, test=develop (#35757) (#36551)

Add paddle.nn.functional.sparse_attention API

    本个PR主要将sparse_attention功能在python层进行了一层封装,OP的主体代码见:#PR35676

    此外,对于封装的python 接口,增加了相应的单测。
上级 bd40dd9a
......@@ -94,7 +94,7 @@ if (WITH_GPU OR WITH_ROCM)
endif()
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) )
if ((NOT WIN32) AND (NOT WITH_ROCM) AND (NOT PADDLE_WITH_ARM) AND (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2) )
op_library(sparse_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sparse_attention);\n")
endif()
......
......@@ -456,6 +456,11 @@ list(REMOVE_ITEM TEST_OPS test_imperative_static_runner_while)
# disable this unittest temporarily
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception)
# disable sparse_attention which not in suitable env
if ( (NOT WITH_GPU) OR (WIN32) OR (PADDLE_WITH_ARM) OR (WITH_ROCM) )
list(REMOVE_ITEM TEST_OPS test_sparse_attention_op)
endif()
if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
list(REMOVE_ITEM TEST_OPS test_dataset_dataloader)
......
......@@ -16,10 +16,13 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.static import Program, program_guard
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.nn.functional as F
import os
import re
import platform
def get_cuda_version():
......@@ -34,22 +37,6 @@ def get_cuda_version():
return -1
def get_linux_platform():
if platform.system().lower() == 'windows':
return 0
elif platform.system().lower() == 'linux':
return 1
else:
return -1
def get_suitable_env():
if get_cuda_version() >= 11020 and get_linux_platform() == 1:
return True
else:
return False
def softmax(x):
max = np.max(x, axis=1, keepdims=True)
e_x = np.exp(x - max)
......@@ -141,8 +128,9 @@ def init_csr_format(batch_size, num_heads, rows, blocksize):
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_suitable_env() == False,
"core is not compiled with CUDA and cuda version need >= 11.2 in windows")
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.2"
)
class TestSparseAttentionOp(OpTest):
def config(self):
self.shape = (1, 1, 16, 8)
......@@ -201,5 +189,130 @@ class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
self.dtype = "float64"
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11020,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.2"
)
class TestSparseAttentionAPI(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (1, 1, 8, 4)
self.blocksize = 2
self.dtype = 'float64'
def test_static_graph(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Q = paddle.static.data(name="Q", shape=self.shape, dtype=self.dtype)
K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype)
V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype)
batch_size, num_heads, rows = self.shape[0], self.shape[
1], self.shape[2]
block_num = rows / self.blocksize
block_last = rows % self.blocksize
sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last
offset_shape = (batch_size, num_heads, rows + 1)
columns_shape = (batch_size, num_heads, int(sparse_nnz_num))
offset = paddle.static.data(
name="Offset", shape=offset_shape, dtype="int32")
columns = paddle.static.data(
name="Columns", shape=columns_shape, dtype="int32")
Out = F.sparse_attention(Q, K, V, offset, columns)
Q_np = np.random.random(self.shape).astype(self.dtype)
K_np = np.random.random(self.shape).astype(self.dtype)
V_np = np.random.random(self.shape).astype(self.dtype)
offset_np, columns_np = init_csr_format(
self.shape[0], self.shape[1], self.shape[2], self.blocksize)
offset_np = offset_np.astype('int32')
columns_np = columns_np.astype('int32')
exe = fluid.Executor(self.place)
fetches_result = exe.run(feed={
"Q": Q_np,
"K": K_np,
"V": V_np,
"Offset": offset_np,
"Columns": columns_np
},
fetch_list=[Out])
expected_result, __, __ = ref_batch_sparse_attention(
Q_np, K_np, V_np, offset_np, columns_np)
self.assertTrue(
np.allclose(
fetches_result, expected_result, atol=1e-5))
def test_dygraph(self):
paddle.disable_static()
offset, columns = init_csr_format(self.shape[0], self.shape[1],
self.shape[2], self.blocksize)
offset = offset.astype('int32')
columns = columns.astype('int32')
query = np.random.random(self.shape).astype(self.dtype)
key = np.random.random(self.shape).astype(self.dtype)
value = np.random.random(self.shape).astype(self.dtype)
paddle_query = paddle.to_tensor(query, place=self.place)
paddle_key = paddle.to_tensor(key, place=self.place)
paddle_value = paddle.to_tensor(value, place=self.place)
paddle_offset = paddle.to_tensor(offset, place=self.place)
paddle_colunmns = paddle.to_tensor(columns, place=self.place)
paddle_result = F.sparse_attention(paddle_query, paddle_key,
paddle_value, paddle_offset,
paddle_colunmns)
numpy_result, __, __ = ref_batch_sparse_attention(query, key, value,
offset, columns)
numpy_result = numpy_result.astype(self.dtype)
self.assertTrue(
np.allclose(
paddle_result.numpy(), numpy_result, atol=1e-5))
class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 8, 4)
self.blocksize = 2
self.dtype = 'float32'
class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 2, 64, 32)
self.blocksize = 2
self.dtype = 'float64'
class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (2, 1, 64, 32)
self.blocksize = 2
self.dtype = 'float64'
class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (4, 4, 128, 32)
self.blocksize = 8
self.dtype = 'float64'
class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
def setUp(self):
self.place = paddle.CUDAPlace(0)
self.shape = (3, 3, 35, 15)
self.blocksize = 3
self.dtype = 'float64'
if __name__ == '__main__':
unittest.main()
......@@ -112,6 +112,8 @@ from .input import embedding # noqa: F401
from ...fluid.layers import gather_tree # noqa: F401
from ...fluid.layers import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
__all__ = [ #noqa
'conv1d',
'conv1d_transpose',
......@@ -207,4 +209,5 @@ __all__ = [ #noqa
'layer_norm',
'instance_norm',
'class_center_sample',
'sparse_attention',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import paddle
from ...fluid.framework import in_dygraph_mode, default_main_program
from paddle.fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from paddle import _C_ops
def sparse_attention(query,
key,
value,
sparse_csr_offset,
sparse_csr_columns,
name=None):
r"""
This operator sparsify the Attention matrix in Transformer module
to achieve the effect of reducing memory consumption and computation.
The sparse layout is expressed in CSR format and contains two parameters,
``offset`` and ``columns``.
.. math::
result=softmax(\frac{ Q * K^T }{\sqrt{d}}) * V
where : ``Q``, ``K``, and ``V`` represent the three input parameters of the attention module.
The dimensions of the three parameters are the same.
``d`` represents the size of the last dimension of the three parameters.
Parameters:
query(Tensor): The query tensor in the Attention module.
It's a 4-D tensor with a shape of
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`.
The dtype can be ``float32`` and ``float64``.
key(Tensor): The key tensor in the Attention module.
It's a 4-D tensor with a shape of
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`.
The dtype can be ``float32`` and ``float64``.
value(Tensor): The value tensor in the Attention module.
It's a 4-D tensor with a shape of
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`.
The dtype can be ``float32`` and ``float64``.
sparse_csr_offset(Tensor): The sparsity feature in the Attention module
is expressed in the CSR format, and the offset represents
the number of non-zero elements in each row of the matrix.
It's a 3-D tensor with a shape of
:math:`[batch\_size, num\_heads, seq\_len + 1]`.
The dtype should be ``int32``.
sparse_csr_columns(Tensor): The sparsity feature in the Attention module
is expressed in the CSR format, and the columns represent
the column index values of non-zero elements in the matrix.
It's a 3-D tensor with a shape of
:math:`[batch\_size, num\_heads, sparse\_nnz]`.
The dtype should be ``int32``.
name(str, optional): The default value is None. Normally there is no need for user
to set this property. For more information, please refer to
:ref:`api_guide_Name`.
Returns:
A Tensor which refers to the result in the Attention module.
It's a 4-D tensor with a shape of
:math:`[batch\_size, num\_heads, seq\_len, head\_dim]`.
The dtype can be ``float32`` and ``float64``.
Examples:
.. code-block:: python
# required: skiptest
import paddle
import numpy as np
query_data = np.array([[[[0, 1,], [2, 3],
[ 0, 1], [2, 3]]]]).astype("float32")
key_data = np.array([[[[0, 1,], [2, 3],
[ 0, 1], [2, 3]]]]).astype("float32")
value_data = np.array([[[[0, 1,], [2, 3],
[ 0, 1], [2, 3]]]]).astype("float32")
sparse_csr_offset_data = np.array([[[0, 2,
4, 6, 8]]]).astype("int32")
sparse_csr_columns_data = np.array([[[0, 1,
0, 1, 2, 3, 2, 3]]]).astype("int32")
print(query_data.shape)
# (1, 1, 4, 2)
print(sparse_csr_offset_data.shape)
# (1, 1, 5)
print(sparse_csr_columns_data.shape)
# (1, 1, 8)
paddle.disable_static()
query = paddle.to_tensor(query_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
key = paddle.to_tensor(key_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
value = paddle.to_tensor(value_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
offset = paddle.to_tensor(sparse_csr_offset_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
columns = paddle.to_tensor(sparse_csr_columns_data, stop_gradient=False,
place=paddle.CUDAPlace(0))
output = paddle.nn.functional.sparse_attention(query, key,
value, offset, columns)
print(output)
# [[[[1.60885942, 2.60885954],
# [1.99830270, 2.99830270],
# [1.60885942, 2.60885954],
# [1.99830270, 2.99830270]]]]
"""
if in_dygraph_mode():
result_attention, result_sdd, result_softmax = _C_ops.sparse_attention(
query, key, value, sparse_csr_offset, sparse_csr_columns)
return result_attention
helper = LayerHelper('sparse_attention', **locals())
dtype = helper.input_dtype(input_param_name='Q')
out = helper.create_variable_for_type_inference(dtype)
result_sdd = helper.create_variable_for_type_inference(dtype)
result_softmax = helper.create_variable_for_type_inference(dtype)
inputs = {
'Q': query,
'K': key,
'V': value,
'Offset': sparse_csr_offset,
'Columns': sparse_csr_columns
}
outputs = {
'Out': out,
'SparseDotSdd': result_sdd,
'Softmax': result_softmax
}
helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册