diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index b0d86f6db9f960bc9b5e4c8d06ce368b6cfb4f1f..f2ea692ad088085becd56b6ebfdde2af84abe468 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { + if (pair.first == "LnScale" || pair.first == "LnBias" || + pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || + pair.first == "Ln1Scale" || pair.first == "Ln1Bias") { + continue; + } + } + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float16"; for (auto& var : pair.second) { @@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, pair.first == "X" && dst_type == framework::proto::VarType::FP32) { continue; } + if ((op_type == "fused_attention" || op_type == "fused_feedforwad") && + dst_type == framework::proto::VarType::FP32) { + if (pair.first != "LnScale" && pair.first != "LnBias" && + pair.first != "Ln2Scale" && pair.first != "Ln2Bias" && + pair.first != "Ln1Scale" && pair.first != "Ln1Bias") { + continue; + } + } VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to " << framework::DataTypeToString(dst_type); diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 5b662b09f1cf611f7a53ad5f7c89e3a9d0f19c16..95e597c703b4e4e004c0d133abfe6966c6df9734 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -104,7 +104,7 @@ black_list = { 'reduce_sum', } -# This set contains two types of ops. All ops supported fp16 calculation. One +# This set contains two types of ops. All ops supported fp16 calculation. One # of two types is considered numerically-safe, but may be made unsafe by an # upstream blacklist op. Another type do not have numerically-significant # effects, like stack, flatten2. @@ -153,6 +153,8 @@ gray_list = { 'c_allreduce_sum', 'concat', 'split', + 'fused_feedforward', + 'fused_attention', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 6317be9a2e2e2051accf0d10d2b7faa30a4d307d..36546c1de12048d0327e859b83016fc73cffd4f7 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -40,7 +40,7 @@ _fp16_guard_pattern = "__use_fp16__" def _rename_arg(op, old_name, new_name): """ - If an op has old_name input and output, rename these input + If an op has old_name input and output, rename these input args new_name. Args: @@ -89,6 +89,10 @@ def _keep_fp32_input(op, in_name): return in_name not in {'X', 'Z'} if op_type == 'resnet_unit': return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return in_name in { + 'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias" + } return False @@ -98,6 +102,11 @@ def _keep_fp32_output(op, out_name): return out_name != 'Y' if op_type == 'resnet_unit': return out_name not in {'Y', 'ConvX', 'ConvZ'} + if op_type in ['fused_attention', 'fused_feedforward']: + return out_name in { + 'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean', + 'Ln1Variance' + } return False @@ -256,16 +265,16 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False): ops (list): A list of ops. cur_op (Operator): Current operator which has var_name variable. var_name (string): Variable name. - search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. + search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. """ post_op = [] if search_all: """ - \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come - from startup_prog block and \"ops\" list from main_prog block. - By setting idx to -1, we'll start looking for post-ops from the top of the list. - If search_all is False, assume that \"cur_op\" is in \"ops\" list, - so to reduce the time of search we can start iterating from \"cur_op\" idx. + \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come + from startup_prog block and \"ops\" list from main_prog block. + By setting idx to -1, we'll start looking for post-ops from the top of the list. + If search_all is False, assume that \"cur_op\" is in \"ops\" list, + so to reduce the time of search we can start iterating from \"cur_op\" idx. """ idx = -1 else: @@ -517,19 +526,19 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): def rewrite_program(main_prog, amp_lists): """ - Traverse all ops in current block and insert cast op according to + Traverse all ops in current block and insert cast op according to which set current op belongs to. 1. When an op belongs to the black list, add it to black set 2. When an op belongs to the white list, add it to white set - 3. When an op belongs to the gray list. If one - of its inputs is the output of black set op or black list op, - add it to black set. If all of its previous ops are not black - op and one of its inputs is the output of white set op or + 3. When an op belongs to the gray list. If one + of its inputs is the output of black set op or black list op, + add it to black set. If all of its previous ops are not black + op and one of its inputs is the output of white set op or white list op, add it to white set. 4. When an op isn't in the lists, add it to black op set. - 5. Add necessary cast ops to make sure that black set op will be - computed in fp32 mode, while white set op will be computed in + 5. Add necessary cast ops to make sure that black set op will be + computed in fp32 mode, while white set op will be computed in fp16 mode. Args: diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index e59ecc19d05cb9a7fb0811518bbdda155b16c731..5fa9446763b1fe711490806c60a36754ac4e2cb7 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -107,7 +107,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, q = qkv[0:1, ::] q = q.reshape(batch_size, num_head, seq_len, head_dim) - k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] + k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] k = k.reshape(batch_size, num_head, seq_len, head_dim) v = qkv[2::] v = v.reshape(batch_size, num_head, seq_len, head_dim) diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index 644b934814020f9d781f771f19896126186e50cd..f44e38347e5383822ce2f4e6b17fa9211031ae72 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -23,6 +23,8 @@ from .tensor import segment_mean from .tensor import segment_max from .tensor import segment_min +from . import nn #noqa: F401 + __all__ = [ 'LookAhead', 'ModelAverage', diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index aada78e4ec6a49497e89c38515f9d5cde6943488..f359ec1e0d8425f975b119fcc3f876bb348c766c 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 +from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 +from .layer.fused_transformer import FusedFeedForward # noqa: F401 +from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401 __all__ = [ #noqa 'FusedMultiHeadAttention', + 'FusedFeedForward', + 'FusedTransformerEncoderLayer', + ] diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 68109b4ae694ac63cf5e0bc3adbec460f0b46db2..f6922838418074a5fb2934a9b2c68087641a129f 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -218,7 +218,7 @@ def fused_multi_head_attention(x, `[batch\_size, sequence\_len, embed\_dim]`. qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. - pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture + pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. @@ -229,12 +229,12 @@ def fused_multi_head_attention(x, qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. - attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to - some unwanted positions, usually the paddings or the subsequent positions. It is a tensor - with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the - data type is bool, the unwanted positions have `False` values and the others have `True` values. - When the data type is int, the unwanted positions have 0 values and the others have 1 values. - When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. + attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to + some unwanted positions, usually the paddings or the subsequent positions. It is a tensor + with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the + data type is bool, the unwanted positions have `False` values and the others have `True` values. + When the data type is int, the unwanted positions have 0 values and the others have 1 values. + When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. dropout_rate (float, optional): The dropout probability used on attention weights to drop some attention targets for the dropout after attention. diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index 16588dcef3d27ddf87d6e3967fa4105ae527da21..bc887875c773d5a9013c51635b820b1aa9c0a01c 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy from paddle.nn import functional as F from paddle.incubate.nn import functional as incubate_f from paddle.nn import Layer from paddle.framework import ParamAttr import paddle -from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list from paddle.nn.initializer import Constant import collections @@ -35,16 +33,16 @@ class FusedMultiHeadAttention(Layer): embed_dim (int): The expected feature size in the input and output. num_heads (int): The number of heads in multi-head attention. dropout_rate (float, optional): The dropout probability used on attention - weights to drop some attention targets for the dropout after attention. + weights to drop some attention targets for the dropout after attention. 0 for no dropout. Default 0.5. attn_dropout_rate (float, optional): The dropout probability used on attention - weights to drop some attention targets for the dropout in attention. + weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5. kdim (int, optional): The feature size in key. If None, assumed equal to `embed_dim`. Default None. vdim (int, optional): The feature size in value. If None, assumed equal to `embed_dim`. Default None. - normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) + normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. need_weights (bool, optional): Indicate whether to return the attention weights. Now, only False is supported. Default False. @@ -56,7 +54,10 @@ class FusedMultiHeadAttention(Layer): If it is set to False, this layer will not have trainable bias parameter. See usage for details in :code:`ParamAttr` . Examples: + .. code-block:: python + + # required: gpu import paddle # input: [batch_size, sequence_length, embed_dim] query = paddle.rand((2, 4, 128)) @@ -154,17 +155,17 @@ class FusedMultiHeadAttention(Layer): to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): Now, only None is supported. Default None. Returns: Tensor|tuple: It is a tensor that has the same shape and data type \ - as `query`, representing attention output. + as `query`, representing attention output. """ if attn_mask is not None: # Support bool or int mask @@ -192,26 +193,114 @@ class FusedMultiHeadAttention(Layer): class FusedFeedForward(Layer): + """ + Parameters: + d_model (int): The expected feature size in the input and output. + dim_feedforward (int): The hidden layer size. + dropout_rate (float, optional): The dropout probability used in pre-process + and post-precess. Default 0.1 + activation (str, optional): The activation function. Default relu. + act_dropout_rate (float, optional): The dropout probability after activition. + If None, use the value of `dropout_rate`. Default None + normalize_before (bool, optional): Indicate whether to put layer normalization + into, preprocessing or postprocessing. Default False + weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer. + The default value is None and the weight will be initialized to zero. For detailed + information, please refer to paddle.ParamAttr. + bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of thi layer. + If it is set to False, no bias will be added to the output. If it is set to None or one + kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed + information, please refer to paddle.ParamAttr. The default value is None and the bias + will be initialized to zero. + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn import FusedFeedForward + + fused_feedforward_layer = FusedFeedForward(8, 8) + x = paddle.rand((1, 8, 8)) + out = fused_feedforward_layer(x) + print(out.numpy().shape) + # (1, 8, 8) + """ + def __init__(self, d_model, dim_feedforward, - dropout=0.1, + dropout_rate=0.1, activation="relu", - act_dropout=None, + act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None): super(FusedFeedForward, self).__init__() - raise NotImplementedError() + assert d_model > 0, ( + "Expected d_model to be greater than 0, but recieved {}".format( + d_model)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, but recieved {}". + format(dim_feedforward)) + + self._dtype = self._helper.get_default_dtype() + self._d_model = d_model + self._dim_feedforward = dim_feedforward + self._dropout_rate = dropout_rate + self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate + self._act_method = activation + self._normalize_before = normalize_before + + self._linear1_weight = self.create_parameter( + shape=[d_model, dim_feedforward], + attr=weight_attr, + dtype=self._dtype, + is_bias=False) + self._linear1_bias = self.create_parameter( + shape=[dim_feedforward], + attr=bias_attr, + dtype=self._dtype, + is_bias=True) + + self._linear2_weight = self.create_parameter( + shape=[dim_feedforward, d_model], + attr=weight_attr, + dtype=self._dtype, + is_bias=False) + + self._linear2_bias = self.create_parameter( + shape=[d_model], attr=bias_attr, dtype=self._dtype, is_bias=True) + + self._ln1_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False, + default_initializer=Constant(1.0)) + self._ln1_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True) + + self._ln2_scale = self.create_parameter( + shape=[d_model], + attr=None, + is_bias=False, + default_initializer=Constant(1.0)) + self._ln2_bias = self.create_parameter( + shape=[d_model], attr=None, is_bias=True) def forward(self, src, cache=None): - raise NotImplementedError() + out = incubate_f.fused_feedforward( + src, self._linear1_weight, self._linear2_weight, self._linear1_bias, + self._linear2_bias, self._ln1_scale, self._ln1_bias, + self._ln2_scale, self._ln2_bias, self._dropout_rate, + self._act_dropout_rate, self._act_method, self._normalize_before) + return out class FusedTransformerEncoderLayer(Layer): """ - TransformerEncoderLayer is composed of two sub-layers which are self (multi-head) + FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head) attention and feedforward network. Before and after each sub-layer, pre-process and post-precess would be applied on the input and output accordingly. If `normalize_before` is True, pre-process is layer normalization and post-precess @@ -222,14 +311,14 @@ class FusedTransformerEncoderLayer(Layer): d_model (int): The expected feature size in the input and output. nhead (int): The number of heads in multi-head attention(MHA). dim_feedforward (int): The hidden layer size in the feedforward network(FFN). - dropout (float, optional): The dropout probability used in pre-process + dropout_rate (float, optional): The dropout probability used in pre-process and post-precess of MHA and FFN sub-layer. Default 0.1 activation (str, optional): The activation function in the feedforward network. Default relu. - attn_dropout (float, optional): The dropout probability used + attn_dropout_rate (float, optional): The dropout probability used in MHA to drop some attention target. If None, use the value of `dropout`. Default None - act_dropout (float, optional): The dropout probability used after FFN + act_dropout_rate (float, optional): The dropout probability used after FFN activition. If None, use the value of `dropout`. Default None normalize_before (bool, optional): Indicate whether to put layer normalization into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer @@ -241,7 +330,7 @@ class FusedTransformerEncoderLayer(Layer): MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN. Otherwise, MHA and FFN both use it as `weight_attr` to create parameters. Default: None, which means the default weight parameter property is used. - See usage for details in :code:`ParamAttr` . + See usage for details in :code:`ParamAttr` . bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN. @@ -249,21 +338,21 @@ class FusedTransformerEncoderLayer(Layer): The `False` value means the corresponding layer would not have trainable bias parameter. See usage for details in :code:`ParamAttr` . Default: None, which means the default bias parameter property is used. - + Examples: .. code-block:: python - + # required: gpu import paddle - from paddle.nn import TransformerEncoderLayer + from paddle.incubate.nn import FusedTransformerEncoderLayer # encoder input: [batch_size, src_len, d_model] enc_input = paddle.rand((2, 4, 128)) # self attention mask: [batch_size, n_head, src_len, src_len] attn_mask = paddle.rand((2, 2, 4, 4)) - encoder_layer = TransformerEncoderLayer(128, 2, 512) + encoder_layer = FusedTransformerEncoderLayer(128, 2, 512) enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128] """ @@ -271,10 +360,10 @@ class FusedTransformerEncoderLayer(Layer): d_model, nhead, dim_feedforward, - dropout=0.1, + dropout_rate=0.1, activation="relu", - attn_dropout=None, - act_dropout=None, + attn_dropout_rate=None, + act_dropout_rate=None, normalize_before=False, weight_attr=None, bias_attr=None): @@ -283,7 +372,35 @@ class FusedTransformerEncoderLayer(Layer): self._config.pop("__class__", None) # py3 super(FusedTransformerEncoderLayer, self).__init__() - raise NotImplementedError() + assert d_model > 0, ("Expected d_model to be greater than 0, " + "but recieved {}".format(d_model)) + assert nhead > 0, ("Expected nhead to be greater than 0, " + "but recieved {}".format(nhead)) + assert dim_feedforward > 0, ( + "Expected dim_feedforward to be greater than 0, " + "but recieved {}".format(dim_feedforward)) + attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate + act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 2) + bias_attrs = _convert_param_attr_to_list(bias_attr, 2) + + self.fused_attn = FusedMultiHeadAttention( + d_model, + nhead, + dropout_rate=attn_dropout_rate, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0]) + + self.ffn = FusedFeedForward( + d_model, + dim_feedforward, + dropout_rate=dropout_rate, + act_dropout_rate=act_dropout_rate, + normalize_before=self.normalize_before, + weight_attr=weight_attrs[1], + bias_attr=bias_attrs[1]) def forward(self, src, src_mask=None, cache=None): """ @@ -296,11 +413,11 @@ class FusedTransformerEncoderLayer(Layer): to prevents attention to some unwanted positions, usually the paddings or the subsequent positions. It is a tensor with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. - When the data type is bool, the unwanted positions have `False` - values and the others have `True` values. When the data type is - int, the unwanted positions have 0 values and the others have 1 - values. When the data type is float, the unwanted positions have - `-INF` values and the others have 0 values. It can be None when + When the data type is bool, the unwanted positions have `False` + values and the others have `True` values. When the data type is + int, the unwanted positions have 0 values and the others have 1 + values. When the data type is float, the unwanted positions have + `-INF` values and the others have 0 values. It can be None when nothing wanted or needed to be prevented attention to. Default None. cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`. See `TransformerEncoderLayer.gen_cache` for more details. It is @@ -315,7 +432,16 @@ class FusedTransformerEncoderLayer(Layer): incremental length. See `MultiHeadAttention.gen_cache` and \ `MultiHeadAttention.forward` for more details. """ - raise NotImplementedError() + src_mask = _convert_attention_mask(src_mask, src.dtype) + if cache is None: + attn_out = self.fused_attn(src, attn_mask=src_mask) + else: + attn_out, incremental_cache = self.fused_attn( + src, attn_mask=src_mask, cache=cache) + + ffn_out = self.ffn(attn_out) + + return ffn_out if cache is None else (ffn_out, incremental_cache) class FusedTransformer(Layer): @@ -326,12 +452,12 @@ class FusedTransformer(Layer): Please refer to `Attention is all you need `_ , and see `TransformerEncoder` and `TransformerDecoder` for more details. - + Users can configurate the model architecture with corresponding parameters. Note the usage of `normalize_before` representing where to apply layer normalization (in pre-process or post-precess of multi-head attention or FFN), and some transformer like models are different on this, such as - `BERT `_ and `GPT2 `_ . + `BERT `_ and `GPT2 `_ . The default architecture here places layer normalization in post-process and applies another layer normalization on the output of last encoder/decoder layer. @@ -357,30 +483,30 @@ class FusedTransformer(Layer): Otherwise, no pre-process and post-precess includes dropout, residual connection, layer normalization. Default False weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property. - If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3, - `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]` - would be used as `weight_attr` for cross attention of `TransformerDecoder`, - and `weight_attr[2]` would be used as `weight_attr` for linear in FFN. - If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention - and cross attntion and `weight_attr[1]` would be used as `weight_attr` for - linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr` - for self attention, cross attention and linear in FFN. Otherwise, - the three sub-layers all uses it as `weight_attr` to create parameters. - Default: None, which means the default weight parameter property is used. + If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3, + `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]` + would be used as `weight_attr` for cross attention of `TransformerDecoder`, + and `weight_attr[2]` would be used as `weight_attr` for linear in FFN. + If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention + and cross attntion and `weight_attr[1]` would be used as `weight_attr` for + linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr` + for self attention, cross attention and linear in FFN. Otherwise, + the three sub-layers all uses it as `weight_attr` to create parameters. + Default: None, which means the default weight parameter property is used. See usage for details - in :code:`ParamAttr` . + in :code:`ParamAttr` . bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. - If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3, - `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]` - would be used as `bias_attr` for cross attention of `TransformerDecoder`, - and `bias_attr[2]` would be used as `bias_attr` for linear in FFN. - If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention - and cross attntion and `bias_attr[1]` would be used as `bias_attr` for - linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr` - for self attention, cross attention and linear in FFN. Otherwise, - the three sub-layers all uses it as `bias_attr` to create parameters. - The `False` value means the corresponding layer would not have trainable - bias parameter. See usage for details in :code:`ParamAttr` . + If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3, + `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]` + would be used as `bias_attr` for cross attention of `TransformerDecoder`, + and `bias_attr[2]` would be used as `bias_attr` for linear in FFN. + If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention + and cross attntion and `bias_attr[1]` would be used as `bias_attr` for + linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr` + for self attention, cross attention and linear in FFN. Otherwise, + the three sub-layers all uses it as `bias_attr` to create parameters. + The `False` value means the corresponding layer would not have trainable + bias parameter. See usage for details in :code:`ParamAttr` . Default: None,which means the default bias parameter property is used. custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder. Default None diff --git a/python/setup.py.in b/python/setup.py.in index b10d5df541f2ff8527b06565cc2b297396d26867..b246225cbab230b7439dffaf0f248d3cf9f4b913 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -163,6 +163,7 @@ packages=['paddle', 'paddle.incubate.checkpoint', 'paddle.incubate.operators', 'paddle.incubate.tensor', + 'paddle.incubate.nn', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.elastic', @@ -230,6 +231,9 @@ packages=['paddle', 'paddle.text', 'paddle.text.datasets', 'paddle.incubate', + 'paddle.incubate.nn', + 'paddle.incubate.nn.functional', + 'paddle.incubate.nn.layer', 'paddle.io', 'paddle.optimizer', 'paddle.nn',