From 7e8ef328df04c4bbd57c38300b3a50e37aa85a7f Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Fri, 3 Feb 2023 15:33:59 +0800 Subject: [PATCH] Fused attention pass backward op replace. (#50186) --- .../framework/ir/fused_attention_pass.cc | 391 +++++++++++++++++- .../fluid/framework/ir/fused_attention_pass.h | 38 +- .../operators/fused/fused_attention_op.cc | 2 +- .../unittests/test_fused_attention_pass.py | 11 +- 4 files changed, 428 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc index 7b0f469ff8..dcf5f05e64 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.cc +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -766,12 +766,15 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, void FusedAttentionsPass::ApplyImpl(Graph* graph) const { FusePassBase::Init(name_scope_, graph); + FusedAttentionPassCache cache; - graph = PreMaskDropResFwd(graph); - graph = PreMaskDropResBwd(graph); + graph = PreMaskDropResFwd(graph, &cache); + graph = PreMaskDropResBwd(graph, &cache); + cache.ResetCache(); } -ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { +ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( + Graph* graph, FusedAttentionPassCache* cache) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) @@ -792,6 +795,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { Graph* g) { VLOG(3) << "handle FusedMultiHeadAttention pass's fusion"; + int block_id = g->GetBlockId(); + GET_IR_NODE_FROM_SUBGRAPH( pre_layer_norm_op_node, pre_layer_norm_op, fused_attention_pattern); GET_IR_NODE_FROM_SUBGRAPH( @@ -833,9 +838,15 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH( residual_ele_add_op_node, residual_ele_add_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern); + std::string cache_anchor_name = fuse_qkv_matmul_w_node->Var()->Name(); + OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block()); fused_attention_op_desc.SetType("fused_attention"); fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + cache->InsertIntoCache(GenerateCacheKey(cache_anchor_name, "X", block_id), + subgraph.at(x)); fused_attention_op_desc.SetAttr("pre_layer_norm", true); GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_scale_node, @@ -860,6 +871,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { {pre_layer_norm_mean_node->Name()}); fused_attention_op_desc.SetOutput("LnVariance", {pre_layer_norm_variance_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "LnScale", block_id), + pre_layer_norm_scale_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "LnBias", block_id), + pre_layer_norm_bias_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "LnOut", block_id), + pre_layer_norm_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "LnMean", block_id), + pre_layer_norm_mean_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "LnVariance", block_id), + pre_layer_norm_variance_node); fused_attention_op_desc.SetAttr( "epsilon", PADDLE_GET_CONST(float, @@ -869,8 +895,6 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { std::vector shape = PADDLE_GET_CONST( std::vector, fuse_qkv_reshape_op_node->Op()->GetAttr("shape")); fused_attention_op_desc.SetAttr("num_heads", shape[2]); - GET_IR_NODE_FROM_SUBGRAPH( - fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern); GET_IR_NODE_FROM_SUBGRAPH( fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern); GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node, @@ -891,6 +915,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { {fuse_qkv_ele_add_out_node->Name()}); fused_attention_op_desc.SetOutput("TransposeOut2", {fuse_qkv_transpose_out_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKVW", block_id), + fuse_qkv_matmul_w_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKVBias", block_id), + fuse_qkv_ele_add_bias_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKVOut", block_id), + fuse_qkv_matmul_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKVBiasOut", block_id), + fuse_qkv_ele_add_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "TransposeOut2", block_id), + fuse_qkv_transpose_out_node); GET_IR_NODE_FROM_SUBGRAPH( qk_matmul_out_node, qk_matmul_out, fused_attention_pattern); @@ -911,12 +950,24 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH( qkv_reshape_out_node, qkv_reshape_out, fused_attention_pattern); fused_attention_op_desc.SetOutput("QKOut", {qk_matmul_out_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKOut", block_id), + qk_matmul_out_node); fused_attention_op_desc.SetInput("SrcMask", {add_mask_ele_add_mask_node->Name()}); fused_attention_op_desc.SetOutput("SrcMaskOut", {add_mask_ele_add_out_node->Name()}); fused_attention_op_desc.SetOutput("SoftmaxOut", {qk_softmax_out_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "SrcMask", block_id), + add_mask_ele_add_mask_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "SrcMaskOut", block_id), + add_mask_ele_add_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "SoftmaxOut", block_id), + qk_softmax_out_node); fused_attention_op_desc.SetAttr( "attn_dropout_rate", PADDLE_GET_CONST(float, @@ -943,6 +994,18 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { fused_attention_op_desc.SetOutput("QKTVOut", {qkv_matmul_out_node->Name()}); fused_attention_op_desc.SetOutput("FMHAOut", {qkv_reshape_out_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "AttnDropoutMaskOut", block_id), + attn_dropout_mask_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "AttnDropoutOut", block_id), + attn_dropout_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "QKTVOut", block_id), + qkv_matmul_out_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "FMHAOut", block_id), + qkv_reshape_out_node); GET_IR_NODE_FROM_SUBGRAPH( out_linear_matmul_w_node, out_linear_matmul_w, fused_attention_pattern); @@ -952,15 +1015,21 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_bias_node, out_linear_ele_add_bias, fused_attention_pattern); - GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_out_node, - out_linear_ele_add_out, - fused_attention_pattern); fused_attention_op_desc.SetInput("OutLinearW", {out_linear_matmul_w_node->Name()}); fused_attention_op_desc.SetInput("OutLinearBias", {out_linear_ele_add_bias_node->Name()}); fused_attention_op_desc.SetOutput("OutLinearOut", {out_linear_matmul_out_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "OutLinearW", block_id), + out_linear_matmul_w_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "OutLinearBias", block_id), + out_linear_ele_add_bias_node); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "OutLinearOut", block_id), + out_linear_matmul_out_node); GET_IR_NODE_FROM_SUBGRAPH(out_linear_dropout_mask_node, out_linear_dropout_mask, fused_attention_pattern); @@ -983,6 +1052,9 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { "dropout_implementation"))); fused_attention_op_desc.SetOutput("DropoutMaskOut", {out_linear_dropout_mask_node->Name()}); + cache->InsertIntoCache( + GenerateCacheKey(cache_anchor_name, "DropoutMaskOut", block_id), + out_linear_dropout_mask_node); GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_out_node, residual_ele_add_out, @@ -1037,6 +1109,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { out_linear_ele_add_op_node, out_linear_dropout_op_node, residual_ele_add_op_node}); + found_fused_attention++; }; @@ -1046,7 +1119,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(Graph* graph) const { return graph; } -ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { +ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( + Graph* graph, FusedAttentionPassCache* cache) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) @@ -1067,6 +1141,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { Graph* g) { VLOG(3) << "handle FusedMultiHeadAttention backward pass's fusion"; + int block_id = g->GetBlockId(); + GET_IR_NODE_FROM_SUBGRAPH(residual_ele_add_grad_op_node, residual_ele_add_grad_op, fused_attention_grad_pattern); @@ -1124,7 +1200,302 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(Graph* graph) const { grad_accumulation_sum_op, fused_attention_grad_pattern); - // TODO(Yuang Liu): finish the handler + OpDesc fused_attention_grad_op_desc( + residual_ele_add_grad_op_node->Op()->Block()); + fused_attention_grad_op_desc.SetType("fused_attention_grad"); + fused_attention_grad_op_desc.SetInput("Y@GRAD", {subgraph.at(x)->Name()}); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_w_node, + fuse_qkv_matmul_grad_w, + fused_attention_grad_pattern); + std::string cache_anchor_name = fuse_qkv_matmul_grad_w_node->Var()->Name(); + + auto* x_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "X", block_id)); + auto* attn_dropout_mask_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "AttnDropoutMaskOut", block_id)); + auto* attn_dropout_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "AttnDropoutOut", block_id)); + auto* dropout_mask_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "DropoutMaskOut", block_id)); + auto* fmha_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "FMHAOut", block_id)); + auto* ln_bias_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "LnBias", block_id)); + auto* ln_mean_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "LnMean", block_id)); + auto* ln_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "LnOut", block_id)); + auto* ln_scale_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "LnScale", block_id)); + auto* ln_variance_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "LnVariance", block_id)); + auto* out_linear_bias_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "OutLinearBias", block_id)); + auto* out_linear_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "OutLinearOut", block_id)); + auto* out_linear_w_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "OutLinearW", block_id)); + auto* qk_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKOut", block_id)); + auto* qktv_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKTVOut", block_id)); + auto* qkv_bias_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKVBias", block_id)); + auto* qkv_bias_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKVBiasOut", block_id)); + auto* qkv_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKVOut", block_id)); + auto* qkv_w_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "QKVW", block_id)); + auto* softmax_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "SoftmaxOut", block_id)); + auto* src_mask_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "SrcMask", block_id)); + auto* src_mask_out_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "SrcMaskOut", block_id)); + auto* transpose_out_2_node = cache->GetNodeFromCache( + GenerateCacheKey(cache_anchor_name, "TransposeOut2", block_id)); + fused_attention_grad_op_desc.SetInput("X", {x_node->Name()}); + fused_attention_grad_op_desc.SetInput("AttnDropoutMaskOut", + {attn_dropout_mask_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("AttnDropoutOut", + {attn_dropout_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("DropoutMaskOut", + {dropout_mask_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("FMHAOut", {fmha_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("LnBias", {ln_bias_node->Name()}); + fused_attention_grad_op_desc.SetInput("LnMean", {ln_mean_node->Name()}); + fused_attention_grad_op_desc.SetInput("LnOut", {ln_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("LnScale", {ln_scale_node->Name()}); + fused_attention_grad_op_desc.SetInput("LnVariance", + {ln_variance_node->Name()}); + fused_attention_grad_op_desc.SetInput("OutLinearBias", + {out_linear_bias_node->Name()}); + fused_attention_grad_op_desc.SetInput("OutLinearOut", + {out_linear_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("OutLinearW", + {out_linear_w_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKOut", {qk_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKTVOut", {qktv_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKVBias", {qkv_bias_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKVBiasOut", + {qkv_bias_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKVOut", {qkv_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("QKVW", {qkv_w_node->Name()}); + fused_attention_grad_op_desc.SetInput("SoftmaxOut", + {softmax_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("SrcMask", {src_mask_node->Name()}); + fused_attention_grad_op_desc.SetInput("SrcMaskOut", + {src_mask_out_node->Name()}); + fused_attention_grad_op_desc.SetInput("TransposeOut2", + {transpose_out_2_node->Name()}); + + fused_attention_grad_op_desc.SetAttr("add_residual", true); + fused_attention_grad_op_desc.SetAttr( + "attn_dropout_rate", + PADDLE_GET_CONST( + float, attn_dropout_grad_op_node->Op()->GetAttr("dropout_prob"))); + fused_attention_grad_op_desc.SetAttr( + "is_test", + PADDLE_GET_CONST(bool, + attn_dropout_grad_op_node->Op()->GetAttr("is_test"))); + fused_attention_grad_op_desc.SetAttr( + "attn_dropout_fix_seed", + PADDLE_GET_CONST(bool, + attn_dropout_grad_op_node->Op()->GetAttr("fix_seed"))); + fused_attention_grad_op_desc.SetAttr( + "attn_dropout_seed", + PADDLE_GET_CONST(int, + attn_dropout_grad_op_node->Op()->GetAttr("seed"))); + fused_attention_grad_op_desc.SetAttr( + "attn_dropout_implementation", + PADDLE_GET_CONST(std::string, + attn_dropout_grad_op_node->Op()->GetAttr( + "dropout_implementation"))); + fused_attention_grad_op_desc.SetAttr( + "dropout_rate", + PADDLE_GET_CONST( + float, + out_linear_dropout_grad_op_node->Op()->GetAttr("dropout_prob"))); + fused_attention_grad_op_desc.SetAttr( + "dropout_fix_seed", + PADDLE_GET_CONST( + bool, out_linear_dropout_grad_op_node->Op()->GetAttr("fix_seed"))); + fused_attention_grad_op_desc.SetAttr( + "dropout_seed", + PADDLE_GET_CONST( + int, out_linear_dropout_grad_op_node->Op()->GetAttr("seed"))); + fused_attention_grad_op_desc.SetAttr( + "dropout_implementation", + PADDLE_GET_CONST(std::string, + out_linear_dropout_grad_op_node->Op()->GetAttr( + "dropout_implementation"))); + fused_attention_grad_op_desc.SetAttr( + "epsilon", + PADDLE_GET_CONST( + float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon"))); + std::vector shape = + PADDLE_GET_CONST(std::vector, + fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape")); + fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]); + fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true); + fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true); + + // forward op will use default value + // but backward op has to set these redundant attrs + fused_attention_grad_op_desc.SetAttr( + "ln_epsilon", + PADDLE_GET_CONST( + float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon"))); + fused_attention_grad_op_desc.SetAttr("ring_id", -1); + + GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_x_grad_node, + qkv_matmul_grad_x_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_x_grad_node, + out_linear_matmul_grad_x_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_bias_grad_node, + pre_layer_norm_grad_bias_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_x_grad_node, + fuse_qkv_matmul_grad_x_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(pre_layer_norm_grad_scale_grad_node, + pre_layer_norm_grad_scale_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_bias_grad_node, + out_linear_ele_add_grad_bias_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_ele_add_grad_x_grad_node, + out_linear_ele_add_grad_x_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(out_linear_matmul_grad_w_grad_node, + out_linear_matmul_grad_w_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qk_scale_grad_out_node, + qk_scale_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qkv_transpose_grad_out_node, + qkv_transpose_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_bias_grad_node, + fuse_qkv_ele_add_grad_bias_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_reshape_grad_out_node, + fuse_qkv_reshape_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_grad_x_grad_node, + fuse_qkv_ele_add_grad_x_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_matmul_grad_w_grad_node, + fuse_qkv_matmul_grad_w_grad, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(attn_dropout_grad_out_node, + attn_dropout_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(qk_softmax_grad_out_node, + qk_softmax_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_split_grad_out_node, + fuse_qkv_split_grad_out, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(grad_accumulation_out_node, + grad_accumulation_out, + fused_attention_grad_pattern); + fused_attention_grad_op_desc.SetOutput( + "AttnDropoutOut@GRAD", {qkv_matmul_grad_x_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "FMHAOut@GRAD", {out_linear_matmul_grad_x_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "LnBias@GRAD", {pre_layer_norm_grad_bias_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "LnOut@GRAD", {fuse_qkv_matmul_grad_x_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "LnScale@GRAD", {pre_layer_norm_grad_scale_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "OutLinearBias@GRAD", {out_linear_ele_add_grad_bias_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "OutLinearOut@GRAD", {out_linear_ele_add_grad_x_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "OutLinearW@GRAD", {out_linear_matmul_grad_w_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput("QKOut@GRAD", + {qk_scale_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "QKTVOut@GRAD", {qkv_transpose_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "QKVBias@GRAD", {fuse_qkv_ele_add_grad_bias_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "QKVBiasOut@GRAD", {fuse_qkv_reshape_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "QKVOut@GRAD", {fuse_qkv_ele_add_grad_x_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "QKVW@GRAD", {fuse_qkv_matmul_grad_w_grad_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "SoftmaxOut@GRAD", {attn_dropout_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput("SrcMaskOut@GRAD", + {qk_softmax_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "TransposeOut2@GRAD", {fuse_qkv_split_grad_out_node->Name()}); + fused_attention_grad_op_desc.SetOutput( + "X@GRAD", {grad_accumulation_out_node->Name()}); + + auto fused_attention_grad_node = + g->CreateOpNode(&fused_attention_grad_op_desc); + + IR_NODE_LINK_TO(fused_attention_grad_node, qkv_matmul_grad_x_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + out_linear_matmul_grad_x_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + pre_layer_norm_grad_bias_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + fuse_qkv_matmul_grad_x_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + pre_layer_norm_grad_scale_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + out_linear_ele_add_grad_bias_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + out_linear_ele_add_grad_x_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + out_linear_matmul_grad_w_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, qk_scale_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, qkv_transpose_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + fuse_qkv_ele_add_grad_bias_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, fuse_qkv_reshape_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + fuse_qkv_ele_add_grad_x_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, + fuse_qkv_matmul_grad_w_grad_node); + IR_NODE_LINK_TO(fused_attention_grad_node, attn_dropout_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, qk_softmax_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, fuse_qkv_split_grad_out_node); + IR_NODE_LINK_TO(fused_attention_grad_node, grad_accumulation_out_node); + + IR_NODE_LINK_TO(subgraph.at(x), fused_attention_grad_node); + IR_NODE_LINK_TO(x_node, fused_attention_grad_node); + IR_NODE_LINK_TO(attn_dropout_mask_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(attn_dropout_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(dropout_mask_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(fmha_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(ln_bias_node, fused_attention_grad_node); + IR_NODE_LINK_TO(ln_mean_node, fused_attention_grad_node); + IR_NODE_LINK_TO(ln_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(ln_scale_node, fused_attention_grad_node); + IR_NODE_LINK_TO(ln_variance_node, fused_attention_grad_node); + IR_NODE_LINK_TO(out_linear_bias_node, fused_attention_grad_node); + IR_NODE_LINK_TO(out_linear_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(out_linear_w_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qk_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qktv_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qkv_bias_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qkv_bias_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qkv_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(qkv_w_node, fused_attention_grad_node); + IR_NODE_LINK_TO(softmax_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(src_mask_node, fused_attention_grad_node); + IR_NODE_LINK_TO(src_mask_out_node, fused_attention_grad_node); + IR_NODE_LINK_TO(transpose_out_2_node, fused_attention_grad_node); GraphSafeRemoveNodes(g, {residual_ele_add_grad_op_node, diff --git a/paddle/fluid/framework/ir/fused_attention_pass.h b/paddle/fluid/framework/ir/fused_attention_pass.h index 41a90bd599..222900860a 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.h +++ b/paddle/fluid/framework/ir/fused_attention_pass.h @@ -16,6 +16,7 @@ #include #include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" @@ -252,6 +253,31 @@ struct FusedAttentionGradPattern : public PatternBase { } // namespace patterns +class FusedAttentionPassCache { + public: + ir::Node* GetNodeFromCache(const std::string name) { + if (var_name_to_ir_node_cache_.count(name)) { + return var_name_to_ir_node_cache_.find(name)->second; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "The key (%d) of FusedAttentionCache does not exist.", name)); + } + + void InsertIntoCache(const std::string name, ir::Node* node) { + if (!var_name_to_ir_node_cache_.count(name)) { + var_name_to_ir_node_cache_.insert({name, node}); + } else { + PADDLE_THROW(platform::errors::AlreadyExists( + "The key (%d) of FusedAttentionCache already exist.", name)); + } + } + + void ResetCache() { var_name_to_ir_node_cache_.clear(); } + + private: + std::unordered_map var_name_to_ir_node_cache_; +}; + class FusedAttentionsPass : public FusePassBase { public: virtual ~FusedAttentionsPass() {} @@ -273,9 +299,17 @@ class FusedAttentionsPass : public FusePassBase { // If true, the function name will have an abbreviation part. // If false, the function name won't contain an abbreviation for it. - ir::Graph* PreMaskDropResFwd(Graph* graph) const; + ir::Graph* PreMaskDropResFwd(Graph* graph, + FusedAttentionPassCache* cache) const; + + ir::Graph* PreMaskDropResBwd(Graph* graph, + FusedAttentionPassCache* cache) const; - ir::Graph* PreMaskDropResBwd(Graph* graph) const; + const std::string GenerateCacheKey(const std::string anchor, + const std::string var_name, + int block_id) const { + return anchor + "_" + std::to_string(block_id) + "_" + var_name; + } }; } // namespace ir diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 7d00dda194..347d1ba252 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -375,7 +375,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("BiasDropoutResidualOut", "Result of residual + dropout(src + bias).") .AsIntermediate(); - AddOutput("CacheKVOut", "The udpated cache KV."); + AddOutput("CacheKVOut", "The udpated cache KV.").AsDispensable(); AddOutput("Y", "Result after attention."); AddAttr("num_heads", "The number head for multi_head_attention.") diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py index 12366a574d..98085c223a 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_pass.py @@ -157,11 +157,20 @@ class TestFusedAttentionPass(unittest.TestCase): assert ops[2].type == 'fused_attention' assert ops[3].type == 'reduce_mean' assert ops[5].type == 'reduce_mean_grad' + assert ops[6].type == 'fused_attention_grad' # two ops for linear, one op for reduce mean # one fill constant # one op for reduce mean grad, two ops for linear bwd # the eighth op should be the optimizer - assert ops[8].type == 'sgd' + assert ops[9].type == 'sgd' + + exe = paddle.static.Executor() + exe.run(startup_prog) + rst = exe.run( + main_prog, + feed={'x': x_data, 'attn_mask': mask_data}, + fetch_list=[loss], + ) if __name__ == "__main__": -- GitLab