未验证 提交 b7295120 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add fused_gate_attention API. (#53432)

* Add fused_gate_attention API.

* Implement FusedDropout API.

* Fix doc and add unittest.

* Skip for non-gpu device.

* Add unittest.
上级 99399f32
......@@ -176,6 +176,7 @@ class TestDropoutNdBF16Op(OpTest):
class TestDropoutNdAPI(unittest.TestCase):
def setUp(self):
paddle.seed(123)
np.random.seed(123)
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
......@@ -187,11 +188,36 @@ class TestDropoutNdAPI(unittest.TestCase):
with fluid.dygraph.guard(place):
in_np = np.random.random([4, 32, 16]).astype("float32")
input = paddle.to_tensor(in_np)
res1 = dropout_nd(x=input, p=0.0, axis=[0, 1])
res2 = dropout_nd(x=input, p=0.5, axis=[0, 1])
dropout_1 = paddle.incubate.nn.FusedDropout(p=0.0, axis=[0, 1])
dropout_2 = paddle.incubate.nn.FusedDropout(p=0.5, axis=[0, 1])
print(dropout_1)
print(dropout_2)
res1 = dropout_1(input)
res2 = dropout_2(input)
np.testing.assert_allclose(res1.numpy(), in_np, rtol=1e-05)
paddle.enable_static()
def test_error(self):
def _run_illegal_type_p():
dropout = paddle.incubate.nn.FusedDropout(p="test")
self.assertRaises(TypeError, _run_illegal_type_p)
def _run_illegal_value_p():
dropout = paddle.incubate.nn.FusedDropout(p=2)
self.assertRaises(ValueError, _run_illegal_value_p)
def _run_illegal_mode():
dropout = paddle.incubate.nn.FusedDropout(p=0.5, mode="test")
self.assertRaises(ValueError, _run_illegal_mode)
def _run_illegal_type_axis():
dropout = paddle.incubate.nn.FusedDropout(p=0.5, axis="test")
self.assertRaises(TypeError, _run_illegal_type_axis)
if __name__ == '__main__':
unittest.main()
......@@ -28,6 +28,7 @@ from eager_op_test import (
from test_sparse_attention_op import get_cuda_version
import paddle
import paddle.incubate.nn.functional as F
from paddle import _legacy_C_ops, nn
from paddle.fluid import core
......@@ -425,5 +426,97 @@ class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
self.batch_size = 2
class TestFusedGateAttentionApi(unittest.TestCase):
def setUp(self):
self.has_gating = True
self.batch_size = 2
self.msa_len = 3
self.res_len = 2
self.q_dim = 4
self.num_heads = 2
self.head_dim = 4
self.m_size = self.res_len
self.kv_dim = self.q_dim
self.out_dim = self.q_dim
self.merge_qkv = self.q_dim == self.kv_dim
self.query_shape = [
self.batch_size,
self.msa_len,
self.res_len,
self.q_dim,
]
self.qkv_weight_shape = [3, self.num_heads, self.head_dim, self.q_dim]
self.attn_mask_shape = [
self.batch_size,
self.msa_len,
1,
1,
self.m_size,
]
self.nonbatched_bias_shape = [
self.batch_size,
1,
self.num_heads,
self.res_len,
self.m_size,
]
self.gating_w_shape = [self.q_dim, self.num_heads, self.head_dim]
self.gating_b_shape = [self.num_heads, self.head_dim]
self.output_w_shape = [self.num_heads, self.head_dim, self.out_dim]
self.output_b_shape = [self.out_dim]
self.out_shape = [
self.batch_size,
self.msa_len,
self.res_len,
self.out_dim,
]
def test_api(self):
if not core.is_compiled_with_cuda():
pass
query = paddle.rand(shape=self.query_shape, dtype="float32")
qkv_weight = paddle.rand(shape=self.qkv_weight_shape, dtype="float32")
attn_mask = paddle.rand(shape=self.attn_mask_shape, dtype="float32")
nonbatched_bias = paddle.rand(
shape=self.nonbatched_bias_shape, dtype="float32"
)
gate_linear_weight = paddle.rand(
shape=self.gating_w_shape, dtype="float32"
)
gate_linear_bias = paddle.rand(
shape=self.gating_b_shape, dtype="float32"
)
out_linear_weight = paddle.rand(
shape=self.output_w_shape, dtype="float32"
)
out_linear_bias = paddle.rand(
shape=self.output_b_shape, dtype="float32"
)
output = F.fused_gate_attention(
query=query,
qkv_weight=qkv_weight,
gate_linear_weight=gate_linear_weight,
gate_linear_bias=gate_linear_bias,
out_linear_weight=out_linear_weight,
out_linear_bias=out_linear_bias,
nonbatched_bias=nonbatched_bias,
attn_mask=attn_mask,
has_gating=True,
merge_qkv=True,
)
print(f"output.shape={output.shape}")
self.assertEqual(output.shape, self.out_shape)
if __name__ == "__main__":
unittest.main()
......@@ -22,6 +22,7 @@ from .layer.fused_transformer import (
) # noqa: F401
from .layer.fused_ec_moe import FusedEcMoe # noqa: F401
from .layer.fused_dropout_add import FusedDropoutAdd # noqa: F401
from .layer.fused_dropout_nd import FusedDropout # noqa: F401
__all__ = [ # noqa
'FusedMultiHeadAttention',
......
......@@ -19,6 +19,7 @@ from .fused_matmul_bias import fused_matmul_bias, fused_linear
from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
from .fused_gate_attention import fused_gate_attention
__all__ = [
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import _legacy_C_ops
from paddle.fluid.framework import _non_static_mode
def fused_gate_attention(
query,
key=None,
query_weight=None,
key_weight=None,
value_weight=None,
qkv_weight=None,
gate_linear_weight=None,
gate_linear_bias=None,
out_linear_weight=None,
out_linear_bias=None,
nonbatched_bias=None,
attn_mask=None,
has_gating=True,
merge_qkv=True,
use_flash_attn=False,
):
r"""
Attention mapps queries and a set of key-value pairs to outputs, and
Gate Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces. This API only
support self_attention. The pseudo code is as follows:
.. code-block:: python
c = c ** (-0.5)
q = paddle.einsum('nbqa,ahc->nbqhc', q_data, query_w) * c
k = paddle.einsum('nbka,ahc->nbkhc', m_data, key_w)
v = paddle.einsum('nbka,ahc->nbkhc', m_data, value_w)
logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) + bias
if nonbatched_bias is not None:
logits += paddle.unsqueeze(nonbatched_bias, axis=1)
weights = paddle.nn.functional.softmax(logits)
weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v)
if has_gating:
gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, gating_w) + gating_b
gate_values = paddle.nn.functional.sigmoid(gate_values)
weighted_avg *= gate_values
output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, output_w) + output_b
Args:
query (Tensor): The input query tensor. The shape is [batch_size, msa_len, res_len, q_dim].
key (Tensor, optional): The input key tensor, which can be set when
merge_qkv is False. The shape is [batch_size, msa_len, m_size, kv_dim].
query_weight (Tensor, optional): The weight of query linear, which
should be set when input key is not None. The shape is [q_dim, num_heads, head_dim].
key_weight (Tensor, optional): The weight of key linear, which should
be set when input key is not None. The shape is [kv_dim, num_heads, head_dim].
value_weight (Tensor, optional): The weight of value linear, which should
be set when input key is not None. The shape is [kv_dim, num_heads, head_dim].
qkv_weight (Tensor, optional): The weight of qkv linear, which should
be set when merge_qkv is True. The shape is [3, num_heads, head_dim, q_dim].
gate_linear_weight (Tensor, optional): The weight of gating linear,
which should be set when has_gating is True. The shape is [q_dim, num_heads, head_dim].
gate_linear_bias (Tensor, optional): The bias of gating linear, which
should be set when has_gating is True. The shape is [num_heads, head_dim]. Default None.
out_linear_weight (Tensor, optional): The weight of output linear. The shape is [num_heads, head_dim, q_dim].
out_linear_bias (Tensor): The bias of output linear, the shape is [q_dim]. Default None.
nonbatched_bias (Tensor, optional): The extra bias. The shape is [batch_size, 1, num_heads, res_len, m_size]. Default None.
attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None.
has_gating (bool, optional): Whether has the gating linear. Default True.
merge_qkv (bool, optional): Whether has the gating linear. Default True.
Returns:
Tensor: The output Tensor, the data type and shape is same as `query`.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.incubate.nn.functional as F
# batch_size = 2
# msa_len = 4
# res_len = 2
# q_dim = 4
# num_heads = 8
# head_dim = 4
# m_size = res_len (when merge_qkv is True)
# query: [batch_size, msa_len, res_len, q_dim]
query = paddle.rand(shape=[2, 4, 2, 4], dtype="float32")
# qkv_weight: [3, n_heads, head_dim, q_dim]
qkv_weight = paddle.rand(shape=[3, 8, 4, 4], dtype="float32")
# nonbatched_bias: [batch_size, 1, num_heads, res_len, m_size]
nonbatched_bias = paddle.rand(shape=[2, 1, 8, 2, 2], dtype="float32")
# attn_mask: [batch_size, msa_len, 1, 1, m_size]
attn_mask = paddle.rand(shape=[2, 4, 1, 1, 2], dtype="float32")
# gate_linear_weight: [q_dim, num_heads, head_dim]
gate_linear_weight = paddle.rand(shape=[4, 8, 4], dtype="float32")
# gate_bias: [num_heads, head_dim]
gate_linear_bias = paddle.rand(shape=[8, 4], dtype="float32")
# out_linear_weight: [num_heads, head_dim, q_dim]
out_linear_weight = paddle.rand(shape=[8, 4, 4], dtype="float32")
# out_linear_bias: [q_dim]
out_linear_bias = paddle.rand(shape=[4], dtype="float32")
# output: [batch_size, msa_len, res_len, q_dim]
output = F.fused_gate_attention(
query=query,
qkv_weight=qkv_weight,
gate_linear_weight=gate_linear_weight,
gate_linear_bias=gate_linear_bias,
out_linear_weight=out_linear_weight,
out_linear_bias=out_linear_bias,
nonbatched_bias=nonbatched_bias,
attn_mask=attn_mask,
has_gating=True,
merge_qkv=True)
print(output.shape)
# [2, 4, 2, 4]
"""
if _non_static_mode():
_, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention(
query,
key,
query_weight,
key_weight,
value_weight,
qkv_weight,
nonbatched_bias,
attn_mask,
gate_linear_weight,
gate_linear_bias,
out_linear_weight,
out_linear_bias,
'has_gating',
has_gating,
'merge_qkv',
merge_qkv,
"use_flash_attn",
use_flash_attn,
)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _legacy_C_ops
from paddle.fluid.framework import _non_static_mode
class FusedDropout(paddle.nn.Layer):
r"""
Dropout is a regularization technique for reducing overfitting by preventing
neuron co-adaption during training as described in the paper:
`Improving neural networks by preventing co-adaptation of feature detectors <https://arxiv.org/abs/1207.0580>`_
The dropout operator randomly sets the outputs of some units to zero, while upscale others
according to the given dropout probability.
It is an optimized implementation for ``paddle.nn.Dropout``.
In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled.
Parameters:
p (float|int, optional): Probability of setting units to zero. Default: 0.5
axis (int|list|tuple, optional): The axis along which the dropout is performed. Default: None.
mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']
1. upscale_in_train (default), upscale the output at training time
- train: :math:`out = input \times \frac{mask}{(1.0 - p)}`
- inference: :math:`out = input`
2. downscale_in_infer, downscale the output at inference
- train: :math:`out = input \times mask`
- inference: :math:`out = input \times (1.0 - p)`
name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: N-D tensor.
- output: N-D tensor, the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32")
m = paddle.incubate.nn.FusedDropout(p=0.5)
y_train = m(x)
print(y_train)
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[2., 0., 6.],
# [0., 0., 0.]])
m.eval() # switch the model to test phase
y_test = m(x)
print(y_test)
# Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [[1., 2., 3.],
# [4., 5., 6.]])
"""
def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None):
super().__init__()
if not isinstance(p, (float, int)):
raise TypeError("p argument should be a number")
if p < 0 or p > 1:
raise ValueError("p argument should between 0 and 1")
mode = (
'downgrade_in_infer' if mode == 'downscale_in_infer' else mode
) # semantic transfer
if mode not in ('downscale_in_infer', 'upscale_in_train'):
raise ValueError(
"mode argument should be 'downscale_in_infer' or 'upscale_in_train'"
)
if axis and not isinstance(axis, (int, list, tuple)):
raise TypeError("datatype of axis argument should be int or list")
self.p = p
self.mode = mode
self.name = name
self.axis = None
if axis is not None:
self.axis = [axis] if isinstance(axis, int) else list(axis)
def forward(self, input):
# fast return for p == 0
if self.p == 0:
return input
if self.axis is not None and _non_static_mode():
seed = None
if paddle.static.default_main_program().random_seed != 0:
seed = paddle.static.default_main_program().random_seed
out, mask = _legacy_C_ops.dropout_nd(
input,
'dropout_prob',
self.p,
'is_test',
not self.training,
'fix_seed',
seed is not None,
'seed',
seed if seed is not None else 0,
'dropout_implementation',
self.mode,
'axis',
self.axis,
)
else:
out = paddle.nn.functional.dropout(
input,
p=self.p,
axis=self.axis,
training=self.training,
mode=self.mode,
name=self.name,
)
return out
def extra_repr(self):
name_str = f', name={self.name}' if self.name else ''
return 'p={}, axis={}, mode={}{}'.format(
self.p, self.axis, self.mode, name_str
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册