fused_attention_pass.h 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
19
#include <unordered_map>
20 21 22 23 24 25 26 27 28 29 30 31

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

// Declare patterns for multi head attention.
// Can detect:
32
// 1. Pre layer norm or post layer norm.
33 34 35
// 2. Add attn mask for qk product before the softmax or not.
// 3. Do attn dropout or not.
// 4. Add residual to the out linear result or not.
36
// 5. Use model tensor parallel or not.
37 38 39 40 41
struct FusedAttentionPattern : public PatternBase {
  FusedAttentionPattern(PDPattern* pattern, const std::string& name_scope)
      : PatternBase(pattern, name_scope, "fused_attention_pattern") {}

  PDNode* operator()(PDNode* x,
42 43 44
                     bool pre_layer_norm,  // do pre ln or not
                     bool has_attn_mask,   // add attn mask to qk or not
                     bool do_dropout,      // dropout the softmax(qk) or not
45 46
                     bool add_residual,    // add residual to out linear or not
                     bool use_mp);         // use tensor parallel or not
47 48 49 50 51 52 53 54 55

  // pre layer norm
  PATTERN_DECL_NODE(pre_layer_norm_op);
  PATTERN_DECL_NODE(pre_layer_norm_scale);
  PATTERN_DECL_NODE(pre_layer_norm_bias);
  PATTERN_DECL_NODE(pre_layer_norm_out);
  PATTERN_DECL_NODE(pre_layer_norm_mean);
  PATTERN_DECL_NODE(pre_layer_norm_variance);

56 57 58 59
  // c_identity for mp
  PATTERN_DECL_NODE(c_identity_op);
  PATTERN_DECL_NODE(c_identity_out);

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  // fuse qkv projection
  PATTERN_DECL_NODE(fuse_qkv_matmul_op);
  PATTERN_DECL_NODE(fuse_qkv_matmul_w);
  PATTERN_DECL_NODE(fuse_qkv_matmul_out);

  PATTERN_DECL_NODE(fuse_qkv_ele_add_op);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_bias);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_out);

  PATTERN_DECL_NODE(fuse_qkv_reshape_op);
  PATTERN_DECL_NODE(fuse_qkv_reshape_out);
  PATTERN_DECL_NODE(fuse_qkv_reshape_x_shape);

  PATTERN_DECL_NODE(fuse_qkv_transpose_op);
  PATTERN_DECL_NODE(fuse_qkv_transpose_out);
  PATTERN_DECL_NODE(fuse_qkv_transpose_x_shape);

  PATTERN_DECL_NODE(fuse_qkv_split_op);
  PATTERN_DECL_NODE(fuse_qkv_split_out_q);  // q
  PATTERN_DECL_NODE(fuse_qkv_split_out_k);  // k
  PATTERN_DECL_NODE(fuse_qkv_split_out_v);  // v

  // core attention
  PATTERN_DECL_NODE(qk_matmul_op);
  PATTERN_DECL_NODE(qk_matmul_out);

  PATTERN_DECL_NODE(qk_scale_op);
  PATTERN_DECL_NODE(qk_scale_out);

  PATTERN_DECL_NODE(add_mask_ele_add_op);
  PATTERN_DECL_NODE(add_mask_ele_add_mask);
  PATTERN_DECL_NODE(add_mask_ele_add_out);

  PATTERN_DECL_NODE(qk_softmax_op);
  PATTERN_DECL_NODE(qk_softmax_out);

  PATTERN_DECL_NODE(attn_dropout_op);
  PATTERN_DECL_NODE(attn_dropout_out);
  PATTERN_DECL_NODE(attn_dropout_mask);

  PATTERN_DECL_NODE(qkv_matmul_op);
  PATTERN_DECL_NODE(qkv_matmul_out);

  PATTERN_DECL_NODE(qkv_transpose_op);
  PATTERN_DECL_NODE(qkv_transpose_out);
  PATTERN_DECL_NODE(qkv_transpose_x_shape);

  PATTERN_DECL_NODE(qkv_reshape_op);
  PATTERN_DECL_NODE(qkv_reshape_out);
  PATTERN_DECL_NODE(qkv_reshape_x_shape);

  // out linear
  PATTERN_DECL_NODE(out_linear_matmul_op);
  PATTERN_DECL_NODE(out_linear_matmul_w);
  PATTERN_DECL_NODE(out_linear_matmul_out);

  PATTERN_DECL_NODE(out_linear_ele_add_op);
  PATTERN_DECL_NODE(out_linear_ele_add_bias);
  PATTERN_DECL_NODE(out_linear_ele_add_out);

120 121 122 123
  // allreudce for mp
  PATTERN_DECL_NODE(mp_allreudce_sum_op);
  PATTERN_DECL_NODE(mp_allreudce_sum_out);

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  PATTERN_DECL_NODE(out_linear_dropout_op);
  PATTERN_DECL_NODE(out_linear_dropout_out);
  PATTERN_DECL_NODE(out_linear_dropout_mask);

  // residual
  PATTERN_DECL_NODE(residual_ele_add_op);
  PATTERN_DECL_NODE(residual_ele_add_out);

  // post layer norm
  PATTERN_DECL_NODE(post_layer_norm_op);
  PATTERN_DECL_NODE(post_layer_norm_scale);
  PATTERN_DECL_NODE(post_layer_norm_bias);
  PATTERN_DECL_NODE(post_layer_norm_out);
  PATTERN_DECL_NODE(post_layer_norm_mean);
  PATTERN_DECL_NODE(post_layer_norm_variance);
};

// Declare the grad pattern for multi head attention
struct FusedAttentionGradPattern : public PatternBase {
  FusedAttentionGradPattern(PDPattern* pattern, const std::string& name_scope)
144
      : PatternBase(pattern, name_scope, "fused_attention_grad_pattern") {}
145 146

  PDNode* operator()(PDNode* x,
147 148 149
                     bool pre_layer_norm,  // pre ln
                     bool has_attn_mask,   // add attn mask to qk or not
                     bool do_dropout,      // dropout the softmax(qk) or not
150 151
                     bool add_residual,    // add residual to out linear or not
                     bool use_mp);         // use tensor parallel or not
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  // post layer norm grad
  PATTERN_DECL_NODE(post_layer_norm_grad_op);
  PATTERN_DECL_NODE(post_layer_norm_grad_scale);
  PATTERN_DECL_NODE(post_layer_norm_grad_bias);
  PATTERN_DECL_NODE(post_layer_norm_grad_mean);
  PATTERN_DECL_NODE(post_layer_norm_grad_variance);
  PATTERN_DECL_NODE(post_layer_norm_grad_x);
  PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad);
  PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad);
  PATTERN_DECL_NODE(post_layer_norm_grad_x_grad);

  // residual grad
  PATTERN_DECL_NODE(residual_ele_add_grad_op);
  PATTERN_DECL_NODE(residual_ele_add_grad_x);
  PATTERN_DECL_NODE(residual_ele_add_grad_bias);
  PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad);
  PATTERN_DECL_NODE(residual_ele_add_grad_x_grad);

  // out linear grad
  PATTERN_DECL_NODE(out_linear_dropout_grad_op);
  PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
  PATTERN_DECL_NODE(out_linear_dropout_grad_out);

176 177 178 179
  // c_identity for mp
  PATTERN_DECL_NODE(mp_allreudce_sum_grad_op);  // c_identity
  PATTERN_DECL_NODE(mp_allreudce_sum_grad_out);

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
  PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
  PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
  PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad);
  PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad);

  PATTERN_DECL_NODE(out_linear_matmul_grad_op);
  PATTERN_DECL_NODE(out_linear_matmul_grad_x);
  PATTERN_DECL_NODE(out_linear_matmul_grad_w);
  PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad);
  PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad);

  // core attention grad
  PATTERN_DECL_NODE(qkv_reshape_grad_op);
  PATTERN_DECL_NODE(qkv_reshape_grad_x_shape);
  PATTERN_DECL_NODE(qkv_reshape_grad_out);

  PATTERN_DECL_NODE(qkv_transpose_grad_op);
  PATTERN_DECL_NODE(qkv_transpose_grad_x_shape);
  PATTERN_DECL_NODE(qkv_transpose_grad_out);

  PATTERN_DECL_NODE(qkv_matmul_grad_op);
  PATTERN_DECL_NODE(qkv_matmul_grad_x);
  PATTERN_DECL_NODE(qkv_matmul_grad_w);
  PATTERN_DECL_NODE(qkv_matmul_grad_x_grad);
  PATTERN_DECL_NODE(qkv_matmul_grad_w_grad);

  PATTERN_DECL_NODE(attn_dropout_grad_op);
  PATTERN_DECL_NODE(attn_dropout_grad_mask);
  PATTERN_DECL_NODE(attn_dropout_grad_out);

  PATTERN_DECL_NODE(qk_softmax_grad_op);
  PATTERN_DECL_NODE(qk_softmax_grad_fwd_out);
  PATTERN_DECL_NODE(qk_softmax_grad_out);

  PATTERN_DECL_NODE(add_mask_ele_add_grad_op);
  PATTERN_DECL_NODE(add_mask_ele_add_grad_x);
  PATTERN_DECL_NODE(add_mask_ele_add_grad_bias);
  PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad);

  PATTERN_DECL_NODE(qk_scale_grad_op);
  PATTERN_DECL_NODE(qk_scale_grad_out);

  PATTERN_DECL_NODE(qk_matmul_grad_op);
  PATTERN_DECL_NODE(qk_matmul_grad_x);
  PATTERN_DECL_NODE(qk_matmul_grad_w);
  PATTERN_DECL_NODE(qk_matmul_grad_x_grad);
  PATTERN_DECL_NODE(qk_matmul_grad_w_grad);

  // fuse qkv projection grad
  PATTERN_DECL_NODE(fuse_qkv_split_grad_op);  // concat op
  PATTERN_DECL_NODE(fuse_qkv_split_grad_out);

  PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op);
  PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape);
  PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out);

  PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op);
  PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape);
  PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out);

  PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad);
  PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad);

  PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op);
  PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x);
  PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w);
  PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
  PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);

253 254 255 256
  // allreduce for mp
  PATTERN_DECL_NODE(c_identity_grad_op);  // mp_allreduce_sum
  PATTERN_DECL_NODE(c_identity_grad_out);

257 258 259 260 261 262 263 264 265 266 267 268 269 270
  // pre layer norm grad
  PATTERN_DECL_NODE(pre_layer_norm_grad_op);
  PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
  PATTERN_DECL_NODE(pre_layer_norm_grad_bias);
  PATTERN_DECL_NODE(pre_layer_norm_grad_mean);
  PATTERN_DECL_NODE(pre_layer_norm_grad_variance);
  PATTERN_DECL_NODE(pre_layer_norm_grad_x);
  PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad);
  PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad);
  PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad);

  // grad accumulation
  PATTERN_DECL_NODE(grad_accumulation_sum_op);
  PATTERN_DECL_NODE(grad_accumulation_out);
271 272 273 274
};

}  // namespace patterns

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299
class FusedAttentionPassCache {
 public:
  ir::Node* GetNodeFromCache(const std::string name) {
    if (var_name_to_ir_node_cache_.count(name)) {
      return var_name_to_ir_node_cache_.find(name)->second;
    }
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The key (%d) of FusedAttentionCache does not exist.", name));
  }

  void InsertIntoCache(const std::string name, ir::Node* node) {
    if (!var_name_to_ir_node_cache_.count(name)) {
      var_name_to_ir_node_cache_.insert({name, node});
    } else {
      PADDLE_THROW(platform::errors::AlreadyExists(
          "The key (%d) of FusedAttentionCache already exist.", name));
    }
  }

  void ResetCache() { var_name_to_ir_node_cache_.clear(); }

 private:
  std::unordered_map<std::string, ir::Node*> var_name_to_ir_node_cache_;
};

300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
class FusedAttentionsPass : public FusePassBase {
 public:
  virtual ~FusedAttentionsPass() {}

 protected:
  void ApplyImpl(Graph* graph) const;

  const std::string name_scope_{"fused_attention_pass"};

 private:
  // The name rule for the helper function.
  // The function name will contain at most five parts in order:
  // 1. Do pre layer norm? [Pre]
  // 2. Add mask in the core attention part? [Mask]
  // 3. Do dropout in the core attention part? [Drop]
  // 4. Add residual? [Res]
  // 5. Do post layer norm? [Post]
  // 6. Forward or Backward? [Fwd/Bwd]
318
  // 7. Use tensor model parallel? [MP]
319 320 321
  // If true, the function name will have an abbreviation part.
  // If false, the function name won't contain an abbreviation for it.

322 323 324 325 326
  ir::Graph* PreMaskDropResFwd(Graph* graph,
                               FusedAttentionPassCache* cache) const;

  ir::Graph* PreMaskDropResBwd(Graph* graph,
                               FusedAttentionPassCache* cache) const;
327

328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
  ir::Graph* PreMaskDropResMPFwd(Graph* graph,
                                 FusedAttentionPassCache* cache) const;

  ir::Graph* PreMaskDropResMPBwd(Graph* graph,
                                 FusedAttentionPassCache* cache) const;

  ir::Graph* ForwardHandlerHelper(Graph* graph,
                                  FusedAttentionPassCache* cache,
                                  bool pre_layer_norm,
                                  bool has_attn_mask,
                                  bool do_dropout,
                                  bool add_residual,
                                  bool use_mp) const;

  ir::Graph* BackwardHandlerHelper(Graph* graph,
                                   FusedAttentionPassCache* cache,
                                   bool pre_layer_norm,
                                   bool has_attn_mask,
                                   bool do_dropout,
                                   bool add_residual,
                                   bool use_mp) const;

350 351 352 353 354
  const std::string GenerateCacheKey(const std::string anchor,
                                     const std::string var_name,
                                     int block_id) const {
    return anchor + "_" + std::to_string(block_id) + "_" + var_name;
  }
355 356 357 358 359
};

}  // namespace ir
}  // namespace framework
}  // namespace paddle