From 9f3613f312e35e9226e61e1f1d663ef0dbcf2446 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 27 Oct 2021 14:27:22 +0800 Subject: [PATCH] Fused transformer encoder layer and fused feedforward layer (#36604) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 本PR是fused_transformer的layer层代码,包含FusedFeedForward的layer层代码和FusedTransformerEncoderLayer的代码。 --- paddle/fluid/imperative/amp_auto_cast.cc | 16 ++ .../contrib/mixed_precision/fp16_lists.py | 4 +- .../contrib/mixed_precision/fp16_utils.py | 37 ++- .../unittests/test_fused_attention_op_api.py | 2 +- python/paddle/incubate/__init__.py | 2 + python/paddle/incubate/nn/__init__.py | 7 +- .../nn/functional/fused_transformer.py | 14 +- .../incubate/nn/layer/fused_transformer.py | 244 +++++++++++++----- python/setup.py.in | 4 + 9 files changed, 247 insertions(+), 83 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index b0d86f6db9..f2ea692ad0 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { + if (pair.first == "LnScale" || pair.first == "LnBias" || + pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || + pair.first == "Ln1Scale" || pair.first == "Ln1Bias") { + continue; + } + } + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float16"; for (auto& var : pair.second) { @@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforwad") && + dst_type == framework::proto::VarType::FP32) { + if (pair.first != "LnScale" && pair.first != "LnBias" && + pair.first != "Ln2Scale" && pair.first != "Ln2Bias" && + pair.first != "Ln1Scale" && pair.first != "Ln1Bias") { + continue; + } + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type); diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 5b662b09f1..95e597c703 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -104,7 +104,7 @@ black_list = { 'reduce_sum', } -# This set contains two types of ops. All ops supported fp16 calculation. One +# This set contains two types of ops. All ops supported fp16 calculation. One # of two types is considered numerically-safe, but may be made unsafe by an # upstream blacklist op. Another type do not have numerically-significant # effects, like stack, flatten2. @@ -153,6 +153,8 @@ gray_list = { 'c_allreduce_sum', 'concat', 'split', + 'fused_feedforward', + 'fused_attention', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 6317be9a2e..36546c1de1 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -40,7 +40,7 @@ _fp16_guard_pattern = "__use_fp16__" def _rename_arg(op, old_name, new_name): """ - If an op has old_name input and output, rename these input + If an op has old_name input and output, rename these input args new_name. Args: @@ -89,6 +89,10 @@ def _keep_fp32_input(op, in_name): return in_name not in {'X', 'Z'} if op_type == 'resnet_unit': return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return in_name in { + 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" + } return False @@ -98,6 +102,11 @@ def _keep_fp32_output(op, out_name): return out_name != 'Y' if op_type == 'resnet_unit': return out_name not in {'Y', 'ConvX', 'ConvZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return out_name in { + 'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', + 'Ln1Variance' + } return False @@ -256,16 +265,16 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False): ops (list): A list of ops. cur_op (Operator): Current operator which has var_name variable. var_name (string): Variable name. - search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. + search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. """ post_op = [] if search_all: """ - \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come - from startup_prog block and \"ops\" list from main_prog block. - By setting idx to -1, we'll start looking for post-ops from the top of the list. - If search_all is False, assume that \"cur_op\" is in \"ops\" list, - so to reduce the time of search we can start iterating from \"cur_op\" idx. + \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come + from startup_prog block and \"ops\" list from main_prog block. + By setting idx to -1, we'll start looking for post-ops from the top of the list. + If search_all is False, assume that \"cur_op\" is in \"ops\" list, + so to reduce the time of search we can start iterating from \"cur_op\" idx. """ idx = -1 else: @@ -517,19 +526,19 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): def rewrite_program(main_prog, amp_lists): """ - Traverse all ops in current block and insert cast op according to + Traverse all ops in current block and insert cast op according to which set current op belongs to. 1. When an op belongs to the black list, add it to black set 2. When an op belongs to the white list, add it to white set - 3. When an op belongs to the gray list. If one - of its inputs is the output of black set op or black list op, - add it to black set. If all of its previous ops are not black - op and one of its inputs is the output of white set op or + 3. When an op belongs to the gray list. If one + of its inputs is the output of black set op or black list op, + add it to black set. If all of its previous ops are not black + op and one of its inputs is the output of white set op or white list op, add it to white set. 4. When an op isn't in the lists, add it to black op set. - 5. Add necessary cast ops to make sure that black set op will be - computed in fp32 mode, while white set op will be computed in + 5. Add necessary cast ops to make sure that black set op will be + computed in fp32 mode, while white set op will be computed in fp16 mode. Args: diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index e59ecc19d0..5fa9446763 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -107,7 +107,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, q = qkv[0:1, ::] q = q.reshape(batch_size, num_head, seq_len, head_dim) - k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] + k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] k = k.reshape(batch_size, num_head, seq_len, head_dim) v = qkv[2::] v = v.reshape(batch_size, num_head, seq_len, head_dim) diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 644b934814..f44e38347e 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -23,6 +23,8 @@ from .tensor import segment_mean from .tensor import segment_max from .tensor import segment_min +from . import nn #noqa: F401 + __all__ = [ 'LookAhead', 'ModelAverage', diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index aada78e4ec..f359ec1e0d 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 +from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 +from .layer.fused_transformer import FusedFeedForward # noqa: F401 +from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', + 'FusedFeedForward', + 'FusedTransformerEncoderLayer', + ] diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 68109b4ae6..f692283841 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -218,7 +218,7 @@ def fused_multi_head_attention(x, `[batch\_size, sequence\_len, embed\_dim]`. qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. - pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture + pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. @@ -229,12 +229,12 @@ def fused_multi_head_attention(x, qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. - attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to - some unwanted positions, usually the paddings or the subsequent positions. It is a tensor - with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the - data type is bool, the unwanted positions have `False` values and the others have `True` values. - When the data type is int, the unwanted positions have 0 values and the others have 1 values. - When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. + attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent positions. It is a tensor + with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the + data type is bool, the unwanted positions have `False` values and the others have `True` values. + When the data type is int, the unwanted positions have 0 values and the others have 1 values. + When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. dropout_rate (float, optional): The dropout probability used on attention weights to drop some attention targets for the dropout after attention. diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 16588dcef3..bc887875c7 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy from paddle.nn import functional as F from paddle.incubate.nn import functional as incubate_f from paddle.nn import Layer from paddle.framework import ParamAttr import paddle -from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list from paddle.nn.initializer import Constant import collections @@ -35,16 +33,16 @@ class FusedMultiHeadAttention(Layer): embed_dim (int): The expected feature size in the input and output. num_heads (int): The number of heads in multi-head attention. dropout_rate (float, optional): The dropout probability used on attention - weights to drop some attention targets for the dropout after attention. + weights to drop some attention targets for the dropout after attention. 0 for no dropout. Default 0.5. attn_dropout_rate (float, optional): The dropout probability used on attention - weights to drop some attention targets for the dropout in attention. + weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5. kdim (int, optional): The feature size in key. If None, assumed equal to `embed_dim`. Default None. vdim (int, optional): The feature size in value. If None, assumed equal to `embed_dim`. Default None. - normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) + normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. need_weights (bool, optional): Indicate whether to return the attention weights. Now, only False is supported. Default False. @@ -56,7 +54,10 @@ class FusedMultiHeadAttention(Layer): If it is set to False, this layer will not have trainable bias parameter. See usage for details in :code:`ParamAttr` . Examples: + .. code-block:: python + + # required: gpu import paddle # input: [batch_size, sequence_length, embed_dim] query = paddle.rand((2, 4, 128)) @@ -154,17 +155,17 @@ class FusedMultiHeadAttention(Layer): to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): Now, only None is supported. Default None. Returns: Tensor|tuple: It is a tensor that has the same shape and data type \ - as `query`, representing attention output. + as `query`, representing attention output. """ if attn_mask is not None: # Support bool or int mask @@ -192,26 +193,114 @@ class FusedMultiHeadAttention(Layer): class FusedFeedForward(Layer): + """ + Parameters: + d_model (int): The expected feature size in the input and output. + dim_feedforward (int): The hidden layer size. + dropout_rate (float, optional): The dropout probability used in pre-process + and post-precess. Default 0.1 + activation (str, optional): The activation function. Default relu. + act_dropout_rate (float, optional): The dropout probability after activition. + If None, use the value of `dropout_rate`. Default None + normalize_before (bool, optional): Indicate whether to put layer normalization + into, preprocessing or postprocessing. Default False + weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer. + The default value is None and the weight will be initialized to zero. For detailed + information, please refer to paddle.ParamAttr. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of thi layer. + If it is set to False, no bias will be added to the output. If it is set to None or one + kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed + information, please refer to paddle.ParamAttr. The default value is None and the bias + will be initialized to zero. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn import FusedFeedForward + + fused_feedforward_layer = FusedFeedForward(8, 8) + x = paddle.rand((1, 8, 8)) + out = fused_feedforward_layer(x) + print(out.numpy().shape) + # (1, 8, 8) + """ + def __init__(self, d_model, dim_feedforward, - dropout=0.1, + dropout_rate=0.1, activation="relu", - act_dropout=None, + act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None): super(FusedFeedForward, self).__init__() - raise NotImplementedError() + assert d_model > 0, ( + "Expected d_model to be greater than 0, but recieved {}".format( + d_model)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, but recieved {}". + format(dim_feedforward)) + + self._dtype = self._helper.get_default_dtype() + self._d_model = d_model + self._dim_feedforward = dim_feedforward + self._dropout_rate = dropout_rate + self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate + self._act_method = activation + self._normalize_before = normalize_before + + self._linear1_weight = self.create_parameter( + shape=[d_model, dim_feedforward], + attr=weight_attr, + dtype=self._dtype, + is_bias=False) + self._linear1_bias = self.create_parameter( + shape=[dim_feedforward], + attr=bias_attr, + dtype=self._dtype, + is_bias=True) + + self._linear2_weight = self.create_parameter( + shape=[dim_feedforward, d_model], + attr=weight_attr, + dtype=self._dtype, + is_bias=False) + + self._linear2_bias = self.create_parameter( + shape=[d_model], attr=bias_attr, dtype=self._dtype, is_bias=True) + + self._ln1_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False, + default_initializer=Constant(1.0)) + self._ln1_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True) + + self._ln2_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False, + default_initializer=Constant(1.0)) + self._ln2_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True) def forward(self, src, cache=None): - raise NotImplementedError() + out = incubate_f.fused_feedforward( + src, self._linear1_weight, self._linear2_weight, self._linear1_bias, + self._linear2_bias, self._ln1_scale, self._ln1_bias, + self._ln2_scale, self._ln2_bias, self._dropout_rate, + self._act_dropout_rate, self._act_method, self._normalize_before) + return out class FusedTransformerEncoderLayer(Layer): """ - TransformerEncoderLayer is composed of two sub-layers which are self (multi-head) + FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head) attention and feedforward network. Before and after each sub-layer, pre-process and post-precess would be applied on the input and output accordingly. If `normalize_before` is True, pre-process is layer normalization and post-precess @@ -222,14 +311,14 @@ class FusedTransformerEncoderLayer(Layer): d_model (int): The expected feature size in the input and output. nhead (int): The number of heads in multi-head attention(MHA). dim_feedforward (int): The hidden layer size in the feedforward network(FFN). - dropout (float, optional): The dropout probability used in pre-process + dropout_rate (float, optional): The dropout probability used in pre-process and post-precess of MHA and FFN sub-layer. Default 0.1 activation (str, optional): The activation function in the feedforward network. Default relu. - attn_dropout (float, optional): The dropout probability used + attn_dropout_rate (float, optional): The dropout probability used in MHA to drop some attention target. If None, use the value of `dropout`. Default None - act_dropout (float, optional): The dropout probability used after FFN + act_dropout_rate (float, optional): The dropout probability used after FFN activition. If None, use the value of `dropout`. Default None normalize_before (bool, optional): Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer @@ -241,7 +330,7 @@ class FusedTransformerEncoderLayer(Layer): MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN. Otherwise, MHA and FFN both use it as `weight_attr` to create parameters. Default: None, which means the default weight parameter property is used. - See usage for details in :code:`ParamAttr` . + See usage for details in :code:`ParamAttr` . bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN. @@ -249,21 +338,21 @@ class FusedTransformerEncoderLayer(Layer): The `False` value means the corresponding layer would not have trainable bias parameter. See usage for details in :code:`ParamAttr` . Default: None, which means the default bias parameter property is used. - + Examples: .. code-block:: python - + # required: gpu import paddle - from paddle.nn import TransformerEncoderLayer + from paddle.incubate.nn import FusedTransformerEncoderLayer # encoder input: [batch_size, src_len, d_model] enc_input = paddle.rand((2, 4, 128)) # self attention mask: [batch_size, n_head, src_len, src_len] attn_mask = paddle.rand((2, 2, 4, 4)) - encoder_layer = TransformerEncoderLayer(128, 2, 512) + encoder_layer = FusedTransformerEncoderLayer(128, 2, 512) enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128] """ @@ -271,10 +360,10 @@ class FusedTransformerEncoderLayer(Layer): d_model, nhead, dim_feedforward, - dropout=0.1, + dropout_rate=0.1, activation="relu", - attn_dropout=None, - act_dropout=None, + attn_dropout_rate=None, + act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None): @@ -283,7 +372,35 @@ class FusedTransformerEncoderLayer(Layer): self._config.pop("__class__", None) # py3 super(FusedTransformerEncoderLayer, self).__init__() - raise NotImplementedError() + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate + act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 2) + bias_attrs = _convert_param_attr_to_list(bias_attr, 2) + + self.fused_attn = FusedMultiHeadAttention( + d_model, + nhead, + dropout_rate=attn_dropout_rate, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0]) + + self.ffn = FusedFeedForward( + d_model, + dim_feedforward, + dropout_rate=dropout_rate, + act_dropout_rate=act_dropout_rate, + normalize_before=self.normalize_before, + weight_attr=weight_attrs[1], + bias_attr=bias_attrs[1]) def forward(self, src, src_mask=None, cache=None): """ @@ -296,11 +413,11 @@ class FusedTransformerEncoderLayer(Layer): to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`. See `TransformerEncoderLayer.gen_cache` for more details. It is @@ -315,7 +432,16 @@ class FusedTransformerEncoderLayer(Layer): incremental length. See `MultiHeadAttention.gen_cache` and \ `MultiHeadAttention.forward` for more details. """ - raise NotImplementedError() + src_mask = _convert_attention_mask(src_mask, src.dtype) + if cache is None: + attn_out = self.fused_attn(src, attn_mask=src_mask) + else: + attn_out, incremental_cache = self.fused_attn( + src, attn_mask=src_mask, cache=cache) + + ffn_out = self.ffn(attn_out) + + return ffn_out if cache is None else (ffn_out, incremental_cache) class FusedTransformer(Layer): @@ -326,12 +452,12 @@ class FusedTransformer(Layer): Please refer to `Attention is all you need `_ , and see `TransformerEncoder` and `TransformerDecoder` for more details. - + Users can configurate the model architecture with corresponding parameters. Note the usage of `normalize_before` representing where to apply layer normalization (in pre-process or post-precess of multi-head attention or FFN), and some transformer like models are different on this, such as - `BERT `_ and `GPT2 `_ . + `BERT `_ and `GPT2 `_ . The default architecture here places layer normalization in post-process and applies another layer normalization on the output of last encoder/decoder layer. @@ -357,30 +483,30 @@ class FusedTransformer(Layer): Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Default False weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property. - If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3, - `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]` - would be used as `weight_attr` for cross attention of `TransformerDecoder`, - and `weight_attr[2]` would be used as `weight_attr` for linear in FFN. - If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention - and cross attntion and `weight_attr[1]` would be used as `weight_attr` for - linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr` - for self attention, cross attention and linear in FFN. Otherwise, - the three sub-layers all uses it as `weight_attr` to create parameters. - Default: None, which means the default weight parameter property is used. + If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3, + `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]` + would be used as `weight_attr` for cross attention of `TransformerDecoder`, + and `weight_attr[2]` would be used as `weight_attr` for linear in FFN. + If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention + and cross attntion and `weight_attr[1]` would be used as `weight_attr` for + linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr` + for self attention, cross attention and linear in FFN. Otherwise, + the three sub-layers all uses it as `weight_attr` to create parameters. + Default: None, which means the default weight parameter property is used. See usage for details - in :code:`ParamAttr` . + in :code:`ParamAttr` . bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. - If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3, - `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]` - would be used as `bias_attr` for cross attention of `TransformerDecoder`, - and `bias_attr[2]` would be used as `bias_attr` for linear in FFN. - If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention - and cross attntion and `bias_attr[1]` would be used as `bias_attr` for - linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr` - for self attention, cross attention and linear in FFN. Otherwise, - the three sub-layers all uses it as `bias_attr` to create parameters. - The `False` value means the corresponding layer would not have trainable - bias parameter. See usage for details in :code:`ParamAttr` . + If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3, + `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]` + would be used as `bias_attr` for cross attention of `TransformerDecoder`, + and `bias_attr[2]` would be used as `bias_attr` for linear in FFN. + If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention + and cross attntion and `bias_attr[1]` would be used as `bias_attr` for + linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr` + for self attention, cross attention and linear in FFN. Otherwise, + the three sub-layers all uses it as `bias_attr` to create parameters. + The `False` value means the corresponding layer would not have trainable + bias parameter. See usage for details in :code:`ParamAttr` . Default: None,which means the default bias parameter property is used. custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder. Default None diff --git a/python/setup.py.in b/python/setup.py.in index b10d5df541..b246225cba 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -163,6 +163,7 @@ packages=['paddle', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', 'paddle.incubate.tensor', + 'paddle.incubate.nn', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic', @@ -230,6 +231,9 @@ packages=['paddle', 'paddle.text', 'paddle.text.datasets', 'paddle.incubate', + 'paddle.incubate.nn', + 'paddle.incubate.nn.functional', + 'paddle.incubate.nn.layer', 'paddle.io', 'paddle.optimizer', 'paddle.nn', -- GitLab