未验证 提交 e1b5b1da 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick]Fused transformer encoder layer and fused feedforward layer #36776

本PR是fused_transformer的layer层代码,包含FusedFeedForward的layer层代码和FusedTransformerEncoderLayer的代码。
上级 5402f8e7
...@@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, ...@@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
continue; continue;
} }
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16"; << GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (auto& var : pair.second) { for (auto& var : pair.second) {
...@@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, ...@@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) { pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue; continue;
} }
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" &&
pair.first != "Ln2Scale" && pair.first != "Ln2Bias" &&
pair.first != "Ln1Scale" && pair.first != "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to " << GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type); << framework::DataTypeToString(dst_type);
......
...@@ -104,7 +104,7 @@ black_list = { ...@@ -104,7 +104,7 @@ black_list = {
'reduce_sum', 'reduce_sum',
} }
# This set contains two types of ops. All ops supported fp16 calculation. One # This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an # of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant # upstream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2. # effects, like stack, flatten2.
...@@ -153,6 +153,8 @@ gray_list = { ...@@ -153,6 +153,8 @@ gray_list = {
'c_allreduce_sum', 'c_allreduce_sum',
'concat', 'concat',
'split', 'split',
'fused_feedforward',
'fused_attention',
} }
# The set of ops that don't support fp16 calculation # The set of ops that don't support fp16 calculation
......
...@@ -40,7 +40,7 @@ _fp16_guard_pattern = "__use_fp16__" ...@@ -40,7 +40,7 @@ _fp16_guard_pattern = "__use_fp16__"
def _rename_arg(op, old_name, new_name): def _rename_arg(op, old_name, new_name):
""" """
If an op has old_name input and output, rename these input If an op has old_name input and output, rename these input
args new_name. args new_name.
Args: Args:
...@@ -80,6 +80,36 @@ def _dtype_to_str(dtype): ...@@ -80,6 +80,36 @@ def _dtype_to_str(dtype):
return 'fp32' return 'fp32'
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['batch_norm', 'layer_norm']:
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
}
return False
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
""" """
Insert cast op and rename args of input and output. Insert cast op and rename args of input and output.
...@@ -239,16 +269,16 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False): ...@@ -239,16 +269,16 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
ops (list): A list of ops. ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable. cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name. var_name (string): Variable name.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set. search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
""" """
post_op = [] post_op = []
if search_all: if search_all:
""" """
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come \"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block. from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list. By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list, If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx. so to reduce the time of search we can start iterating from \"cur_op\" idx.
""" """
idx = -1 idx = -1
else: else:
...@@ -504,19 +534,19 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -504,19 +534,19 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
def rewrite_program(main_prog, amp_lists): def rewrite_program(main_prog, amp_lists):
""" """
Traverse all ops in current block and insert cast op according to Traverse all ops in current block and insert cast op according to
which set current op belongs to. which set current op belongs to.
1. When an op belongs to the black list, add it to black set 1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set 2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one 3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op, of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or op and one of its inputs is the output of white set op or
white list op, add it to white set. white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set. 4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be 5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in computed in fp32 mode, while white set op will be computed in
fp16 mode. fp16 mode.
Args: Args:
......
...@@ -107,7 +107,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -107,7 +107,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
q = qkv[0:1, ::] q = qkv[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim) q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim] k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim) k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::] v = qkv[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim) v = v.reshape(batch_size, num_head, seq_len, head_dim)
......
...@@ -23,6 +23,8 @@ from .tensor import segment_mean ...@@ -23,6 +23,8 @@ from .tensor import segment_mean
from .tensor import segment_max from .tensor import segment_max
from .tensor import segment_min from .tensor import segment_min
from . import nn #noqa: F401
__all__ = [ __all__ = [
'LookAhead', 'LookAhead',
'ModelAverage', 'ModelAverage',
......
...@@ -12,8 +12,13 @@ ...@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401 from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
'FusedMultiHeadAttention', 'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
] ]
...@@ -218,7 +218,7 @@ def fused_multi_head_attention(x, ...@@ -218,7 +218,7 @@ def fused_multi_head_attention(x,
`[batch\_size, sequence\_len, embed\_dim]`. `[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
(False). Default False. (False). Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
...@@ -229,12 +229,12 @@ def fused_multi_head_attention(x, ...@@ -229,12 +229,12 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None. Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values. data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values. When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have `-INF` values and the others have 0 values. When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
It can be None when nothing wanted or needed to be prevented attention to. Default None. It can be None when nothing wanted or needed to be prevented attention to. Default None.
dropout_rate (float, optional): The dropout probability used on attention dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention. weights to drop some attention targets for the dropout after attention.
......
...@@ -11,14 +11,12 @@ ...@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.incubate.nn import functional as incubate_f from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer from paddle.nn import Layer
from paddle.framework import ParamAttr from paddle.framework import ParamAttr
import paddle import paddle
from paddle.nn.layer.transformer import _convert_attention_mask from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list
from paddle.nn.initializer import Constant from paddle.nn.initializer import Constant
import collections import collections
...@@ -35,16 +33,16 @@ class FusedMultiHeadAttention(Layer): ...@@ -35,16 +33,16 @@ class FusedMultiHeadAttention(Layer):
embed_dim (int): The expected feature size in the input and output. embed_dim (int): The expected feature size in the input and output.
num_heads (int): The number of heads in multi-head attention. num_heads (int): The number of heads in multi-head attention.
dropout_rate (float, optional): The dropout probability used on attention dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention. weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.5. 0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional): The dropout probability used on attention attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention. weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.5. 0 for no dropout. Default 0.5.
kdim (int, optional): The feature size in key. If None, assumed equal to kdim (int, optional): The feature size in key. If None, assumed equal to
`embed_dim`. Default None. `embed_dim`. Default None.
vdim (int, optional): The feature size in value. If None, assumed equal to vdim (int, optional): The feature size in value. If None, assumed equal to
`embed_dim`. Default None. `embed_dim`. Default None.
normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True)
or post_layer_norm architecture (False). Default False. or post_layer_norm architecture (False). Default False.
need_weights (bool, optional): Indicate whether to return the attention need_weights (bool, optional): Indicate whether to return the attention
weights. Now, only False is supported. Default False. weights. Now, only False is supported. Default False.
...@@ -56,7 +54,10 @@ class FusedMultiHeadAttention(Layer): ...@@ -56,7 +54,10 @@ class FusedMultiHeadAttention(Layer):
If it is set to False, this layer will not have trainable bias parameter. If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
Examples: Examples:
.. code-block:: python .. code-block:: python
# required: gpu
import paddle import paddle
# input: [batch_size, sequence_length, embed_dim] # input: [batch_size, sequence_length, embed_dim]
query = paddle.rand((2, 4, 128)) query = paddle.rand((2, 4, 128))
...@@ -154,17 +155,17 @@ class FusedMultiHeadAttention(Layer): ...@@ -154,17 +155,17 @@ class FusedMultiHeadAttention(Layer):
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False` When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1 int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when `-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None. nothing wanted or needed to be prevented attention to. Default None.
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
Now, only None is supported. Default None. Now, only None is supported. Default None.
Returns: Returns:
Tensor|tuple: It is a tensor that has the same shape and data type \ Tensor|tuple: It is a tensor that has the same shape and data type \
as `query`, representing attention output. as `query`, representing attention output.
""" """
if attn_mask is not None: if attn_mask is not None:
# Support bool or int mask # Support bool or int mask
...@@ -192,26 +193,114 @@ class FusedMultiHeadAttention(Layer): ...@@ -192,26 +193,114 @@ class FusedMultiHeadAttention(Layer):
class FusedFeedForward(Layer): class FusedFeedForward(Layer):
"""
Parameters:
d_model (int): The expected feature size in the input and output.
dim_feedforward (int): The hidden layer size.
dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess. Default 0.1
activation (str, optional): The activation function. Default relu.
act_dropout_rate (float, optional): The dropout probability after activition.
If None, use the value of `dropout_rate`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into, preprocessing or postprocessing. Default False
weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer.
The default value is None and the weight will be initialized to zero. For detailed
information, please refer to paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of thi layer.
If it is set to False, no bias will be added to the output. If it is set to None or one
kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed
information, please refer to paddle.ParamAttr. The default value is None and the bias
will be initialized to zero.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn import FusedFeedForward
fused_feedforward_layer = FusedFeedForward(8, 8)
x = paddle.rand((1, 8, 8))
out = fused_feedforward_layer(x)
print(out.numpy().shape)
# (1, 8, 8)
"""
def __init__(self, def __init__(self,
d_model, d_model,
dim_feedforward, dim_feedforward,
dropout=0.1, dropout_rate=0.1,
activation="relu", activation="relu",
act_dropout=None, act_dropout_rate=None,
normalize_before=False, normalize_before=False,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None):
super(FusedFeedForward, self).__init__() super(FusedFeedForward, self).__init__()
raise NotImplementedError() assert d_model > 0, (
"Expected d_model to be greater than 0, but recieved {}".format(
d_model))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, but recieved {}".
format(dim_feedforward))
self._dtype = self._helper.get_default_dtype()
self._d_model = d_model
self._dim_feedforward = dim_feedforward
self._dropout_rate = dropout_rate
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation
self._normalize_before = normalize_before
self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward],
attr=weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear1_bias = self.create_parameter(
shape=[dim_feedforward],
attr=bias_attr,
dtype=self._dtype,
is_bias=True)
self._linear2_weight = self.create_parameter(
shape=[dim_feedforward, d_model],
attr=weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear2_bias = self.create_parameter(
shape=[d_model], attr=bias_attr, dtype=self._dtype, is_bias=True)
self._ln1_scale = self.create_parameter(
shape=[d_model],
attr=None,
is_bias=False,
default_initializer=Constant(1.0))
self._ln1_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True)
self._ln2_scale = self.create_parameter(
shape=[d_model],
attr=None,
is_bias=False,
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True)
def forward(self, src, cache=None): def forward(self, src, cache=None):
raise NotImplementedError() out = incubate_f.fused_feedforward(
src, self._linear1_weight, self._linear2_weight, self._linear1_bias,
self._linear2_bias, self._ln1_scale, self._ln1_bias,
self._ln2_scale, self._ln2_bias, self._dropout_rate,
self._act_dropout_rate, self._act_method, self._normalize_before)
return out
class FusedTransformerEncoderLayer(Layer): class FusedTransformerEncoderLayer(Layer):
""" """
TransformerEncoderLayer is composed of two sub-layers which are self (multi-head) FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head)
attention and feedforward network. Before and after each sub-layer, pre-process attention and feedforward network. Before and after each sub-layer, pre-process
and post-precess would be applied on the input and output accordingly. If and post-precess would be applied on the input and output accordingly. If
`normalize_before` is True, pre-process is layer normalization and post-precess `normalize_before` is True, pre-process is layer normalization and post-precess
...@@ -222,14 +311,14 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -222,14 +311,14 @@ class FusedTransformerEncoderLayer(Layer):
d_model (int): The expected feature size in the input and output. d_model (int): The expected feature size in the input and output.
nhead (int): The number of heads in multi-head attention(MHA). nhead (int): The number of heads in multi-head attention(MHA).
dim_feedforward (int): The hidden layer size in the feedforward network(FFN). dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
dropout (float, optional): The dropout probability used in pre-process dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess of MHA and FFN sub-layer. Default 0.1 and post-precess of MHA and FFN sub-layer. Default 0.1
activation (str, optional): The activation function in the feedforward activation (str, optional): The activation function in the feedforward
network. Default relu. network. Default relu.
attn_dropout (float, optional): The dropout probability used attn_dropout_rate (float, optional): The dropout probability used
in MHA to drop some attention target. If None, use the value of in MHA to drop some attention target. If None, use the value of
`dropout`. Default None `dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN act_dropout_rate (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None activition. If None, use the value of `dropout`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
...@@ -241,7 +330,7 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -241,7 +330,7 @@ class FusedTransformerEncoderLayer(Layer):
MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN. MHA, and `weight_attr[1]` would be used as `weight_attr` for linear in FFN.
Otherwise, MHA and FFN both use it as `weight_attr` to create parameters. Otherwise, MHA and FFN both use it as `weight_attr` to create parameters.
Default: None, which means the default weight parameter property is used. Default: None, which means the default weight parameter property is used.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for If it is a list/tuple, `bias_attr[0]` would be used as `bias_attr` for
MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN. MHA, and `bias_attr[1]` would be used as `bias_attr` for linear in FFN.
...@@ -249,21 +338,21 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -249,21 +338,21 @@ class FusedTransformerEncoderLayer(Layer):
The `False` value means the corresponding layer would not have trainable The `False` value means the corresponding layer would not have trainable
bias parameter. See usage for details in :code:`ParamAttr` . Default: None, bias parameter. See usage for details in :code:`ParamAttr` . Default: None,
which means the default bias parameter property is used. which means the default bias parameter property is used.
Examples: Examples:
.. code-block:: python .. code-block:: python
# required: gpu # required: gpu
import paddle import paddle
from paddle.nn import TransformerEncoderLayer from paddle.incubate.nn import FusedTransformerEncoderLayer
# encoder input: [batch_size, src_len, d_model] # encoder input: [batch_size, src_len, d_model]
enc_input = paddle.rand((2, 4, 128)) enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, n_head, src_len, src_len] # self attention mask: [batch_size, n_head, src_len, src_len]
attn_mask = paddle.rand((2, 2, 4, 4)) attn_mask = paddle.rand((2, 2, 4, 4))
encoder_layer = TransformerEncoderLayer(128, 2, 512) encoder_layer = FusedTransformerEncoderLayer(128, 2, 512)
enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128] enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128]
""" """
...@@ -271,10 +360,10 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -271,10 +360,10 @@ class FusedTransformerEncoderLayer(Layer):
d_model, d_model,
nhead, nhead,
dim_feedforward, dim_feedforward,
dropout=0.1, dropout_rate=0.1,
activation="relu", activation="relu",
attn_dropout=None, attn_dropout_rate=None,
act_dropout=None, act_dropout_rate=None,
normalize_before=False, normalize_before=False,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None):
...@@ -283,7 +372,35 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -283,7 +372,35 @@ class FusedTransformerEncoderLayer(Layer):
self._config.pop("__class__", None) # py3 self._config.pop("__class__", None) # py3
super(FusedTransformerEncoderLayer, self).__init__() super(FusedTransformerEncoderLayer, self).__init__()
raise NotImplementedError() assert d_model > 0, ("Expected d_model to be greater than 0, "
"but recieved {}".format(d_model))
assert nhead > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(nhead))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, "
"but recieved {}".format(dim_feedforward))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
bias_attrs = _convert_param_attr_to_list(bias_attr, 2)
self.fused_attn = FusedMultiHeadAttention(
d_model,
nhead,
dropout_rate=attn_dropout_rate,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
self.ffn = FusedFeedForward(
d_model,
dim_feedforward,
dropout_rate=dropout_rate,
act_dropout_rate=act_dropout_rate,
normalize_before=self.normalize_before,
weight_attr=weight_attrs[1],
bias_attr=bias_attrs[1])
def forward(self, src, src_mask=None, cache=None): def forward(self, src, src_mask=None, cache=None):
""" """
...@@ -296,11 +413,11 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -296,11 +413,11 @@ class FusedTransformerEncoderLayer(Layer):
to prevents attention to some unwanted positions, usually the to prevents attention to some unwanted positions, usually the
paddings or the subsequent positions. It is a tensor with shape paddings or the subsequent positions. It is a tensor with shape
broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
When the data type is bool, the unwanted positions have `False` When the data type is bool, the unwanted positions have `False`
values and the others have `True` values. When the data type is values and the others have `True` values. When the data type is
int, the unwanted positions have 0 values and the others have 1 int, the unwanted positions have 0 values and the others have 1
values. When the data type is float, the unwanted positions have values. When the data type is float, the unwanted positions have
`-INF` values and the others have 0 values. It can be None when `-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None. nothing wanted or needed to be prevented attention to. Default None.
cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`. cache (Tensor, optional): It is an instance of `MultiHeadAttention.Cache`.
See `TransformerEncoderLayer.gen_cache` for more details. It is See `TransformerEncoderLayer.gen_cache` for more details. It is
...@@ -315,7 +432,16 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -315,7 +432,16 @@ class FusedTransformerEncoderLayer(Layer):
incremental length. See `MultiHeadAttention.gen_cache` and \ incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details. `MultiHeadAttention.forward` for more details.
""" """
raise NotImplementedError() src_mask = _convert_attention_mask(src_mask, src.dtype)
if cache is None:
attn_out = self.fused_attn(src, attn_mask=src_mask)
else:
attn_out, incremental_cache = self.fused_attn(
src, attn_mask=src_mask, cache=cache)
ffn_out = self.ffn(attn_out)
return ffn_out if cache is None else (ffn_out, incremental_cache)
class FusedTransformer(Layer): class FusedTransformer(Layer):
...@@ -326,12 +452,12 @@ class FusedTransformer(Layer): ...@@ -326,12 +452,12 @@ class FusedTransformer(Layer):
Please refer to `Attention is all you need <http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_ , Please refer to `Attention is all you need <http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_ ,
and see `TransformerEncoder` and `TransformerDecoder` for more details. and see `TransformerEncoder` and `TransformerDecoder` for more details.
Users can configurate the model architecture with corresponding parameters. Users can configurate the model architecture with corresponding parameters.
Note the usage of `normalize_before` representing where to apply layer Note the usage of `normalize_before` representing where to apply layer
normalization (in pre-process or post-precess of multi-head attention or FFN), normalization (in pre-process or post-precess of multi-head attention or FFN),
and some transformer like models are different on this, such as and some transformer like models are different on this, such as
`BERT <https://arxiv.org/abs/1810.04805>`_ and `GPT2 <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`_ . `BERT <https://arxiv.org/abs/1810.04805>`_ and `GPT2 <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`_ .
The default architecture here places layer normalization in post-process and The default architecture here places layer normalization in post-process and
applies another layer normalization on the output of last encoder/decoder layer. applies another layer normalization on the output of last encoder/decoder layer.
...@@ -357,30 +483,30 @@ class FusedTransformer(Layer): ...@@ -357,30 +483,30 @@ class FusedTransformer(Layer):
Otherwise, no pre-process and post-precess includes dropout, residual Otherwise, no pre-process and post-precess includes dropout, residual
connection, layer normalization. Default False connection, layer normalization. Default False
weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property. weight_attr(ParamAttr|list|tuple, optional): To specify the weight parameter property.
If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3, If it is a list/tuple, the length of `weight_attr` could be 1, 2 or 3. If it is 3,
`weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]` `weight_attr[0]` would be used as `weight_attr` for self attention, `weight_attr[1]`
would be used as `weight_attr` for cross attention of `TransformerDecoder`, would be used as `weight_attr` for cross attention of `TransformerDecoder`,
and `weight_attr[2]` would be used as `weight_attr` for linear in FFN. and `weight_attr[2]` would be used as `weight_attr` for linear in FFN.
If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention If it is 2, `weight_attr[0]` would be used as `weight_attr` both for self attention
and cross attntion and `weight_attr[1]` would be used as `weight_attr` for and cross attntion and `weight_attr[1]` would be used as `weight_attr` for
linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr` linear in FFN. If it is 1, `weight_attr[0]` would be used as `weight_attr`
for self attention, cross attention and linear in FFN. Otherwise, for self attention, cross attention and linear in FFN. Otherwise,
the three sub-layers all uses it as `weight_attr` to create parameters. the three sub-layers all uses it as `weight_attr` to create parameters.
Default: None, which means the default weight parameter property is used. Default: None, which means the default weight parameter property is used.
See usage for details See usage for details
in :code:`ParamAttr` . in :code:`ParamAttr` .
bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property. bias_attr (ParamAttr|list|tuple|bool, optional): To specify the bias parameter property.
If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3, If it is a list/tuple, the length of `bias_attr` could be 1, 2 or 3. If it is 3,
`bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]` `bias_attr[0]` would be used as `bias_attr` for self attention, `bias_attr[1]`
would be used as `bias_attr` for cross attention of `TransformerDecoder`, would be used as `bias_attr` for cross attention of `TransformerDecoder`,
and `bias_attr[2]` would be used as `bias_attr` for linear in FFN. and `bias_attr[2]` would be used as `bias_attr` for linear in FFN.
If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention If it is 2, `bias_attr[0]` would be used as `bias_attr` both for self attention
and cross attntion and `bias_attr[1]` would be used as `bias_attr` for and cross attntion and `bias_attr[1]` would be used as `bias_attr` for
linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr` linear in FFN. If it is 1, `bias_attr[0]` would be used as `bias_attr`
for self attention, cross attention and linear in FFN. Otherwise, for self attention, cross attention and linear in FFN. Otherwise,
the three sub-layers all uses it as `bias_attr` to create parameters. the three sub-layers all uses it as `bias_attr` to create parameters.
The `False` value means the corresponding layer would not have trainable The `False` value means the corresponding layer would not have trainable
bias parameter. See usage for details in :code:`ParamAttr` . bias parameter. See usage for details in :code:`ParamAttr` .
Default: None,which means the default bias parameter property is used. Default: None,which means the default bias parameter property is used.
custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder. custom_encoder (Layer, optional): If custom encoder is provided, use it as the encoder.
Default None Default None
......
...@@ -163,6 +163,7 @@ packages=['paddle', ...@@ -163,6 +163,7 @@ packages=['paddle',
'paddle.incubate.checkpoint', 'paddle.incubate.checkpoint',
'paddle.incubate.operators', 'paddle.incubate.operators',
'paddle.incubate.tensor', 'paddle.incubate.tensor',
'paddle.incubate.nn',
'paddle.distributed.fleet', 'paddle.distributed.fleet',
'paddle.distributed.fleet.base', 'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.elastic',
...@@ -230,6 +231,9 @@ packages=['paddle', ...@@ -230,6 +231,9 @@ packages=['paddle',
'paddle.text', 'paddle.text',
'paddle.text.datasets', 'paddle.text.datasets',
'paddle.incubate', 'paddle.incubate',
'paddle.incubate.nn',
'paddle.incubate.nn.functional',
'paddle.incubate.nn.layer',
'paddle.io', 'paddle.io',
'paddle.optimizer', 'paddle.optimizer',
'paddle.nn', 'paddle.nn',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册