diff --git a/paddle/fluid/framework/ir/fused_attention_pass.cc b/paddle/fluid/framework/ir/fused_attention_pass.cc index bebed80c469c1e6dfff3bd229b86e94b45945ceb..2e8f96faf01dad73951bdec10309f1f6eaf5db76 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.cc +++ b/paddle/fluid/framework/ir/fused_attention_pass.cc @@ -24,7 +24,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, bool pre_layer_norm, bool has_attn_mask, bool do_dropout, - bool add_residual) { + bool add_residual, + bool use_mp) { // pre layer norm pattern PDNode* pre_layer_norm_out_node{nullptr}; if (pre_layer_norm) { @@ -51,6 +52,28 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, pre_layer_norm_variance_node}); } + // c_identity for mp + PDNode* c_identity_input_node = pre_layer_norm ? pre_layer_norm_out_node : x; + PDNode* c_identity_out_node{nullptr}; + if (use_mp) { + auto* c_identity_node = + pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity"); + if (pre_layer_norm) { + c_identity_input_node->assert_is_op_input("c_identity", "X"); + } + c_identity_out_node = pattern->NewNode(c_identity_out_repr()) + ->assert_is_op_output("c_identity"); + c_identity_node->LinksFrom({c_identity_input_node}) + .LinksTo({c_identity_out_node}); + } + + PDNode* fuse_qkv_input_node = x; + if (use_mp) { + fuse_qkv_input_node = c_identity_out_node; + } else if (pre_layer_norm) { + fuse_qkv_input_node = pre_layer_norm_out_node; + } + // fuse qkv pattern auto* fuse_qkv_matmul_node = pattern->NewNode(fuse_qkv_matmul_op_repr())->assert_is_op("matmul_v2"); @@ -58,15 +81,11 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ->assert_is_op_input("matmul_v2", "Y"); auto* fuse_qkv_matmul_out_node = pattern->NewNode(fuse_qkv_matmul_out_repr()) ->assert_is_op_output("matmul_v2"); - if (pre_layer_norm) { - pre_layer_norm_out_node->assert_is_op_input("matmul_v2", "X"); - fuse_qkv_matmul_node - ->LinksFrom({pre_layer_norm_out_node, fuse_qkv_matmul_w_node}) - .LinksTo({fuse_qkv_matmul_out_node}); - } else { - fuse_qkv_matmul_node->LinksFrom({x, fuse_qkv_matmul_w_node}) - .LinksTo({fuse_qkv_matmul_out_node}); + if (pre_layer_norm || use_mp) { + fuse_qkv_input_node->assert_is_op_input("matmul_v2", "X"); } + fuse_qkv_matmul_node->LinksFrom({fuse_qkv_input_node, fuse_qkv_matmul_w_node}) + .LinksTo({fuse_qkv_matmul_out_node}); auto* fuse_qkv_ele_add_node = pattern->NewNode(fuse_qkv_ele_add_op_repr()) ->assert_is_op("elementwise_add"); @@ -246,6 +265,20 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, ->LinksFrom({out_linear_matmul_out_node, out_linear_ele_add_bias_node}) .LinksTo({out_linear_ele_add_out_node}); + PDNode* mp_allreduce_out_node{nullptr}; + if (use_mp) { + mp_allreduce_out_node = pattern->NewNode(mp_allreudce_sum_out_repr()) + ->assert_is_op_output("mp_allreduce_sum"); + auto* mp_allreduce_node = pattern->NewNode(mp_allreudce_sum_op_repr()) + ->assert_is_op("mp_allreduce_sum"); + out_linear_ele_add_out_node->assert_is_op_input("mp_allreduce_sum"); + mp_allreduce_node->LinksFrom({out_linear_ele_add_out_node}) + .LinksTo({mp_allreduce_out_node}); + } + + PDNode* out_linear_dropout_input_node = + use_mp ? mp_allreduce_out_node : out_linear_ele_add_out_node; + auto* out_linear_dropout_node = pattern->NewNode(out_linear_dropout_op_repr())->assert_is_op("dropout"); auto* out_linear_dropout_mask_node = @@ -254,8 +287,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x, auto* out_linear_dropout_out_node = pattern->NewNode(out_linear_dropout_out_repr()) ->assert_is_op_output("dropout"); - out_linear_ele_add_out_node->assert_is_op_input("dropout", "X"); - out_linear_dropout_node->LinksFrom({out_linear_ele_add_out_node}) + out_linear_dropout_input_node->assert_is_op_input("dropout", "X"); + out_linear_dropout_node->LinksFrom({out_linear_dropout_input_node}) .LinksTo({out_linear_dropout_mask_node, out_linear_dropout_out_node}); if (!add_residual && pre_layer_norm) { @@ -324,7 +357,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, bool pre_layer_norm, bool has_attn_mask, bool do_dropout, - bool add_residual) { + bool add_residual, + bool use_mp) { // post layer norm PDNode* post_layer_norm_grad_out_node{nullptr}; if (!pre_layer_norm) { @@ -424,6 +458,20 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, {out_linear_grad_input_node, out_linear_dropout_grad_mask_node}) .LinksTo({out_linear_dropout_grad_out_node}); + PDNode* mp_c_identity_out_node{nullptr}; + if (use_mp) { + mp_c_identity_out_node = pattern->NewNode(mp_allreudce_sum_grad_out_repr()) + ->assert_is_op_output("c_identity", "Out"); + auto* mp_c_identity_node = pattern->NewNode(mp_allreudce_sum_grad_op_repr()) + ->assert_is_op("c_identity"); + out_linear_dropout_grad_out_node->assert_is_op_input("c_identity"); + mp_c_identity_node->LinksFrom({out_linear_dropout_grad_out_node}) + .LinksTo({mp_c_identity_out_node}); + } + + PDNode* out_linear_ele_add_grad_input_node = + use_mp ? mp_c_identity_out_node : out_linear_dropout_grad_out_node; + auto* out_linear_ele_add_grad_node = pattern->NewNode(out_linear_ele_add_grad_op_repr()) ->assert_is_op("elementwise_add_grad"); @@ -439,10 +487,10 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, auto* out_linear_ele_add_grad_bias_grad_node = pattern->NewNode(out_linear_ele_add_grad_bias_grad_repr()) ->assert_is_op_output("elementwise_add_grad", "Y@GRAD"); - out_linear_dropout_grad_out_node->assert_is_op_input("elementwise_add_grad", - "Out@GRAD"); + out_linear_ele_add_grad_input_node->assert_is_op_input("elementwise_add_grad", + "Out@GRAD"); out_linear_ele_add_grad_node - ->LinksFrom({out_linear_dropout_grad_out_node, + ->LinksFrom({out_linear_ele_add_grad_input_node, out_linear_ele_add_grad_x_node, out_linear_ele_add_grad_bias_node}) .LinksTo({out_linear_ele_add_grad_x_grad_node, @@ -699,54 +747,78 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, .LinksTo( {fuse_qkv_matmul_grad_x_grad_node, fuse_qkv_matmul_grad_w_grad_node}); - if (!pre_layer_norm) { - return fuse_qkv_matmul_grad_x_grad_node; + PDNode* mp_allreduce_out_node{nullptr}; + if (use_mp) { + mp_allreduce_out_node = pattern->NewNode(c_identity_grad_out_repr()) + ->assert_is_op_output("c_allreduce_sum", "Out"); + auto* mp_allreduce_node = pattern->NewNode(c_identity_grad_op_repr()) + ->assert_is_op("c_allreduce_sum"); + fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("c_allreduce_sum", + "X"); + mp_allreduce_node->LinksFrom({fuse_qkv_matmul_grad_x_grad_node}) + .LinksTo({mp_allreduce_out_node}); + } + + PDNode* pre_layer_norm_input_node = + use_mp ? mp_allreduce_out_node : fuse_qkv_matmul_grad_x_grad_node; + if (!pre_layer_norm && !add_residual) { + return pre_layer_norm_input_node; + } + + PDNode* pre_layer_norm_grad_x_grad_node{nullptr}; + + if (pre_layer_norm) { + // pre layer norm + auto* pre_layer_norm_grad_node = + pattern->NewNode(pre_layer_norm_grad_op_repr()) + ->assert_is_op("layer_norm_grad"); + auto* pre_layer_norm_grad_scale_node = + pattern->NewNode(pre_layer_norm_grad_scale_repr()) + ->assert_is_op_input("layer_norm_grad", "Scale"); + auto* pre_layer_norm_grad_bias_node = + pattern->NewNode(pre_layer_norm_grad_bias_repr()) + ->assert_is_op_input("layer_norm_grad", "Bias"); + auto* pre_layer_norm_grad_mean_node = + pattern->NewNode(pre_layer_norm_grad_mean_repr()) + ->assert_is_op_input("layer_norm_grad", "Mean"); + auto* pre_layer_norm_grad_variance_node = + pattern->NewNode(pre_layer_norm_grad_variance_repr()) + ->assert_is_op_input("layer_norm_grad", "Variance"); + auto* pre_layer_norm_grad_x_node = + add_residual ? residual_ele_add_grad_x_node + : pattern->NewNode(pre_layer_norm_grad_x_repr()) + ->assert_is_op_input("layer_norm_grad", "X"); + auto* pre_layer_norm_grad_scale_grad_node = + pattern->NewNode(pre_layer_norm_grad_scale_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Scale@GRAD"); + auto* pre_layer_norm_grad_bias_grad_node = + pattern->NewNode(pre_layer_norm_grad_bias_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "Bias@GRAD"); + pre_layer_norm_grad_x_grad_node = + pattern->NewNode(pre_layer_norm_grad_x_grad_repr()) + ->assert_is_op_output("layer_norm_grad", "X@GRAD"); + pre_layer_norm_input_node->assert_is_op_input("layer_norm_grad", "Y@GRAD"); + pre_layer_norm_grad_node + ->LinksFrom({pre_layer_norm_input_node, + pre_layer_norm_grad_scale_node, + pre_layer_norm_grad_bias_node, + pre_layer_norm_grad_mean_node, + pre_layer_norm_grad_variance_node, + pre_layer_norm_grad_x_node}) + .LinksTo({pre_layer_norm_grad_scale_grad_node, + pre_layer_norm_grad_bias_grad_node, + pre_layer_norm_grad_x_grad_node}); } - // pre layer norm - auto* pre_layer_norm_grad_node = - pattern->NewNode(pre_layer_norm_grad_op_repr()) - ->assert_is_op("layer_norm_grad"); - auto* pre_layer_norm_grad_scale_node = - pattern->NewNode(pre_layer_norm_grad_scale_repr()) - ->assert_is_op_input("layer_norm_grad", "Scale"); - auto* pre_layer_norm_grad_bias_node = - pattern->NewNode(pre_layer_norm_grad_bias_repr()) - ->assert_is_op_input("layer_norm_grad", "Bias"); - auto* pre_layer_norm_grad_mean_node = - pattern->NewNode(pre_layer_norm_grad_mean_repr()) - ->assert_is_op_input("layer_norm_grad", "Mean"); - auto* pre_layer_norm_grad_variance_node = - pattern->NewNode(pre_layer_norm_grad_variance_repr()) - ->assert_is_op_input("layer_norm_grad", "Variance"); - auto* pre_layer_norm_grad_x_node = - add_residual ? residual_ele_add_grad_x_node - : pattern->NewNode(pre_layer_norm_grad_x_repr()) - ->assert_is_op_input("layer_norm_grad", "X"); - auto* pre_layer_norm_grad_scale_grad_node = - pattern->NewNode(pre_layer_norm_grad_scale_grad_repr()) - ->assert_is_op_output("layer_norm_grad", "Scale@GRAD"); - auto* pre_layer_norm_grad_bias_grad_node = - pattern->NewNode(pre_layer_norm_grad_bias_grad_repr()) - ->assert_is_op_output("layer_norm_grad", "Bias@GRAD"); - auto* pre_layer_norm_grad_x_grad_node = - pattern->NewNode(pre_layer_norm_grad_x_grad_repr()) - ->assert_is_op_output("layer_norm_grad", "X@GRAD"); - fuse_qkv_matmul_grad_x_grad_node->assert_is_op_input("layer_norm_grad", - "Y@GRAD"); - pre_layer_norm_grad_node - ->LinksFrom({fuse_qkv_matmul_grad_x_grad_node, - pre_layer_norm_grad_scale_node, - pre_layer_norm_grad_bias_node, - pre_layer_norm_grad_mean_node, - pre_layer_norm_grad_variance_node, - pre_layer_norm_grad_x_node}) - .LinksTo({pre_layer_norm_grad_scale_grad_node, - pre_layer_norm_grad_bias_grad_node, - pre_layer_norm_grad_x_grad_node}); + PDNode* grad_accumulation_x_input_node = fuse_qkv_matmul_grad_x_grad_node; + if (pre_layer_norm) { + grad_accumulation_x_input_node = pre_layer_norm_grad_x_grad_node; + } else if (use_mp) { + grad_accumulation_x_input_node = mp_allreduce_out_node; + } if (!add_residual) { - return pre_layer_norm_grad_x_grad_node; + return grad_accumulation_x_input_node; } auto* grad_accumulation_sum_node = @@ -754,9 +826,11 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x, auto* grad_accumulation_sum_out_node = pattern->NewNode(grad_accumulation_out_repr()) ->assert_is_op_output("sum"); + residual_ele_add_grad_x_grad_node->assert_is_op_input("sum"); + grad_accumulation_x_input_node->assert_is_op_input("sum"); grad_accumulation_sum_node ->LinksFrom( - {pre_layer_norm_grad_x_grad_node, residual_ele_add_grad_x_grad_node}) + {grad_accumulation_x_input_node, residual_ele_add_grad_x_grad_node}) .LinksTo({grad_accumulation_sum_out_node}); return grad_accumulation_sum_out_node; @@ -771,10 +845,64 @@ void FusedAttentionsPass::ApplyImpl(Graph* graph) const { graph = PreMaskDropResFwd(graph, &cache); graph = PreMaskDropResBwd(graph, &cache); cache.ResetCache(); + + graph = PreMaskDropResMPFwd(graph, &cache); + graph = PreMaskDropResMPBwd(graph, &cache); + cache.ResetCache(); } ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( Graph* graph, FusedAttentionPassCache* cache) const { + return ForwardHandlerHelper(graph, + cache, + /* pre_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true, + /* use_mp */ false); +} + +ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( + Graph* graph, FusedAttentionPassCache* cache) const { + return BackwardHandlerHelper(graph, + cache, + /* pre_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true, + /* use_mp */ false); +} + +ir::Graph* FusedAttentionsPass::PreMaskDropResMPFwd( + Graph* graph, FusedAttentionPassCache* cache) const { + return ForwardHandlerHelper(graph, + cache, + /* pre_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true, + /* use_mp */ true); +} + +ir::Graph* FusedAttentionsPass::PreMaskDropResMPBwd( + Graph* graph, FusedAttentionPassCache* cache) const { + return BackwardHandlerHelper(graph, + cache, + /* pre_layer_norm */ true, + /* has_attn_mask */ true, + /* do_dropout */ true, + /* add_residual */ true, + /* use_mp */ true); +} + +ir::Graph* FusedAttentionsPass::ForwardHandlerHelper( + Graph* graph, + FusedAttentionPassCache* cache, + bool pre_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual, + bool use_mp) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) @@ -783,11 +911,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( patterns::FusedAttentionPattern fused_attention_pattern( gpd.mutable_pattern(), "fused_attention_pattern"); - fused_attention_pattern(x, - /* pre_layer_norm */ true, - /* has_attn_mask */ true, - /* do_dropout */ true, - /* add_residual */ true); + fused_attention_pattern( + x, pre_layer_norm, has_attn_mask, do_dropout, add_residual, use_mp); int found_fused_attention = 0; @@ -840,10 +965,44 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( GET_IR_NODE_FROM_SUBGRAPH( fuse_qkv_matmul_w_node, fuse_qkv_matmul_w, fused_attention_pattern); + + std::unordered_set remove_nodes = {pre_layer_norm_op_node, + fuse_qkv_matmul_op_node, + fuse_qkv_ele_add_op_node, + fuse_qkv_reshape_op_node, + fuse_qkv_transpose_op_node, + fuse_qkv_split_op_node, + qk_matmul_op_node, + qk_scale_op_node, + add_mask_ele_add_op_node, + qk_softmax_op_node, + attn_dropout_op_node, + qkv_matmul_op_node, + qkv_transpose_op_node, + qkv_reshape_op_node, + out_linear_matmul_op_node, + out_linear_ele_add_op_node, + out_linear_dropout_op_node, + residual_ele_add_op_node}; + + int ring_id = -1; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH( + c_identity_op_node, c_identity_op, fused_attention_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mp_allreudce_sum_op_node, + mp_allreudce_sum_op, + fused_attention_pattern); + remove_nodes.insert(c_identity_op_node); + remove_nodes.insert(mp_allreudce_sum_op_node); + ring_id = PADDLE_GET_CONST( + int, mp_allreudce_sum_op_node->Op()->GetAttr("ring_id")); + } + std::string cache_anchor_name = fuse_qkv_matmul_w_node->Var()->Name(); OpDesc fused_attention_op_desc(pre_layer_norm_op_node->Op()->Block()); fused_attention_op_desc.SetType("fused_attention"); + fused_attention_op_desc.SetAttr("ring_id", ring_id); fused_attention_op_desc.SetInput("X", {subgraph.at(x)->Name()}); cache->InsertIntoCache(GenerateCacheKey(cache_anchor_name, "X", block_id), subgraph.at(x)); @@ -1090,25 +1249,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( IR_NODE_LINK_TO(fused_attention_node, out_linear_dropout_mask_node); IR_NODE_LINK_TO(fused_attention_node, residual_ele_add_out_node); - GraphSafeRemoveNodes(g, - {pre_layer_norm_op_node, - fuse_qkv_matmul_op_node, - fuse_qkv_ele_add_op_node, - fuse_qkv_reshape_op_node, - fuse_qkv_transpose_op_node, - fuse_qkv_split_op_node, - qk_matmul_op_node, - qk_scale_op_node, - add_mask_ele_add_op_node, - qk_softmax_op_node, - attn_dropout_op_node, - qkv_matmul_op_node, - qkv_transpose_op_node, - qkv_reshape_op_node, - out_linear_matmul_op_node, - out_linear_ele_add_op_node, - out_linear_dropout_op_node, - residual_ele_add_op_node}); + GraphSafeRemoveNodes(g, remove_nodes); found_fused_attention++; }; @@ -1119,8 +1260,14 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd( return graph; } -ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( - Graph* graph, FusedAttentionPassCache* cache) const { +ir::Graph* FusedAttentionsPass::BackwardHandlerHelper( + Graph* graph, + FusedAttentionPassCache* cache, + bool pre_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual, + bool use_mp) const { GraphPatternDetector gpd; auto* x = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "x")) @@ -1129,11 +1276,8 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( patterns::FusedAttentionGradPattern fused_attention_grad_pattern( gpd.mutable_pattern(), "fused_attention_grad_pattern"); - fused_attention_grad_pattern(x, - /* pre_layer_norm */ true, - /* has_attn_mask */ true, - /* do_dropout */ true, - /* add_residual */ true); + fused_attention_grad_pattern( + x, pre_layer_norm, has_attn_mask, do_dropout, add_residual, use_mp); int found_fused_attention = 0; @@ -1200,6 +1344,41 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( grad_accumulation_sum_op, fused_attention_grad_pattern); + std::unordered_set remove_nodes = { + residual_ele_add_grad_op_node, + out_linear_dropout_grad_op_node, + out_linear_ele_add_grad_op_node, + out_linear_matmul_grad_op_node, + qkv_reshape_grad_op_node, + qkv_transpose_grad_op_node, + qkv_matmul_grad_op_node, + attn_dropout_grad_op_node, + qk_softmax_grad_op_node, + add_mask_ele_add_grad_op_node, + qk_scale_grad_op_node, + qk_matmul_grad_op_node, + fuse_qkv_split_grad_op_node, + fuse_qkv_transpose_grad_op_node, + fuse_qkv_reshape_grad_op_node, + fuse_qkv_ele_add_grad_op_node, + fuse_qkv_matmul_grad_op_node, + pre_layer_norm_grad_op_node, + grad_accumulation_sum_op_node}; + + int ring_id = -1; + if (use_mp) { + GET_IR_NODE_FROM_SUBGRAPH(mp_allreudce_sum_grad_op_node, + mp_allreudce_sum_grad_op, + fused_attention_grad_pattern); + GET_IR_NODE_FROM_SUBGRAPH(c_identity_grad_op_node, + c_identity_grad_op, + fused_attention_grad_pattern); + remove_nodes.insert(mp_allreudce_sum_grad_op_node); + remove_nodes.insert(c_identity_grad_op_node); + ring_id = PADDLE_GET_CONST( + int, mp_allreudce_sum_grad_op_node->Op()->GetAttr("ring_id")); + } + OpDesc fused_attention_grad_op_desc( residual_ele_add_grad_op_node->Op()->Block()); fused_attention_grad_op_desc.SetType("fused_attention_grad"); @@ -1347,7 +1526,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( "ln_epsilon", PADDLE_GET_CONST( float, pre_layer_norm_grad_op_node->Op()->GetAttr("epsilon"))); - fused_attention_grad_op_desc.SetAttr("ring_id", -1); + fused_attention_grad_op_desc.SetAttr("ring_id", ring_id); GET_IR_NODE_FROM_SUBGRAPH(qkv_matmul_grad_x_grad_node, qkv_matmul_grad_x_grad, @@ -1497,26 +1676,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd( IR_NODE_LINK_TO(src_mask_out_node, fused_attention_grad_node); IR_NODE_LINK_TO(transpose_out_2_node, fused_attention_grad_node); - GraphSafeRemoveNodes(g, - {residual_ele_add_grad_op_node, - out_linear_dropout_grad_op_node, - out_linear_ele_add_grad_op_node, - out_linear_matmul_grad_op_node, - qkv_reshape_grad_op_node, - qkv_transpose_grad_op_node, - qkv_matmul_grad_op_node, - attn_dropout_grad_op_node, - qk_softmax_grad_op_node, - add_mask_ele_add_grad_op_node, - qk_scale_grad_op_node, - qk_matmul_grad_op_node, - fuse_qkv_split_grad_op_node, - fuse_qkv_transpose_grad_op_node, - fuse_qkv_reshape_grad_op_node, - fuse_qkv_ele_add_grad_op_node, - fuse_qkv_matmul_grad_op_node, - pre_layer_norm_grad_op_node, - grad_accumulation_sum_op_node}); + GraphSafeRemoveNodes(g, remove_nodes); found_fused_attention++; }; diff --git a/paddle/fluid/framework/ir/fused_attention_pass.h b/paddle/fluid/framework/ir/fused_attention_pass.h index 222900860a7bd80bcd9dec8c456fcfbcd7a99199..79d051f6dad6d1683c68403c0a3dc31a420f430e 100644 --- a/paddle/fluid/framework/ir/fused_attention_pass.h +++ b/paddle/fluid/framework/ir/fused_attention_pass.h @@ -33,6 +33,7 @@ namespace patterns { // 2. Add attn mask for qk product before the softmax or not. // 3. Do attn dropout or not. // 4. Add residual to the out linear result or not. +// 5. Use model tensor parallel or not. struct FusedAttentionPattern : public PatternBase { FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "fused_attention_pattern") {} @@ -41,7 +42,8 @@ struct FusedAttentionPattern : public PatternBase { bool pre_layer_norm, // do pre ln or not bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not - bool add_residual); // add residual to out linear or not + bool add_residual, // add residual to out linear or not + bool use_mp); // use tensor parallel or not // pre layer norm PATTERN_DECL_NODE(pre_layer_norm_op); @@ -51,6 +53,10 @@ struct FusedAttentionPattern : public PatternBase { PATTERN_DECL_NODE(pre_layer_norm_mean); PATTERN_DECL_NODE(pre_layer_norm_variance); + // c_identity for mp + PATTERN_DECL_NODE(c_identity_op); + PATTERN_DECL_NODE(c_identity_out); + // fuse qkv projection PATTERN_DECL_NODE(fuse_qkv_matmul_op); PATTERN_DECL_NODE(fuse_qkv_matmul_w); @@ -111,6 +117,10 @@ struct FusedAttentionPattern : public PatternBase { PATTERN_DECL_NODE(out_linear_ele_add_bias); PATTERN_DECL_NODE(out_linear_ele_add_out); + // allreudce for mp + PATTERN_DECL_NODE(mp_allreudce_sum_op); + PATTERN_DECL_NODE(mp_allreudce_sum_out); + PATTERN_DECL_NODE(out_linear_dropout_op); PATTERN_DECL_NODE(out_linear_dropout_out); PATTERN_DECL_NODE(out_linear_dropout_mask); @@ -131,13 +141,14 @@ struct FusedAttentionPattern : public PatternBase { // Declare the grad pattern for multi head attention struct FusedAttentionGradPattern : public PatternBase { FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "fused_attention_pattern") {} + : PatternBase(pattern, name_scope, "fused_attention_grad_pattern") {} PDNode* operator()(PDNode* x, bool pre_layer_norm, // pre ln bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not - bool add_residual); // add residual to out linear or not + bool add_residual, // add residual to out linear or not + bool use_mp); // use tensor parallel or not // post layer norm grad PATTERN_DECL_NODE(post_layer_norm_grad_op); @@ -162,6 +173,10 @@ struct FusedAttentionGradPattern : public PatternBase { PATTERN_DECL_NODE(out_linear_dropout_grad_mask); PATTERN_DECL_NODE(out_linear_dropout_grad_out); + // c_identity for mp + PATTERN_DECL_NODE(mp_allreudce_sum_grad_op); // c_identity + PATTERN_DECL_NODE(mp_allreudce_sum_grad_out); + PATTERN_DECL_NODE(out_linear_ele_add_grad_op); PATTERN_DECL_NODE(out_linear_ele_add_grad_x); PATTERN_DECL_NODE(out_linear_ele_add_grad_bias); @@ -235,6 +250,10 @@ struct FusedAttentionGradPattern : public PatternBase { PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad); + // allreduce for mp + PATTERN_DECL_NODE(c_identity_grad_op); // mp_allreduce_sum + PATTERN_DECL_NODE(c_identity_grad_out); + // pre layer norm grad PATTERN_DECL_NODE(pre_layer_norm_grad_op); PATTERN_DECL_NODE(pre_layer_norm_grad_scale); @@ -296,6 +315,7 @@ class FusedAttentionsPass : public FusePassBase { // 4. Add residual? [Res] // 5. Do post layer norm? [Post] // 6. Forward or Backward? [Fwd/Bwd] + // 7. Use tensor model parallel? [MP] // If true, the function name will have an abbreviation part. // If false, the function name won't contain an abbreviation for it. @@ -305,6 +325,28 @@ class FusedAttentionsPass : public FusePassBase { ir::Graph* PreMaskDropResBwd(Graph* graph, FusedAttentionPassCache* cache) const; + ir::Graph* PreMaskDropResMPFwd(Graph* graph, + FusedAttentionPassCache* cache) const; + + ir::Graph* PreMaskDropResMPBwd(Graph* graph, + FusedAttentionPassCache* cache) const; + + ir::Graph* ForwardHandlerHelper(Graph* graph, + FusedAttentionPassCache* cache, + bool pre_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual, + bool use_mp) const; + + ir::Graph* BackwardHandlerHelper(Graph* graph, + FusedAttentionPassCache* cache, + bool pre_layer_norm, + bool has_attn_mask, + bool do_dropout, + bool add_residual, + bool use_mp) const; + const std::string GenerateCacheKey(const std::string anchor, const std::string var_name, int block_id) const { diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 347d1ba25215081eedb44c552e965ed65d1d921f..8ae1a60ad3b940177a18c878ef7c980c345467d7 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -120,6 +120,7 @@ class FusedAttentionOp : public framework::OperatorWithKernel { auto y_dim = ctx->GetInputDim("QKVW"); int dim_head; int hidden_size; + int nranks = 1; if (transpose_qkv_wb) { PADDLE_ENFORCE_EQ(y_dim.size(), 2, @@ -149,8 +150,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "The dimensions of qkv_weight must be 2" "(dim_embed, 3 * dim_embed).")); + } else { + // compute the mp nranks + nranks = (y_dim[0] * 3) / y_dim[1]; } - dim_head = y_dim[0] / num_heads; + dim_head = y_dim[0] / (num_heads * nranks); hidden_size = y_dim[0]; } else { PADDLE_ENFORCE_EQ(y_dim.size(), @@ -210,11 +214,13 @@ class FusedAttentionOp : public framework::OperatorWithKernel { } if (transpose_qkv_wb) { - // [batch_size, seq_len, 3 * hidden_size] - ctx->SetOutputDim("QKVOut", {x_dim[0], x_dim[1], 3 * hidden_size}); + // [batch_size, seq_len, 3 * num_heads * dim_head] + ctx->SetOutputDim("QKVOut", + {x_dim[0], x_dim[1], 3 * num_heads * dim_head}); if (ctx->HasInput("QKVBias")) { - ctx->SetOutputDim("QKVBiasOut", {x_dim[0], x_dim[1], 3 * hidden_size}); + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], 3 * num_heads * dim_head}); } } else { // [batch_size, seq_len, 3, num_head, head_size] diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 91dbf71bbc845827b942c98505d18a24da64eb3f..d4f1a543984026f716af4a563ce494ab8fd95920 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -217,13 +217,15 @@ class FusedAttentionOpKernel : public framework::OpKernel { int num_head; int dim_head; + int nranks = 1; // get num_head and dim_head in two different ways if (!transpose_qkv_wb) { num_head = qkv_w_dims[1]; dim_head = qkv_w_dims[2]; } else { + nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1]; num_head = num_heads; - dim_head = dim_embed / num_head; + dim_head = dim_embed / (num_head * nranks); } int bsz_seq = batch_size * max_seq_len; @@ -579,12 +581,14 @@ class FusedAttentionGradKernel : public framework::OpKernel { int dim_embed = input_x_dims[2]; int num_head; int dim_head; + int nranks = 1; if (!transpose_qkv_wb) { num_head = qkv_w_dims[1]; dim_head = qkv_w_dims[2]; } else { + nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1]; num_head = num_heads; - dim_head = dim_embed / num_head; + dim_head = dim_embed / (num_head * nranks); } int bsz_seq = batch_size * max_seq_len; diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index d8fa0fb0285958cab5dc61dbef0cce4524a204c6..c61705bdc15c2868fe271b662311063929fe18b4 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -908,3 +908,15 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_dygraph_save_for_auto_infer PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST") endif() +if(WITH_GPU) + bash_test_modules( + test_fused_attention_pass_with_mp + START_BASH + test_fused_attention_pass_with_mp.sh + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=") + set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT + "120") +endif() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/fused_attention_pass_with_mp.py b/python/paddle/fluid/tests/unittests/collective/fleet/fused_attention_pass_with_mp.py new file mode 100644 index 0000000000000000000000000000000000000000..b3dc61ce9e5aeeb1fce1a3b190fc7a5f6ff87018 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/fused_attention_pass_with_mp.py @@ -0,0 +1,241 @@ +# Copyright (c) 2013 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np + +import paddle +import paddle.distributed.fleet as fleet +import paddle.fluid as fluid +import paddle.nn.functional as F +from paddle.distributed.passes import PassManager, new_pass + +paddle.enable_static() + + +class MultiHeadAttentionWithMP(paddle.nn.Layer): + def __init__( + self, + embed_dim, + num_heads, + add_residual=True, + pre_ln=True, + attn_dropout=True, + ): + super(MultiHeadAttentionWithMP, self).__init__() + self.embed_dim = embed_dim + self.kdim = embed_dim + self.vdim = embed_dim + self.num_heads = num_heads + + self.add_residual = add_residual + self.pre_ln = pre_ln + self.attn_dropout = attn_dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + assert num_heads % 2 == 0 + self.num_heads = num_heads // 2 + + self.norm1 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5) + self.norm2 = paddle.nn.LayerNorm(embed_dim, epsilon=1e-5) + + self.qkv_proj = paddle.nn.Linear( + embed_dim, 3 * self.num_heads * self.head_dim + ) + self.out_proj = paddle.nn.Linear( + self.num_heads * self.head_dim, embed_dim + ) + self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train") + + def forward(self, x, attn_mask=None): + residual = x + + if self.pre_ln: + # pre layer norm + x = self.norm1(x) + + x = paddle.distributed.collective._c_identity(x) + + # compute qkv + qkv = self.qkv_proj(x) + qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim]) + qkv = paddle.transpose(qkv, [0, 2, 1, 3]) + q, k, v = paddle.split(qkv, num_or_sections=3, axis=1) + + # compute core attention + q = paddle.scale(q, scale=self.head_dim**-0.5) + product = paddle.matmul(x=q, y=k, transpose_y=True) + if attn_mask is not None: + product = product + attn_mask + weights = F.softmax(product) + if self.attn_dropout: + weights = F.dropout( + weights, 0.1, training=self.training, mode="upscale_in_train" + ) + out = paddle.matmul(weights, v) + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + out = paddle.distributed.collective._mp_allreduce( + out, use_calc_stream=True, use_model_parallel=True + ) + out = self.dropout(out) + if self.add_residual: + out = residual + out + + if not self.pre_ln: + # post layer norm + out = self.norm2(out) + + return out + + +class TestFusedAttentionPassWithMP(unittest.TestCase): + def setUp(self): + fleet.init() + self.endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',') + self.current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") + self.nranks = len(self.endpoints) + self.rank = self.endpoints.index(self.current_endpoint) + self.gpu_id = int(os.getenv("FLAGS_selected_gpus")) + self.place = fluid.CUDAPlace(self.gpu_id) + self.exe = fluid.Executor(self.place) + self.endpoints.remove(self.current_endpoint) + self.other_endpoints = self.endpoints + self.add_residual = True + self.pre_ln = True + self.attn_dropout = True + self.add_mask = True + self.x_data = None + self.mask_data = None + + def get_rst(self, use_pass=False): + batch_size = 2 + seq_len = 1024 + hidden_size = 768 + num_heads = 12 + + np.random.seed(1234) + if self.x_data is None: + self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype( + 'float32' + ) + self.mask_data = np.random.rand( + batch_size, num_heads // 2, seq_len, seq_len + ).astype('float32') + + main_prog = paddle.static.Program() + main_prog.random_seed = 1234 + startup_prog = paddle.static.Program() + startup_prog.random_seed = 1234 + + with paddle.static.program_guard(main_prog, startup_prog): + data = paddle.static.data( + name="x", + shape=[-1, seq_len, seq_len], + dtype='float32', + ) + if self.add_mask: + attn_mask = paddle.static.data( + name="attn_mask", + shape=[-1, num_heads // 2, seq_len, seq_len], + dtype='float32', + ) + else: + attn_mask = None + + data_linear = paddle.nn.Linear(seq_len, hidden_size) + multi_head_attn = MultiHeadAttentionWithMP( + hidden_size, + num_heads, + add_residual=self.add_residual, + pre_ln=self.pre_ln, + attn_dropout=self.attn_dropout, + ) + + attn_input = data_linear(data) + out = multi_head_attn(attn_input, attn_mask) + loss = paddle.mean(out) + + sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(loss) + + startup_block = startup_prog.global_block() + nccl_id_var = startup_block.create_var( + name=fluid.unique_name.generate('nccl_id'), + persistable=True, + type=fluid.core.VarDesc.VarType.RAW, + ) + startup_block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': nccl_id_var}, + attrs={ + 'rank': self.rank, + 'endpoint': self.current_endpoint, + 'other_endpoints': self.other_endpoints, + }, + ) + startup_block.append_op( + type='c_comm_init', + inputs={'X': nccl_id_var}, + outputs={}, + attrs={ + 'nranks': self.nranks, + 'rank': self.rank, + 'ring_id': 0, + 'device_id': self.gpu_id, + }, + ) + + if use_pass: + pass_manager = PassManager([new_pass("fused_attention")]) + pass_manager.apply([main_prog], [startup_prog]) + + ops = main_prog.global_block().ops + assert ops[2].type == 'fused_attention' + assert ops[3].type == 'reduce_mean' + assert ops[5].type == 'reduce_mean_grad' + assert ops[6].type == 'fused_attention_grad' + # two ops for linear, one op for reduce mean + # one fill constant + # one op for reduce mean grad, two ops for linear bwd + # the eighth op should be the optimizer + assert ops[9].type == 'sgd' + + self.exe.run(startup_prog) + for i in range(2): + rst = self.exe.run( + main_prog, + feed={'x': self.x_data, 'attn_mask': self.mask_data}, + fetch_list=[loss], + ) + + return rst + + def test_pass(self): + fused_rst = self.get_rst(use_pass=True) + non_fused_rst = self.get_rst() + assert np.allclose(fused_rst, non_fused_rst, atol=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fused_attention_pass_with_mp.sh b/python/paddle/fluid/tests/unittests/collective/fleet/test_fused_attention_pass_with_mp.sh new file mode 100644 index 0000000000000000000000000000000000000000..d00f2fdbac0e1d85325c1c24c66e005273f2bbc2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fused_attention_pass_with_mp.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +# use default values +# FIXME: random fails on Unknown command lines -c (or -m). +CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch fused_attention_pass_with_mp.py diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv index 954e6b7b42a75702ab5cf15bb40066fc28eeb449..d63c61bfa1a7c56f70c391c83a2d264dd8981429 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/fleet/testslist.csv @@ -58,6 +58,7 @@ test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU;ASCEND;ASCEND_CL,,,test_ test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_new_group,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=, test_c_comm_init_op,LINUX,GPU;XPU;ASCEND;ASCEND_CL,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=, +test_fused_attention_pass_with_mp,LINUX,GPU;;;,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,