diff --git a/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py b/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py index 7a36599a580f3d8dd37421c9ef92f9f1c194c624..e2d5ae8685cc94b4d97619b529f91a0240a65f1e 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py @@ -176,6 +176,7 @@ class TestDropoutNdBF16Op(OpTest): class TestDropoutNdAPI(unittest.TestCase): def setUp(self): + paddle.seed(123) np.random.seed(123) self.places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): @@ -187,11 +188,36 @@ class TestDropoutNdAPI(unittest.TestCase): with fluid.dygraph.guard(place): in_np = np.random.random([4, 32, 16]).astype("float32") input = paddle.to_tensor(in_np) - res1 = dropout_nd(x=input, p=0.0, axis=[0, 1]) - res2 = dropout_nd(x=input, p=0.5, axis=[0, 1]) + dropout_1 = paddle.incubate.nn.FusedDropout(p=0.0, axis=[0, 1]) + dropout_2 = paddle.incubate.nn.FusedDropout(p=0.5, axis=[0, 1]) + print(dropout_1) + print(dropout_2) + res1 = dropout_1(input) + res2 = dropout_2(input) np.testing.assert_allclose(res1.numpy(), in_np, rtol=1e-05) paddle.enable_static() + def test_error(self): + def _run_illegal_type_p(): + dropout = paddle.incubate.nn.FusedDropout(p="test") + + self.assertRaises(TypeError, _run_illegal_type_p) + + def _run_illegal_value_p(): + dropout = paddle.incubate.nn.FusedDropout(p=2) + + self.assertRaises(ValueError, _run_illegal_value_p) + + def _run_illegal_mode(): + dropout = paddle.incubate.nn.FusedDropout(p=0.5, mode="test") + + self.assertRaises(ValueError, _run_illegal_mode) + + def _run_illegal_type_axis(): + dropout = paddle.incubate.nn.FusedDropout(p=0.5, axis="test") + + self.assertRaises(TypeError, _run_illegal_type_axis) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py index 70face3ce5b6748f57e6e2c77f31dc4fc44cf824..d9bc88d4b2c6d57a64eb0aae19e58f06e6d9d4a2 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py @@ -28,6 +28,7 @@ from eager_op_test import ( from test_sparse_attention_op import get_cuda_version import paddle +import paddle.incubate.nn.functional as F from paddle import _legacy_C_ops, nn from paddle.fluid import core @@ -425,5 +426,97 @@ class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case): self.batch_size = 2 +class TestFusedGateAttentionApi(unittest.TestCase): + def setUp(self): + self.has_gating = True + self.batch_size = 2 + self.msa_len = 3 + self.res_len = 2 + self.q_dim = 4 + self.num_heads = 2 + self.head_dim = 4 + self.m_size = self.res_len + self.kv_dim = self.q_dim + self.out_dim = self.q_dim + self.merge_qkv = self.q_dim == self.kv_dim + + self.query_shape = [ + self.batch_size, + self.msa_len, + self.res_len, + self.q_dim, + ] + self.qkv_weight_shape = [3, self.num_heads, self.head_dim, self.q_dim] + + self.attn_mask_shape = [ + self.batch_size, + self.msa_len, + 1, + 1, + self.m_size, + ] + self.nonbatched_bias_shape = [ + self.batch_size, + 1, + self.num_heads, + self.res_len, + self.m_size, + ] + + self.gating_w_shape = [self.q_dim, self.num_heads, self.head_dim] + self.gating_b_shape = [self.num_heads, self.head_dim] + + self.output_w_shape = [self.num_heads, self.head_dim, self.out_dim] + self.output_b_shape = [self.out_dim] + + self.out_shape = [ + self.batch_size, + self.msa_len, + self.res_len, + self.out_dim, + ] + + def test_api(self): + if not core.is_compiled_with_cuda(): + pass + + query = paddle.rand(shape=self.query_shape, dtype="float32") + qkv_weight = paddle.rand(shape=self.qkv_weight_shape, dtype="float32") + + attn_mask = paddle.rand(shape=self.attn_mask_shape, dtype="float32") + nonbatched_bias = paddle.rand( + shape=self.nonbatched_bias_shape, dtype="float32" + ) + + gate_linear_weight = paddle.rand( + shape=self.gating_w_shape, dtype="float32" + ) + gate_linear_bias = paddle.rand( + shape=self.gating_b_shape, dtype="float32" + ) + + out_linear_weight = paddle.rand( + shape=self.output_w_shape, dtype="float32" + ) + out_linear_bias = paddle.rand( + shape=self.output_b_shape, dtype="float32" + ) + + output = F.fused_gate_attention( + query=query, + qkv_weight=qkv_weight, + gate_linear_weight=gate_linear_weight, + gate_linear_bias=gate_linear_bias, + out_linear_weight=out_linear_weight, + out_linear_bias=out_linear_bias, + nonbatched_bias=nonbatched_bias, + attn_mask=attn_mask, + has_gating=True, + merge_qkv=True, + ) + print(f"output.shape={output.shape}") + self.assertEqual(output.shape, self.out_shape) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index 3b6869f88c6283b304e61fd9b7e7c35eac9e622f..c663d6248feb0fcd96dafee72f551eeef3fd2ddb 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -22,6 +22,7 @@ from .layer.fused_transformer import ( ) # noqa: F401 from .layer.fused_ec_moe import FusedEcMoe # noqa: F401 from .layer.fused_dropout_add import FusedDropoutAdd # noqa: F401 +from .layer.fused_dropout_nd import FusedDropout # noqa: F401 __all__ = [ # noqa 'FusedMultiHeadAttention', diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index e5d17294329d464910f9ccb8e6c8e88f9a979284..ccccadd284e9edcaeb172616941e877568f0533f 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -19,6 +19,7 @@ from .fused_matmul_bias import fused_matmul_bias, fused_linear from .fused_transformer import fused_bias_dropout_residual_layer_norm from .fused_ec_moe import fused_ec_moe from .fused_dropout_add import fused_dropout_add +from .fused_gate_attention import fused_gate_attention __all__ = [ diff --git a/python/paddle/incubate/nn/functional/fused_gate_attention.py b/python/paddle/incubate/nn/functional/fused_gate_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..13833683449b59af48a871f387214467c55aff79 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_gate_attention.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import _legacy_C_ops +from paddle.fluid.framework import _non_static_mode + + +def fused_gate_attention( + query, + key=None, + query_weight=None, + key_weight=None, + value_weight=None, + qkv_weight=None, + gate_linear_weight=None, + gate_linear_bias=None, + out_linear_weight=None, + out_linear_bias=None, + nonbatched_bias=None, + attn_mask=None, + has_gating=True, + merge_qkv=True, + use_flash_attn=False, +): + r""" + Attention mapps queries and a set of key-value pairs to outputs, and + Gate Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. This API only + support self_attention. The pseudo code is as follows: + + .. code-block:: python + + c = c ** (-0.5) + q = paddle.einsum('nbqa,ahc->nbqhc', q_data, query_w) * c + k = paddle.einsum('nbka,ahc->nbkhc', m_data, key_w) + v = paddle.einsum('nbka,ahc->nbkhc', m_data, value_w) + logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q, k) + bias + + if nonbatched_bias is not None: + logits += paddle.unsqueeze(nonbatched_bias, axis=1) + + weights = paddle.nn.functional.softmax(logits) + weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v) + + if has_gating: + gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data, gating_w) + gating_b + gate_values = paddle.nn.functional.sigmoid(gate_values) + weighted_avg *= gate_values + + output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg, output_w) + output_b + + + Args: + query (Tensor): The input query tensor. The shape is [batch_size, msa_len, res_len, q_dim]. + key (Tensor, optional): The input key tensor, which can be set when + merge_qkv is False. The shape is [batch_size, msa_len, m_size, kv_dim]. + query_weight (Tensor, optional): The weight of query linear, which + should be set when input key is not None. The shape is [q_dim, num_heads, head_dim]. + key_weight (Tensor, optional): The weight of key linear, which should + be set when input key is not None. The shape is [kv_dim, num_heads, head_dim]. + value_weight (Tensor, optional): The weight of value linear, which should + be set when input key is not None. The shape is [kv_dim, num_heads, head_dim]. + qkv_weight (Tensor, optional): The weight of qkv linear, which should + be set when merge_qkv is True. The shape is [3, num_heads, head_dim, q_dim]. + gate_linear_weight (Tensor, optional): The weight of gating linear, + which should be set when has_gating is True. The shape is [q_dim, num_heads, head_dim]. + gate_linear_bias (Tensor, optional): The bias of gating linear, which + should be set when has_gating is True. The shape is [num_heads, head_dim]. Default None. + out_linear_weight (Tensor, optional): The weight of output linear. The shape is [num_heads, head_dim, q_dim]. + out_linear_bias (Tensor): The bias of output linear, the shape is [q_dim]. Default None. + nonbatched_bias (Tensor, optional): The extra bias. The shape is [batch_size, 1, num_heads, res_len, m_size]. Default None. + attn_mask (Tensor, optional): The attention mask. The shape is [batch_size, msa_len, 1, 1, res_len]. Default None. + has_gating (bool, optional): Whether has the gating linear. Default True. + merge_qkv (bool, optional): Whether has the gating linear. Default True. + + Returns: + Tensor: The output Tensor, the data type and shape is same as `query`. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + import paddle.incubate.nn.functional as F + + # batch_size = 2 + # msa_len = 4 + # res_len = 2 + # q_dim = 4 + # num_heads = 8 + # head_dim = 4 + # m_size = res_len (when merge_qkv is True) + + # query: [batch_size, msa_len, res_len, q_dim] + query = paddle.rand(shape=[2, 4, 2, 4], dtype="float32") + + # qkv_weight: [3, n_heads, head_dim, q_dim] + qkv_weight = paddle.rand(shape=[3, 8, 4, 4], dtype="float32") + + # nonbatched_bias: [batch_size, 1, num_heads, res_len, m_size] + nonbatched_bias = paddle.rand(shape=[2, 1, 8, 2, 2], dtype="float32") + + # attn_mask: [batch_size, msa_len, 1, 1, m_size] + attn_mask = paddle.rand(shape=[2, 4, 1, 1, 2], dtype="float32") + + # gate_linear_weight: [q_dim, num_heads, head_dim] + gate_linear_weight = paddle.rand(shape=[4, 8, 4], dtype="float32") + # gate_bias: [num_heads, head_dim] + gate_linear_bias = paddle.rand(shape=[8, 4], dtype="float32") + + # out_linear_weight: [num_heads, head_dim, q_dim] + out_linear_weight = paddle.rand(shape=[8, 4, 4], dtype="float32") + # out_linear_bias: [q_dim] + out_linear_bias = paddle.rand(shape=[4], dtype="float32") + + # output: [batch_size, msa_len, res_len, q_dim] + output = F.fused_gate_attention( + query=query, + qkv_weight=qkv_weight, + gate_linear_weight=gate_linear_weight, + gate_linear_bias=gate_linear_bias, + out_linear_weight=out_linear_weight, + out_linear_bias=out_linear_bias, + nonbatched_bias=nonbatched_bias, + attn_mask=attn_mask, + has_gating=True, + merge_qkv=True) + print(output.shape) + # [2, 4, 2, 4] + + """ + if _non_static_mode(): + _, _, _, _, _, _, _, out = _legacy_C_ops.fused_gate_attention( + query, + key, + query_weight, + key_weight, + value_weight, + qkv_weight, + nonbatched_bias, + attn_mask, + gate_linear_weight, + gate_linear_bias, + out_linear_weight, + out_linear_bias, + 'has_gating', + has_gating, + 'merge_qkv', + merge_qkv, + "use_flash_attn", + use_flash_attn, + ) + return out diff --git a/python/paddle/incubate/nn/layer/fused_dropout_nd.py b/python/paddle/incubate/nn/layer/fused_dropout_nd.py new file mode 100644 index 0000000000000000000000000000000000000000..b8aecf0c3978a68558a2178290a79c2b71c04217 --- /dev/null +++ b/python/paddle/incubate/nn/layer/fused_dropout_nd.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import _legacy_C_ops +from paddle.fluid.framework import _non_static_mode + + +class FusedDropout(paddle.nn.Layer): + r""" + Dropout is a regularization technique for reducing overfitting by preventing + neuron co-adaption during training as described in the paper: + `Improving neural networks by preventing co-adaptation of feature detectors `_ + The dropout operator randomly sets the outputs of some units to zero, while upscale others + according to the given dropout probability. + + It is an optimized implementation for ``paddle.nn.Dropout``. + + In dygraph mode, please use ``eval()`` to switch to evaluation mode, where dropout is disabled. + + Parameters: + p (float|int, optional): Probability of setting units to zero. Default: 0.5 + axis (int|list|tuple, optional): The axis along which the dropout is performed. Default: None. + mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train (default), upscale the output at training time + + - train: :math:`out = input \times \frac{mask}{(1.0 - p)}` + - inference: :math:`out = input` + + 2. downscale_in_infer, downscale the output at inference + + - train: :math:`out = input \times mask` + - inference: :math:`out = input \times (1.0 - p)` + name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: N-D tensor. + - output: N-D tensor, the same shape as input. + + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], dtype="float32") + m = paddle.incubate.nn.FusedDropout(p=0.5) + + y_train = m(x) + print(y_train) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[2., 0., 6.], + # [0., 0., 0.]]) + + m.eval() # switch the model to test phase + y_test = m(x) + print(y_test) + # Tensor(shape=[2, 3], dtype=float32, place=Place(gpu:0), stop_gradient=True, + # [[1., 2., 3.], + # [4., 5., 6.]]) + """ + + def __init__(self, p=0.5, axis=None, mode="upscale_in_train", name=None): + super().__init__() + + if not isinstance(p, (float, int)): + raise TypeError("p argument should be a number") + if p < 0 or p > 1: + raise ValueError("p argument should between 0 and 1") + + mode = ( + 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode + ) # semantic transfer + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" + ) + + if axis and not isinstance(axis, (int, list, tuple)): + raise TypeError("datatype of axis argument should be int or list") + + self.p = p + self.mode = mode + self.name = name + + self.axis = None + if axis is not None: + self.axis = [axis] if isinstance(axis, int) else list(axis) + + def forward(self, input): + # fast return for p == 0 + if self.p == 0: + return input + + if self.axis is not None and _non_static_mode(): + seed = None + if paddle.static.default_main_program().random_seed != 0: + seed = paddle.static.default_main_program().random_seed + + out, mask = _legacy_C_ops.dropout_nd( + input, + 'dropout_prob', + self.p, + 'is_test', + not self.training, + 'fix_seed', + seed is not None, + 'seed', + seed if seed is not None else 0, + 'dropout_implementation', + self.mode, + 'axis', + self.axis, + ) + else: + out = paddle.nn.functional.dropout( + input, + p=self.p, + axis=self.axis, + training=self.training, + mode=self.mode, + name=self.name, + ) + return out + + def extra_repr(self): + name_str = f', name={self.name}' if self.name else '' + return 'p={}, axis={}, mode={}{}'.format( + self.p, self.axis, self.mode, name_str + )