// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { namespace ir { namespace patterns { // Declare patterns for multi head attention. // Can detect: // 1. Pre layer norm or post layer norm. // 2. Add attn mask for qk product before the softmax or not. // 3. Do attn dropout or not. // 4. Add residual to the out linear result or not. // 5. Use model tensor parallel or not. struct FusedAttentionPattern : public PatternBase { FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "fused_attention_pattern") {} PDNode* operator()(PDNode* x, bool pre_layer_norm, // do pre ln or not bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not bool add_residual, // add residual to out linear or not bool use_mp); // use tensor parallel or not // pre layer norm PATTERN_DECL_NODE(pre_layer_norm_op); PATTERN_DECL_NODE(pre_layer_norm_scale); PATTERN_DECL_NODE(pre_layer_norm_bias); PATTERN_DECL_NODE(pre_layer_norm_out); PATTERN_DECL_NODE(pre_layer_norm_mean); PATTERN_DECL_NODE(pre_layer_norm_variance); // c_identity for mp PATTERN_DECL_NODE(c_identity_op); PATTERN_DECL_NODE(c_identity_out); // fuse qkv projection PATTERN_DECL_NODE(fuse_qkv_matmul_op); PATTERN_DECL_NODE(fuse_qkv_matmul_w); PATTERN_DECL_NODE(fuse_qkv_matmul_out); PATTERN_DECL_NODE(fuse_qkv_ele_add_op); PATTERN_DECL_NODE(fuse_qkv_ele_add_bias); PATTERN_DECL_NODE(fuse_qkv_ele_add_out); PATTERN_DECL_NODE(fuse_qkv_reshape_op); PATTERN_DECL_NODE(fuse_qkv_reshape_out); PATTERN_DECL_NODE(fuse_qkv_reshape_x_shape); PATTERN_DECL_NODE(fuse_qkv_transpose_op); PATTERN_DECL_NODE(fuse_qkv_transpose_out); PATTERN_DECL_NODE(fuse_qkv_transpose_x_shape); PATTERN_DECL_NODE(fuse_qkv_split_op); PATTERN_DECL_NODE(fuse_qkv_split_out_q); // q PATTERN_DECL_NODE(fuse_qkv_split_out_k); // k PATTERN_DECL_NODE(fuse_qkv_split_out_v); // v // core attention PATTERN_DECL_NODE(qk_matmul_op); PATTERN_DECL_NODE(qk_matmul_out); PATTERN_DECL_NODE(qk_scale_op); PATTERN_DECL_NODE(qk_scale_out); PATTERN_DECL_NODE(add_mask_ele_add_op); PATTERN_DECL_NODE(add_mask_ele_add_mask); PATTERN_DECL_NODE(add_mask_ele_add_out); PATTERN_DECL_NODE(qk_softmax_op); PATTERN_DECL_NODE(qk_softmax_out); PATTERN_DECL_NODE(attn_dropout_op); PATTERN_DECL_NODE(attn_dropout_out); PATTERN_DECL_NODE(attn_dropout_mask); PATTERN_DECL_NODE(qkv_matmul_op); PATTERN_DECL_NODE(qkv_matmul_out); PATTERN_DECL_NODE(qkv_transpose_op); PATTERN_DECL_NODE(qkv_transpose_out); PATTERN_DECL_NODE(qkv_transpose_x_shape); PATTERN_DECL_NODE(qkv_reshape_op); PATTERN_DECL_NODE(qkv_reshape_out); PATTERN_DECL_NODE(qkv_reshape_x_shape); // out linear PATTERN_DECL_NODE(out_linear_matmul_op); PATTERN_DECL_NODE(out_linear_matmul_w); PATTERN_DECL_NODE(out_linear_matmul_out); PATTERN_DECL_NODE(out_linear_ele_add_op); PATTERN_DECL_NODE(out_linear_ele_add_bias); PATTERN_DECL_NODE(out_linear_ele_add_out); // allreudce for mp PATTERN_DECL_NODE(mp_allreudce_sum_op); PATTERN_DECL_NODE(mp_allreudce_sum_out); PATTERN_DECL_NODE(out_linear_dropout_op); PATTERN_DECL_NODE(out_linear_dropout_out); PATTERN_DECL_NODE(out_linear_dropout_mask); // residual PATTERN_DECL_NODE(residual_ele_add_op); PATTERN_DECL_NODE(residual_ele_add_out); // post layer norm PATTERN_DECL_NODE(post_layer_norm_op); PATTERN_DECL_NODE(post_layer_norm_scale); PATTERN_DECL_NODE(post_layer_norm_bias); PATTERN_DECL_NODE(post_layer_norm_out); PATTERN_DECL_NODE(post_layer_norm_mean); PATTERN_DECL_NODE(post_layer_norm_variance); }; // Declare the grad pattern for multi head attention struct FusedAttentionGradPattern : public PatternBase { FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "fused_attention_grad_pattern") {} PDNode* operator()(PDNode* x, bool pre_layer_norm, // pre ln bool has_attn_mask, // add attn mask to qk or not bool do_dropout, // dropout the softmax(qk) or not bool add_residual, // add residual to out linear or not bool use_mp); // use tensor parallel or not // post layer norm grad PATTERN_DECL_NODE(post_layer_norm_grad_op); PATTERN_DECL_NODE(post_layer_norm_grad_scale); PATTERN_DECL_NODE(post_layer_norm_grad_bias); PATTERN_DECL_NODE(post_layer_norm_grad_mean); PATTERN_DECL_NODE(post_layer_norm_grad_variance); PATTERN_DECL_NODE(post_layer_norm_grad_x); PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad); PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad); PATTERN_DECL_NODE(post_layer_norm_grad_x_grad); // residual grad PATTERN_DECL_NODE(residual_ele_add_grad_op); PATTERN_DECL_NODE(residual_ele_add_grad_x); PATTERN_DECL_NODE(residual_ele_add_grad_bias); PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad); PATTERN_DECL_NODE(residual_ele_add_grad_x_grad); // out linear grad PATTERN_DECL_NODE(out_linear_dropout_grad_op); PATTERN_DECL_NODE(out_linear_dropout_grad_mask); PATTERN_DECL_NODE(out_linear_dropout_grad_out); // c_identity for mp PATTERN_DECL_NODE(mp_allreudce_sum_grad_op); // c_identity PATTERN_DECL_NODE(mp_allreudce_sum_grad_out); PATTERN_DECL_NODE(out_linear_ele_add_grad_op); PATTERN_DECL_NODE(out_linear_ele_add_grad_x); PATTERN_DECL_NODE(out_linear_ele_add_grad_bias); PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad); PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad); PATTERN_DECL_NODE(out_linear_matmul_grad_op); PATTERN_DECL_NODE(out_linear_matmul_grad_x); PATTERN_DECL_NODE(out_linear_matmul_grad_w); PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad); PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad); // core attention grad PATTERN_DECL_NODE(qkv_reshape_grad_op); PATTERN_DECL_NODE(qkv_reshape_grad_x_shape); PATTERN_DECL_NODE(qkv_reshape_grad_out); PATTERN_DECL_NODE(qkv_transpose_grad_op); PATTERN_DECL_NODE(qkv_transpose_grad_x_shape); PATTERN_DECL_NODE(qkv_transpose_grad_out); PATTERN_DECL_NODE(qkv_matmul_grad_op); PATTERN_DECL_NODE(qkv_matmul_grad_x); PATTERN_DECL_NODE(qkv_matmul_grad_w); PATTERN_DECL_NODE(qkv_matmul_grad_x_grad); PATTERN_DECL_NODE(qkv_matmul_grad_w_grad); PATTERN_DECL_NODE(attn_dropout_grad_op); PATTERN_DECL_NODE(attn_dropout_grad_mask); PATTERN_DECL_NODE(attn_dropout_grad_out); PATTERN_DECL_NODE(qk_softmax_grad_op); PATTERN_DECL_NODE(qk_softmax_grad_fwd_out); PATTERN_DECL_NODE(qk_softmax_grad_out); PATTERN_DECL_NODE(add_mask_ele_add_grad_op); PATTERN_DECL_NODE(add_mask_ele_add_grad_x); PATTERN_DECL_NODE(add_mask_ele_add_grad_bias); PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad); PATTERN_DECL_NODE(qk_scale_grad_op); PATTERN_DECL_NODE(qk_scale_grad_out); PATTERN_DECL_NODE(qk_matmul_grad_op); PATTERN_DECL_NODE(qk_matmul_grad_x); PATTERN_DECL_NODE(qk_matmul_grad_w); PATTERN_DECL_NODE(qk_matmul_grad_x_grad); PATTERN_DECL_NODE(qk_matmul_grad_w_grad); // fuse qkv projection grad PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op PATTERN_DECL_NODE(fuse_qkv_split_grad_out); PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op); PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape); PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out); PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op); PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape); PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out); PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op); PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x); PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias); PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad); PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad); PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad); // allreduce for mp PATTERN_DECL_NODE(c_identity_grad_op); // mp_allreduce_sum PATTERN_DECL_NODE(c_identity_grad_out); // pre layer norm grad PATTERN_DECL_NODE(pre_layer_norm_grad_op); PATTERN_DECL_NODE(pre_layer_norm_grad_scale); PATTERN_DECL_NODE(pre_layer_norm_grad_bias); PATTERN_DECL_NODE(pre_layer_norm_grad_mean); PATTERN_DECL_NODE(pre_layer_norm_grad_variance); PATTERN_DECL_NODE(pre_layer_norm_grad_x); PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad); PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad); PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad); // grad accumulation PATTERN_DECL_NODE(grad_accumulation_sum_op); PATTERN_DECL_NODE(grad_accumulation_out); }; } // namespace patterns class FusedAttentionPassCache { public: ir::Node* GetNodeFromCache(const std::string name) { if (var_name_to_ir_node_cache_.count(name)) { return var_name_to_ir_node_cache_.find(name)->second; } PADDLE_THROW(platform::errors::InvalidArgument( "The key (%d) of FusedAttentionCache does not exist.", name)); } void InsertIntoCache(const std::string name, ir::Node* node) { if (!var_name_to_ir_node_cache_.count(name)) { var_name_to_ir_node_cache_.insert({name, node}); } else { PADDLE_THROW(platform::errors::AlreadyExists( "The key (%d) of FusedAttentionCache already exist.", name)); } } void ResetCache() { var_name_to_ir_node_cache_.clear(); } private: std::unordered_map var_name_to_ir_node_cache_; }; class FusedAttentionsPass : public FusePassBase { public: virtual ~FusedAttentionsPass() {} protected: void ApplyImpl(Graph* graph) const; const std::string name_scope_{"fused_attention_pass"}; private: // The name rule for the helper function. // The function name will contain at most five parts in order: // 1. Do pre layer norm? [Pre] // 2. Add mask in the core attention part? [Mask] // 3. Do dropout in the core attention part? [Drop] // 4. Add residual? [Res] // 5. Do post layer norm? [Post] // 6. Forward or Backward? [Fwd/Bwd] // 7. Use tensor model parallel? [MP] // If true, the function name will have an abbreviation part. // If false, the function name won't contain an abbreviation for it. ir::Graph* PreMaskDropResFwd(Graph* graph, FusedAttentionPassCache* cache) const; ir::Graph* PreMaskDropResBwd(Graph* graph, FusedAttentionPassCache* cache) const; ir::Graph* PreMaskDropResMPFwd(Graph* graph, FusedAttentionPassCache* cache) const; ir::Graph* PreMaskDropResMPBwd(Graph* graph, FusedAttentionPassCache* cache) const; ir::Graph* ForwardHandlerHelper(Graph* graph, FusedAttentionPassCache* cache, bool pre_layer_norm, bool has_attn_mask, bool do_dropout, bool add_residual, bool use_mp) const; ir::Graph* BackwardHandlerHelper(Graph* graph, FusedAttentionPassCache* cache, bool pre_layer_norm, bool has_attn_mask, bool do_dropout, bool add_residual, bool use_mp) const; const std::string GenerateCacheKey(const std::string anchor, const std::string var_name, int block_id) const { return anchor + "_" + std::to_string(block_id) + "_" + var_name; } }; } // namespace ir } // namespace framework } // namespace paddle