Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
223fb7b3
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
223fb7b3
编写于
6月 21, 2022
作者:
Y
Yiqun Liu
提交者:
GitHub
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix code example of fused_attention and fused_feedforward. (#43635)
上级
4aac90ef
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
81 addition
and
56 deletion
+81
-56
paddle/fluid/operators/fused/fused_attention_op.cc
paddle/fluid/operators/fused/fused_attention_op.cc
+15
-10
paddle/fluid/operators/fused/fused_feedforward_op.cc
paddle/fluid/operators/fused/fused_feedforward_op.cc
+19
-11
python/paddle/incubate/nn/functional/fused_transformer.py
python/paddle/incubate/nn/functional/fused_transformer.py
+47
-35
未找到文件。
paddle/fluid/operators/fused/fused_attention_op.cc
浏览文件 @
223fb7b3
...
...
@@ -386,13 +386,15 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
The fused_attention operator is the same as following pseudo codes:
// @input: [batch_size, seq_len, embed_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
residual = input
if (pre_layernorm)
out
= layer_norm(input);
out = compute_qkv(out) +
bias;
// fmha module
query
= layer_norm(input);
out = compute_qkv(query) + qkv_
bias;
// fmha module
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
...
...
@@ -403,11 +405,14 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
out = transpose(out, perm=[0, 2, 1, 3]);
}
out = out_linear(out);
if (pre_layernorm)
final_out = residual + dropout(bias + out);
else
final_out = layer_norm(residual + dropout(bias + out));
// out linear
out = linear(out);
if add_residual:
out = residual + dropout(out);
else:
out = dropout(out);
if (!pre_layernorm)
out = layer_norm(out);
)DOC"
);
}
};
...
...
paddle/fluid/operators/fused/fused_feedforward_op.cc
浏览文件 @
223fb7b3
...
...
@@ -198,17 +198,25 @@ class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"ring_id"
,
"ring id for tensor model parallel."
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
the function of fused_feedforward operator is the same as the following pseudo code:
residual = src;
ln1_out = src;
if(pre_layer_norm){
ln1_out = layer_norm(src);
}
out = linear(dropout(activation(dropout(linear(ln1_out)))));
if(!pre_layer_norm) {
out = layer_norm(out);
}
)DOC"
);
The fused_feedforward operator is the same as the following pseudo codes:
residual = src;
if (pre_layer_norm)
ln1_out = layer_norm(src);
else
ln1_out = src;
// linear 1
out = linear(ln1_out);
out = dropout(activation(out));
// linear 2
out = linear(out);
if (add_residual)
out = residual + dropout(out);
else
out = dropout(out);
if (!pre_layer_norm)
out = layer_norm(out);
)DOC"
);
}
};
...
...
python/paddle/incubate/nn/functional/fused_transformer.py
浏览文件 @
223fb7b3
...
...
@@ -55,12 +55,19 @@ def fused_feedforward(x,
.. code-block:: python
residual =
src;
residual =
x
if pre_layer_norm:
src = layer_norm(src)
src = linear(dropout(activation(dropout(linear(src)))))
out = layer_norm1(x)
else:
out = x
out = linear2(dropout1(activation(linear1(src))))
if add_residual:
out = residual + dropout2(out)
else:
out = dropout2(out)
if not pre_layer_norm:
src = layer_norm(out)
out = layer_norm2(out)
Args:
x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`.
...
...
@@ -102,15 +109,13 @@ def fused_feedforward(x,
# required: gpu
import paddle
import numpy as np
x_data = np.random.random((1, 8, 8)).astype("float32")
linear1_weight_data = np.random.random((8, 8)).astype("float32")
linear2_weight_data = np.random.random((8, 8)).astype("float32")
x = paddle.to_tensor(x_data)
linear1_weight = paddle.to_tensor(linear1_weight_data)
linear2_weight = paddle.to_tensor(linear2_weight_data)
out = paddle.incubate.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight)
print(out.numpy().shape)
import paddle.incubate.nn.functional as F
x = paddle.randn(shape=(1, 8, 8), dtype="float32")
linear1_weight = paddle.randn(shape=(8, 8), dtype="float32")
linear2_weight = paddle.randn(shape=(8, 8), dtype="float32")
out = F.fused_feedforward(x, linear1_weight, linear2_weight)
print(out.shape)
# (1, 8, 8)
"""
_verify_dropout_rate
(
dropout1_rate
)
...
...
@@ -392,27 +397,34 @@ def fused_multi_head_attention(x,
.. code-block:: python
if pre_layer_norm:
out = layer_norm(x)
out = linear(out) + qkv) + bias
else:
out = linear(x) + bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out.
q = out[0:1,::]
k = out[1:2,::]
v = out[2:3,::]
out = q * k^t
out = attn_mask + out
out = softmax(out)
out = dropout(out)
out = out * v
out = transpose(out, perm=[0, 2, 1, 3])
out = out_linear(out)
if pre_layer_norm:
out = x + dropout(linear_bias + out)
residual = x
if pre_layer_norm:
out = layer_norm(x)
else:
out = layer_norm(x + dropout(linear_bias + out))
out = x
# compute q, k, v
out = matmul(out, qkv_weight) + qkv_bias
out = transpose(out, perm=[2, 0, 3, 1, 4])
# extract q, k and v from out
q = out[0:1,::] * (head_dim ** -0.5)
k = out[1:2,::]
v = out[2:3,::]
out = matmul(q, k, transpose_y=True)
out = out + attn_mask
out = softmax(out)
out = dropout(out)
out = matmul(out, v)
# combine heads
out = transpose(out, perm=[0, 2, 1, 3])
# project to output
out = linear(out)
if add_residual:
out = residual + dropout(out)
else:
out = dropout(out)
if not pre_layer_norm:
out = layer_norm(out)
Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
...
...
@@ -420,7 +432,7 @@ def fused_multi_head_attention(x,
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
(False). Default False.
(False). Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
...
...
@@ -432,7 +444,7 @@ def fused_multi_head_attention(x,
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
cache_kv (Tensor, optional): For generation model, cache structure. The shape is `[2, bsz, num_head, seq_len, head_dim]`. Default None.
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录