未验证 提交 9f3613f3 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fused transformer encoder layer and fused feedforward layer (#36604)

本PR是fused_transformer的layer层代码,包含FusedFeedForward的layer层代码和FusedTransformerEncoderLayer的代码。
上级 e6253152
......@@ -191,6 +191,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
pair.first == "Ln1Scale" || pair.first == "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float16";
for (auto& var : pair.second) {
......@@ -223,6 +231,14 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" &&
pair.first != "Ln2Scale" && pair.first != "Ln2Bias" &&
pair.first != "Ln1Scale" && pair.first != "Ln1Bias") {
continue;
}
}
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
......
......@@ -104,7 +104,7 @@ black_list = {
'reduce_sum',
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
......@@ -153,6 +153,8 @@ gray_list = {
'c_allreduce_sum',
'concat',
'split',
'fused_feedforward',
'fused_attention',
}
# The set of ops that don't support fp16 calculation
......
......@@ -40,7 +40,7 @@ _fp16_guard_pattern = "__use_fp16__"
def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
If an op has old_name input and output, rename these input
args new_name.
Args:
......@@ -89,6 +89,10 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
}
return False
......@@ -98,6 +102,11 @@ def _keep_fp32_output(op, out_name):
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
}
return False
......@@ -256,16 +265,16 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
search_all (bool): The type of operator search. Use if \"cur_op\" is not in the \"ops\" set.
"""
post_op = []
if search_all:
"""
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx.
\"cur_op\" do not have to be in list of \"ops\". E.g. \"cur_op\" can come
from startup_prog block and \"ops\" list from main_prog block.
By setting idx to -1, we'll start looking for post-ops from the top of the list.
If search_all is False, assume that \"cur_op\" is in \"ops\" list,
so to reduce the time of search we can start iterating from \"cur_op\" idx.
"""
idx = -1
else:
......@@ -517,19 +526,19 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
......
......@@ -107,7 +107,7 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
q = qkv[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim)
......
......@@ -23,6 +23,8 @@ from .tensor import segment_mean
from .tensor import segment_max
from .tensor import segment_min
from . import nn #noqa: F401
__all__ = [
'LookAhead',
'ModelAverage',
......
......@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
from .layer.fused_transformer import FusedFeedForward # noqa: F401
from .layer.fused_transformer import FusedTransformerEncoderLayer # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
]
......@@ -218,7 +218,7 @@ def fused_multi_head_attention(x,
`[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
(False). Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
......@@ -229,12 +229,12 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
It can be None when nothing wanted or needed to be prevented attention to. Default None.
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
......
......@@ -163,6 +163,7 @@ packages=['paddle',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
'paddle.incubate.tensor',
'paddle.incubate.nn',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.elastic',
......@@ -230,6 +231,9 @@ packages=['paddle',
'paddle.text',
'paddle.text.datasets',
'paddle.incubate',
'paddle.incubate.nn',
'paddle.incubate.nn.functional',
'paddle.incubate.nn.layer',
'paddle.io',
'paddle.optimizer',
'paddle.nn',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册