未验证 提交 9f3613f3 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fused transformer encoder layer and fused feedforward layer (#36604)

本PR是fused_transformer的layer层代码,包含FusedFeedForward的layer层代码和FusedTransformerEncoderLayer的代码。
上级 e6253152
......@@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (auto& var : pair.second) {
......@@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" &&
pair.first != "Ln2Scale" && pair.first != "Ln2Bias" &&
pair.first != "Ln1Scale" && pair.first != "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
......
......@@ -153,6 +153,8 @@ gray_list = {
'c_allreduce_sum',
'concat',
'split',
'fused_feedforward',
'fused_attention',
}
# The set of ops that don't support fp16 calculation
......
......@@ -89,6 +89,10 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
}
return False
......@@ -98,6 +102,11 @@ def _keep_fp32_output(op, out_name):
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
}
return False
......
......@@ -23,6 +23,8 @@ from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from . import nn #noqa: F401
__all__ = [
'LookAhead',
'ModelAverage',
......
......@@ -13,7 +13,12 @@
# limitations under the License.
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
]
......@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from paddle.nn import functional as F
from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer
from paddle.framework import ParamAttr
import paddle
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list
from paddle.nn.initializer import Constant
import collections
......@@ -56,7 +54,10 @@ class FusedMultiHeadAttention(Layer):
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` .
Examples:
.. code-block:: python
# required: gpu
import paddle
# input: [batch_size, sequence_length, embed_dim]
query = paddle.rand((2, 4, 128))
......@@ -192,26 +193,114 @@ class FusedMultiHeadAttention(Layer):
class FusedFeedForward(Layer):
"""
Parameters:
d_model (int): The expected feature size in the input and output.
dim_feedforward (int): The hidden layer size.
dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess. Default 0.1
activation (str, optional): The activation function. Default relu.
act_dropout_rate (float, optional): The dropout probability after activition.
If None, use the value of `dropout_rate`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into, preprocessing or postprocessing. Default False
weight_attr (ParamAttr, optional): The attribute for the learnable weight of this layer.
The default value is None and the weight will be initialized to zero. For detailed
information, please refer to paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias of thi layer.
If it is set to False, no bias will be added to the output. If it is set to None or one
kind of ParamAttr, a bias parameter will be created according to ParamAttr. For detailed
information, please refer to paddle.ParamAttr. The default value is None and the bias
will be initialized to zero.
Examples:
.. code-block:: python
# required: gpu
import paddle
from paddle.incubate.nn import FusedFeedForward
fused_feedforward_layer = FusedFeedForward(8, 8)
x = paddle.rand((1, 8, 8))
out = fused_feedforward_layer(x)
print(out.numpy().shape)
# (1, 8, 8)
"""
def __init__(self,
d_model,
dim_feedforward,
dropout=0.1,
dropout_rate=0.1,
activation="relu",
act_dropout=None,
act_dropout_rate=None,
normalize_before=False,
weight_attr=None,
bias_attr=None):
super(FusedFeedForward, self).__init__()
raise NotImplementedError()
assert d_model > 0, (
"Expected d_model to be greater than 0, but recieved {}".format(
d_model))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, but recieved {}".
format(dim_feedforward))
self._dtype = self._helper.get_default_dtype()
self._d_model = d_model
self._dim_feedforward = dim_feedforward
self._dropout_rate = dropout_rate
self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self._act_method = activation
self._normalize_before = normalize_before
self._linear1_weight = self.create_parameter(
shape=[d_model, dim_feedforward],
attr=weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear1_bias = self.create_parameter(
shape=[dim_feedforward],
attr=bias_attr,
dtype=self._dtype,
is_bias=True)
self._linear2_weight = self.create_parameter(
shape=[dim_feedforward, d_model],
attr=weight_attr,
dtype=self._dtype,
is_bias=False)
self._linear2_bias = self.create_parameter(
shape=[d_model], attr=bias_attr, dtype=self._dtype, is_bias=True)
self._ln1_scale = self.create_parameter(
shape=[d_model],
attr=None,
is_bias=False,
default_initializer=Constant(1.0))
self._ln1_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True)
self._ln2_scale = self.create_parameter(
shape=[d_model],
attr=None,
is_bias=False,
default_initializer=Constant(1.0))
self._ln2_bias = self.create_parameter(
shape=[d_model], attr=None, is_bias=True)
def forward(self, src, cache=None):
raise NotImplementedError()
out = incubate_f.fused_feedforward(
src, self._linear1_weight, self._linear2_weight, self._linear1_bias,
self._linear2_bias, self._ln1_scale, self._ln1_bias,
self._ln2_scale, self._ln2_bias, self._dropout_rate,
self._act_dropout_rate, self._act_method, self._normalize_before)
return out
class FusedTransformerEncoderLayer(Layer):
"""
TransformerEncoderLayer is composed of two sub-layers which are self (multi-head)
FusedTransformerEncoderLayer is composed of two sub-layers which are self (multi-head)
attention and feedforward network. Before and after each sub-layer, pre-process
and post-precess would be applied on the input and output accordingly. If
`normalize_before` is True, pre-process is layer normalization and post-precess
......@@ -222,14 +311,14 @@ class FusedTransformerEncoderLayer(Layer):
d_model (int): The expected feature size in the input and output.
nhead (int): The number of heads in multi-head attention(MHA).
dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
dropout (float, optional): The dropout probability used in pre-process
dropout_rate (float, optional): The dropout probability used in pre-process
and post-precess of MHA and FFN sub-layer. Default 0.1
activation (str, optional): The activation function in the feedforward
network. Default relu.
attn_dropout (float, optional): The dropout probability used
attn_dropout_rate (float, optional): The dropout probability used
in MHA to drop some attention target. If None, use the value of
`dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
act_dropout_rate (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
......@@ -257,13 +346,13 @@ class FusedTransformerEncoderLayer(Layer):
# required: gpu
import paddle
from paddle.nn import TransformerEncoderLayer
from paddle.incubate.nn import FusedTransformerEncoderLayer
# encoder input: [batch_size, src_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, n_head, src_len, src_len]
attn_mask = paddle.rand((2, 2, 4, 4))
encoder_layer = TransformerEncoderLayer(128, 2, 512)
encoder_layer = FusedTransformerEncoderLayer(128, 2, 512)
enc_output = encoder_layer(enc_input, attn_mask) # [2, 4, 128]
"""
......@@ -271,10 +360,10 @@ class FusedTransformerEncoderLayer(Layer):
d_model,
nhead,
dim_feedforward,
dropout=0.1,
dropout_rate=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
attn_dropout_rate=None,
act_dropout_rate=None,
normalize_before=False,
weight_attr=None,
bias_attr=None):
......@@ -283,7 +372,35 @@ class FusedTransformerEncoderLayer(Layer):
self._config.pop("__class__", None) # py3
super(FusedTransformerEncoderLayer, self).__init__()
raise NotImplementedError()
assert d_model > 0, ("Expected d_model to be greater than 0, "
"but recieved {}".format(d_model))
assert nhead > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(nhead))
assert dim_feedforward > 0, (
"Expected dim_feedforward to be greater than 0, "
"but recieved {}".format(dim_feedforward))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
bias_attrs = _convert_param_attr_to_list(bias_attr, 2)
self.fused_attn = FusedMultiHeadAttention(
d_model,
nhead,
dropout_rate=attn_dropout_rate,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
self.ffn = FusedFeedForward(
d_model,
dim_feedforward,
dropout_rate=dropout_rate,
act_dropout_rate=act_dropout_rate,
normalize_before=self.normalize_before,
weight_attr=weight_attrs[1],
bias_attr=bias_attrs[1])
def forward(self, src, src_mask=None, cache=None):
"""
......@@ -315,7 +432,16 @@ class FusedTransformerEncoderLayer(Layer):
incremental length. See `MultiHeadAttention.gen_cache` and \
`MultiHeadAttention.forward` for more details.
"""
raise NotImplementedError()
src_mask = _convert_attention_mask(src_mask, src.dtype)
if cache is None:
attn_out = self.fused_attn(src, attn_mask=src_mask)
else:
attn_out, incremental_cache = self.fused_attn(
src, attn_mask=src_mask, cache=cache)
ffn_out = self.ffn(attn_out)
return ffn_out if cache is None else (ffn_out, incremental_cache)
class FusedTransformer(Layer):
......
......@@ -163,6 +163,7 @@ packages=['paddle',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.incubate.nn',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
......@@ -230,6 +231,9 @@ packages=['paddle',
'paddle.text',
'paddle.text.datasets',
'paddle.incubate',
'paddle.incubate.nn',
'paddle.incubate.nn.functional',
'paddle.incubate.nn.layer',
'paddle.io',
'paddle.optimizer',
'paddle.nn',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册