未验证 提交 2b848aef 编写于 作者: Y Yuang Liu 提交者: GitHub

Fused attention pass fwd, create the fused_attention op. (#50125)

上级 e6d29e00
......@@ -22,7 +22,6 @@ namespace patterns {
PDNode* FusedAttentionPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
......@@ -259,7 +258,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node})
.LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node});
if (!add_residual && !post_layer_norm) {
if (!add_residual && pre_layer_norm) {
return out_linear_dropout_out_node;
}
......@@ -276,7 +275,7 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
residual_ele_add_node->LinksFrom({x, out_linear_dropout_out_node})
.LinksTo({residual_ele_add_out_node});
if (!post_layer_norm) {
if (pre_layer_norm) {
return residual_ele_add_out_node;
}
}
......@@ -323,13 +322,12 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
bool pre_layer_norm,
bool post_layer_norm,
bool has_attn_mask,
bool do_dropout,
bool add_residual) {
// post layer norm
PDNode* post_layer_norm_grad_out_node{nullptr};
if (post_layer_norm) {
if (!pre_layer_norm) {
auto* post_layer_norm_grad_node =
pattern->NewNode(post_layer_norm_grad_op_repr())
->assert_is_op("layer_norm_grad");
......@@ -375,7 +373,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
PDNode* residual_ele_add_grad_x_grad_node{nullptr};
if (add_residual) {
PDNode* ele_add_grad_input = x;
if (post_layer_norm) {
if (!pre_layer_norm) {
ele_add_grad_input = post_layer_norm_grad_out_node;
}
auto* residual_ele_add_grad_node =
......@@ -404,7 +402,7 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
// get the real input x for dropout grad
PDNode* out_linear_grad_input_node = x;
if (post_layer_norm && !add_residual) {
if (!pre_layer_norm && !add_residual) {
out_linear_grad_input_node = post_layer_norm_grad_out_node;
} else if (add_residual) {
out_linear_grad_input_node = residual_ele_add_grad_out_node;
......@@ -769,11 +767,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
void FusedAttentionsPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
graph = PreMaskDropResPostFwd(graph);
graph = PreMaskDropResPostBwd(graph);
graph = PreMaskDropResFwd(graph);
graph = PreMaskDropResBwd(graph);
}
ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
......@@ -784,7 +782,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern(x,
/* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
......@@ -835,10 +832,191 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern);
OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block());
fused_attention_op_desc.SetType("fused_attention");
fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()});
fused_attention_op_desc.SetAttr("pre_layer_norm", true);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node,
pre_layer_norm_scale,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_bias_node, pre_layer_norm_bias, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_out_node, pre_layer_norm_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
pre_layer_norm_mean_node, pre_layer_norm_mean, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_variance_node,
pre_layer_norm_variance,
fused_attention_pattern);
fused_attention_op_desc.SetInput("LnScale",
{pre_layer_norm_scale_node->Name()});
fused_attention_op_desc.SetInput("LnBias",
{pre_layer_norm_bias_node->Name()});
fused_attention_op_desc.SetOutput("LnOut",
{pre_layer_norm_out_node->Name()});
fused_attention_op_desc.SetOutput("LnMean",
{pre_layer_norm_mean_node->Name()});
fused_attention_op_desc.SetOutput("LnVariance",
{pre_layer_norm_variance_node->Name()});
fused_attention_op_desc.SetAttr(
"epsilon",
PADDLE_GET_CONST(float,
pre_layer_norm_op_node->Op()->GetAttr("epsilon")));
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]);
GET_IR_NODE_FROM_SUBGRAPH(
post_layer_norm_op_node, post_layer_norm_op, fused_attention_pattern);
fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
fuse_qkv_ele_add_bias,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_out_node,
fuse_qkv_ele_add_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_transpose_out_node,
fuse_qkv_transpose_out,
fused_attention_pattern);
fused_attention_op_desc.SetInput("QKVW", {fuse_qkv_matmul_w_node->Name()});
fused_attention_op_desc.SetInput("QKVBias",
{fuse_qkv_ele_add_bias_node->Name()});
fused_attention_op_desc.SetOutput("QKVOut",
{fuse_qkv_matmul_out_node->Name()});
fused_attention_op_desc.SetOutput("QKVBiasOut",
{fuse_qkv_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("TransposeOut2",
{fuse_qkv_transpose_out_node->Name()});
// TODO(Yuang Liu): finish the handler
GET_IR_NODE_FROM_SUBGRAPH(
qk_matmul_out_node, qk_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_mask_node,
add_mask_ele_add_mask,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(add_mask_ele_add_out_node,
add_mask_ele_add_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qk_softmax_out_node, qk_softmax_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attn_dropout_out_node, attn_dropout_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
attn_dropout_mask_node, attn_dropout_mask, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_matmul_out_node, qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern);
fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()});
fused_attention_op_desc.SetInput("SrcMask",
{add_mask_ele_add_mask_node->Name()});
fused_attention_op_desc.SetOutput("SrcMaskOut",
{add_mask_ele_add_out_node->Name()});
fused_attention_op_desc.SetOutput("SoftmaxOut",
{qk_softmax_out_node->Name()});
fused_attention_op_desc.SetAttr(
"attn_dropout_rate",
PADDLE_GET_CONST(float,
attn_dropout_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_op_desc.SetAttr(
"is_test",
PADDLE_GET_CONST(bool, attn_dropout_op_node->Op()->GetAttr("is_test")));
fused_attention_op_desc.SetAttr(
"attn_dropout_fix_seed",
PADDLE_GET_CONST(bool,
attn_dropout_op_node->Op()->GetAttr("fix_seed")));
fused_attention_op_desc.SetAttr(
"attn_dropout_seed",
PADDLE_GET_CONST(int, attn_dropout_op_node->Op()->GetAttr("seed")));
fused_attention_op_desc.SetAttr(
"attn_dropout_implementation",
PADDLE_GET_CONST(
std::string,
attn_dropout_op_node->Op()->GetAttr("dropout_implementation")));
fused_attention_op_desc.SetOutput("AttnDropoutMaskOut",
{attn_dropout_mask_node->Name()});
fused_attention_op_desc.SetOutput("AttnDropoutOut",
{attn_dropout_out_node->Name()});
fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()});
fused_attention_op_desc.SetOutput("FMHAOut",
{qkv_reshape_out_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(
out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_out_node,
out_linear_matmul_out,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node,
out_linear_ele_add_bias,
fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_out_node,
out_linear_ele_add_out,
fused_attention_pattern);
fused_attention_op_desc.SetInput("OutLinearW",
{out_linear_matmul_w_node->Name()});
fused_attention_op_desc.SetInput("OutLinearBias",
{out_linear_ele_add_bias_node->Name()});
fused_attention_op_desc.SetOutput("OutLinearOut",
{out_linear_matmul_out_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node,
out_linear_dropout_mask,
fused_attention_pattern);
fused_attention_op_desc.SetAttr(
"dropout_rate",
PADDLE_GET_CONST(
float, out_linear_dropout_op_node->Op()->GetAttr("dropout_prob")));
fused_attention_op_desc.SetAttr(
"dropout_fix_seed",
PADDLE_GET_CONST(
bool, out_linear_dropout_op_node->Op()->GetAttr("fix_seed")));
fused_attention_op_desc.SetAttr(
"dropout_seed",
PADDLE_GET_CONST(int,
out_linear_dropout_op_node->Op()->GetAttr("seed")));
fused_attention_op_desc.SetAttr(
"dropout_implementation",
PADDLE_GET_CONST(std::string,
out_linear_dropout_op_node->Op()->GetAttr(
"dropout_implementation")));
fused_attention_op_desc.SetOutput("DropoutMaskOut",
{out_linear_dropout_mask_node->Name()});
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node,
residual_ele_add_out,
fused_attention_pattern);
fused_attention_op_desc.SetAttr("add_residual", true);
fused_attention_op_desc.SetOutput("Y", {residual_ele_add_out_node->Name()});
auto fused_attention_node = g->CreateOpNode(&fused_attention_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), fused_attention_node);
IR_NODE_LINK_TO(pre_layer_norm_scale_node, fused_attention_node);
IR_NODE_LINK_TO(pre_layer_norm_bias_node, fused_attention_node);
IR_NODE_LINK_TO(fuse_qkv_matmul_w_node, fused_attention_node);
IR_NODE_LINK_TO(fuse_qkv_ele_add_bias_node, fused_attention_node);
IR_NODE_LINK_TO(add_mask_ele_add_mask_node, fused_attention_node);
IR_NODE_LINK_TO(out_linear_matmul_w_node, fused_attention_node);
IR_NODE_LINK_TO(out_linear_ele_add_bias_node, fused_attention_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_out_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_mean_node);
IR_NODE_LINK_TO(fused_attention_node, pre_layer_norm_variance_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_ele_add_out_node);
IR_NODE_LINK_TO(fused_attention_node, fuse_qkv_transpose_out_node);
IR_NODE_LINK_TO(fused_attention_node, qk_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, add_mask_ele_add_out_node);
IR_NODE_LINK_TO(fused_attention_node, qk_softmax_out_node);
IR_NODE_LINK_TO(fused_attention_node, attn_dropout_mask_node);
IR_NODE_LINK_TO(fused_attention_node, attn_dropout_out_node);
IR_NODE_LINK_TO(fused_attention_node, qkv_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, qkv_reshape_out_node);
IR_NODE_LINK_TO(fused_attention_node, out_linear_matmul_out_node);
IR_NODE_LINK_TO(fused_attention_node, out_linear_dropout_mask_node);
IR_NODE_LINK_TO(fused_attention_node, residual_ele_add_out_node);
GraphSafeRemoveNodes(g,
{pre_layer_norm_op_node,
......@@ -858,8 +1036,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
out_linear_matmul_op_node,
out_linear_ele_add_op_node,
out_linear_dropout_op_node,
residual_ele_add_op_node,
post_layer_norm_op_node});
residual_ele_add_op_node});
found_fused_attention++;
};
......@@ -869,18 +1046,17 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostFwd(Graph* graph) const {
return graph;
}
ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "x"))
->AsInput()
->assert_is_op_input("layer_norm_grad", "Y@GRAD");
->assert_is_op_input("elementwise_add_grad", "Out@GRAD");
patterns::FusedAttentionGradPattern fused_attention_grad_pattern(
gpd.mutable_pattern(), "fused_attention_grad_pattern");
fused_attention_grad_pattern(x,
/* pre_layer_norm */ true,
/* post_layer_norm */ true,
/* has_attn_mask */ true,
/* do_dropout */ true,
/* add_residual */ true);
......@@ -891,9 +1067,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
Graph* g) {
VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion";
GET_IR_NODE_FROM_SUBGRAPH(post_layer_norm_grad_op_node,
post_layer_norm_grad_op,
fused_attention_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node,
residual_ele_add_grad_op,
fused_attention_grad_pattern);
......@@ -953,17 +1126,26 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResPostBwd(Graph* graph) const {
// TODO(Yuang Liu): finish the handler
GraphSafeRemoveNodes(
g, {post_layer_norm_grad_op_node, residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node, out_linear_ele_add_grad_op_node,
out_linear_matmul_grad_op_node, qkv_reshape_grad_op_node,
qkv_transpose_grad_op_node, qkv_matmul_grad_op_node,
attn_dropout_grad_op_node, qk_softmax_grad_op_node,
add_mask_ele_add_grad_op_node, qk_scale_grad_op_node,
qk_matmul_grad_op_node, fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node, fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node, fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node, grad_accumulation_sum_op_node});
GraphSafeRemoveNodes(g,
{residual_ele_add_grad_op_node,
out_linear_dropout_grad_op_node,
out_linear_ele_add_grad_op_node,
out_linear_matmul_grad_op_node,
qkv_reshape_grad_op_node,
qkv_transpose_grad_op_node,
qkv_matmul_grad_op_node,
attn_dropout_grad_op_node,
qk_softmax_grad_op_node,
add_mask_ele_add_grad_op_node,
qk_scale_grad_op_node,
qk_matmul_grad_op_node,
fuse_qkv_split_grad_op_node,
fuse_qkv_transpose_grad_op_node,
fuse_qkv_reshape_grad_op_node,
fuse_qkv_ele_add_grad_op_node,
fuse_qkv_matmul_grad_op_node,
pre_layer_norm_grad_op_node,
grad_accumulation_sum_op_node});
found_fused_attention++;
};
......
......@@ -28,7 +28,7 @@ namespace patterns {
// Declare patterns for multi head attention.
// Can detect:
// 1. Pre layer norm, post layer norm or sandwich layer norm.
// 1. Pre layer norm or post layer norm.
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
......@@ -37,11 +37,10 @@ struct FusedAttentionPattern : public PatternBase {
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // do pre ln or not
bool post_layer_norm, // do post ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool pre_layer_norm, // do pre ln or not
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// pre layer norm
PATTERN_DECL_NODE(pre_layer_norm_op);
......@@ -134,11 +133,10 @@ struct FusedAttentionGradPattern : public PatternBase {
: PatternBase(pattern, name_scope, "fused_attention_pattern") {}
PDNode* operator()(PDNode* x,
bool pre_layer_norm, // pre ln
bool post_layer_norm, // post ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
bool pre_layer_norm, // pre ln
bool has_attn_mask, // add attn mask to qk or not
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
......@@ -275,9 +273,9 @@ class FusedAttentionsPass : public FusePassBase {
// If true, the function name will have an abbreviation part.
// If false, the function name won't contain an abbreviation for it.
ir::Graph* PreMaskDropResPostFwd(Graph* graph) const;
ir::Graph* PreMaskDropResFwd(Graph* graph) const;
ir::Graph* PreMaskDropResPostBwd(Graph* graph) const;
ir::Graph* PreMaskDropResBwd(Graph* graph) const;
};
} // namespace ir
......
......@@ -31,7 +31,6 @@ class MultiHeadAttention(paddle.nn.Layer):
num_heads,
add_residual=True,
pre_ln=True,
post_ln=False,
attn_dropout=True,
):
super(MultiHeadAttention, self).__init__()
......@@ -42,7 +41,6 @@ class MultiHeadAttention(paddle.nn.Layer):
self.add_residual = add_residual
self.pre_ln = pre_ln
self.post_ln = post_ln
self.attn_dropout = attn_dropout
self.head_dim = embed_dim // num_heads
......@@ -90,7 +88,7 @@ class MultiHeadAttention(paddle.nn.Layer):
if self.add_residual:
out = residual + out
if self.post_ln:
if not self.pre_ln:
# post layer norm
out = self.norm2(out)
......@@ -104,7 +102,6 @@ class TestFusedAttentionPass(unittest.TestCase):
def setUp(self):
self.add_residual = True
self.pre_ln = True
self.post_ln = True
self.attn_dropout = True
self.add_mask = True
......@@ -120,6 +117,7 @@ class TestFusedAttentionPass(unittest.TestCase):
).astype('float32')
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
......@@ -142,7 +140,6 @@ class TestFusedAttentionPass(unittest.TestCase):
num_heads,
add_residual=self.add_residual,
pre_ln=self.pre_ln,
post_ln=self.post_ln,
attn_dropout=self.attn_dropout,
)
......@@ -157,13 +154,14 @@ class TestFusedAttentionPass(unittest.TestCase):
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert ops[2].type == 'reduce_mean'
assert ops[4].type == 'reduce_mean_grad'
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[7].type == 'sgd'
assert ops[8].type == 'sgd'
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册