fused_gate_attention.h 25.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
21
#include "paddle/phi/kernels/funcs/reduce_function.h"
22 23 24 25 26 27 28 29
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

inline std::string MemoryDebugString(const Tensor& t) {
30 31 32 33 34 35
  int device_id = platform::GetCurrentDeviceId();
  int64_t allocated =
      memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
  int64_t reserved =
      memory::DeviceMemoryStatCurrentValue("Reserved", device_id);

36 37 38
  std::stringstream ss;
  ss << "shape=[" << t.dims()
     << "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
39 40 41 42
     << " MB, ptr=" << t.data()
     << "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
     << " MB"
     << ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
43 44 45
  return ss.str();
}

46 47 48 49 50 51 52
template <typename T>
void AllocWithDebugInfo(const platform::CUDADeviceContext& dev_ctx,
                        const std::string& info, Tensor* t) {
  t->mutable_data<T>(dev_ctx.GetPlace());
  VLOG(4) << info << ": " << MemoryDebugString(*t);
}

53 54 55 56 57 58 59 60
template <typename T>
struct TernaryAddFunctor {
  inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};

template <typename T>
struct GateAttentionConfig {
 public:
61 62 63 64 65
  const platform::CUDADeviceContext& dev_ctx;

  bool merge_qkv;
  bool has_gating;

66 67 68 69 70
  int64_t batch_size;
  int64_t seq_len_m;
  int64_t seq_len_r;
  int64_t q_dim;
  int64_t kv_dim;
71
  int64_t head_dim;
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
  int64_t m_size;
  int64_t num_heads;

  phi::DDim qkv_out_dims;
  phi::DDim qkv_transpose_out_dims;

  phi::DDim q_out_dims;
  phi::DDim kv_out_dims;
  phi::DDim q_transpose_out_dims;
  phi::DDim kv_transpose_out_dims;

  phi::DDim qk_out_dims;
  phi::DDim softmax_out_dims;
  phi::DDim qktv_out_dims;
  phi::DDim gate_out_dims;

88 89
  GateAttentionConfig(const platform::CUDADeviceContext& dev_ctx,
                      const Tensor* query, const Tensor* key,
90
                      const Tensor* query_weight, const Tensor* qkv_weight,
91 92
                      bool merge_qkv, bool has_gating)
      : dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
93 94 95 96 97 98 99 100 101 102 103 104 105
    // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
    batch_size = query->dims()[0];
    seq_len_m = query->dims()[1];
    seq_len_r = query->dims()[2];
    q_dim = query->dims()[3];

    if (merge_qkv) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_weight,
          platform::errors::NotFound("The input qkv_weight can not be nullptr "
                                     "when merge_qkv is true."));

      // When q_dim == kv_dim, QKV matmul can be computed merged.
106
      // qkv_weight: shape=[3, num_heads, head_dim, q_dim]
107
      num_heads = qkv_weight->dims()[1];
108
      head_dim = qkv_weight->dims()[2];
109 110 111
      m_size = seq_len_r;
      kv_dim = q_dim;

112
      qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim};
113
      qkv_transpose_out_dims = {3,         batch_size, seq_len_m,
114
                                num_heads, seq_len_r,  head_dim};
115 116 117 118 119 120 121 122 123 124 125 126
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          key,
          platform::errors::NotFound(
              "The input key can not be nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          query_weight,
          platform::errors::NotFound("The input query_weight can not be "
                                     "nullptr when merge_qkv is false."));

      // When q_dim != kv_dim, QKV matmul must be computed saparately.
      // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
127
      // query_w: shape=[q_dim, num_heads, head_dim]
128
      num_heads = query_weight->dims()[1];
129
      head_dim = query_weight->dims()[2];
130 131 132
      m_size = key->dims()[2];
      kv_dim = key->dims()[3];

133 134
      q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
      kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim};
135
      q_transpose_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r,
136
                              head_dim};
137
      kv_transpose_out_dims = {batch_size, seq_len_m, num_heads, m_size,
138
                               head_dim};
139 140 141 142
    }

    qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
    softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
143 144
    qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
    gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
145 146 147
  }

  int64_t GetQuerySize() const {
148
    return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
149 150
  }

151
  Tensor* GetQKVOut() {
152 153
    if (!qkv_out.IsInitialized()) {
      qkv_out.Resize(qkv_out_dims);
154
      AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
155 156 157 158
    }
    return &qkv_out;
  }

159
  Tensor* GetQueryOut() {
160 161
    if (!query_out.IsInitialized()) {
      query_out.Resize(q_out_dims);
162
      AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
163 164 165 166
    }
    return &query_out;
  }

167
  Tensor* GetKeyOut() {
168 169
    if (!key_out.IsInitialized()) {
      key_out.Resize(kv_out_dims);
170
      AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
171 172 173 174
    }
    return &key_out;
  }

175
  Tensor* GetValueOut() {
176 177
    if (!value_out.IsInitialized()) {
      value_out.Resize(kv_out_dims);
178
      AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
179 180 181 182
    }
    return &value_out;
  }

183
  Tensor* GetQKOut(Tensor* softmax_out) {
184 185 186 187 188 189
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = m_size;
    if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
      // Not sure whether cudnn softmax can execute inplace.
      if (!qkv_out.IsInitialized()) {
        qk_out.Resize(qk_out_dims);
190
        AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
191 192 193
      }
      return &qk_out;
    } else {
194
      // Enable inplace softmax.
195 196 197 198
      return softmax_out;
    }
  }

199 200 201 202 203 204 205 206 207 208 209 210 211 212
  Tensor* GetQKTVOut(Tensor* gate_out) {
    if (has_gating && gate_out) {
      // Reuse gate_out.
      gate_out->Resize(qktv_out_dims);
      return gate_out;
    } else {
      if (!qktv_out.IsInitialized()) {
        qktv_out.Resize(qktv_out_dims);
        AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
      }
      return &qktv_out;
    }
  }

213 214 215 216 217 218 219 220 221 222 223 224
  void ClearQKVOut() {
    if (qkv_out.IsInitialized()) {
      qkv_out.clear();
    }
  }

  void ClearQKOut() {
    if (qk_out.IsInitialized()) {
      qk_out.clear();
    }
  }

225 226 227 228 229 230
  void ClearQKTVOut() {
    if (qktv_out.IsInitialized()) {
      qktv_out.clear();
    }
  }

231 232 233 234 235 236 237 238 239 240
 protected:
  Tensor qkv_out;
  Tensor query_out;
  Tensor key_out;
  Tensor value_out;
  // qk_out = BatchedGEMM(Q, K^T)
  // qk_out: shape=[batch_size, seq_len_m, num_heads, seq_len_r, m_size]
  // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
  // The shape of qk_out, softmax_out is the same, thus can be called inplace.
  Tensor qk_out;
241 242
  // qktv_out may reuse gate_out.
  Tensor qktv_out;
243 244 245 246 247
};

template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> {
 public:
248 249
  GateAttentionGradConfig(const platform::CUDADeviceContext& dev_ctx,
                          const Tensor* query, const Tensor* key,
250
                          const Tensor* query_weight, const Tensor* qkv_weight,
251 252 253
                          bool merge_qkv, bool has_gating)
      : GateAttentionConfig<T>(dev_ctx, query, key, query_weight, qkv_weight,
                               merge_qkv, has_gating) {}
254

255
  Tensor* GetQKVOutGrad() {
256 257
    if (!qkv_out_grad.IsInitialized()) {
      qkv_out_grad.Resize(this->qkv_out_dims);
258
      AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
259 260 261 262
    }
    return &qkv_out_grad;
  }

263
  Tensor* GetQueryOutGrad() {
264 265
    if (!query_out_grad.IsInitialized()) {
      query_out_grad.Resize(this->q_out_dims);
266
      AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
267 268 269 270
    }
    return &query_out_grad;
  }

271
  Tensor* GetKeyOutGrad() {
272 273
    if (!key_out_grad.IsInitialized()) {
      key_out_grad.Resize(this->kv_out_dims);
274
      AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
275 276 277 278
    }
    return &key_out_grad;
  }

279
  Tensor* GetValueOutGrad() {
280 281
    if (!value_out_grad.IsInitialized()) {
      value_out_grad.Resize(this->kv_out_dims);
282
      AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
283 284 285 286
    }
    return &value_out_grad;
  }

287
  Tensor* GetQKOutGrad(Tensor* softmax_out_grad) {
288 289 290
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = this->m_size;
    if (!softmax_out_grad ||
291
        phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
292 293
      if (!qk_out_grad.IsInitialized()) {
        qk_out_grad.Resize(this->qk_out_dims);
294
        AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
      }
      return &qk_out_grad;
    } else {
      return softmax_out_grad;
    }
  }

 protected:
  Tensor qkv_out_grad;
  Tensor query_out_grad;
  Tensor key_out_grad;
  Tensor value_out_grad;
  Tensor qk_out_grad;
};

template <typename T>
class FMHAGateRef {
 public:
  FMHAGateRef(const platform::CUDADeviceContext& dev_ctx, bool merge_qkv)
      : dev_ctx_(dev_ctx), merge_qkv_(merge_qkv) {}

  void ComputeForward(const Tensor* nonbatched_bias, const Tensor* src_mask,
                      Tensor* q_transpose_out, Tensor* k_transpose_out,
                      Tensor* v_transpose_out, Tensor* qkv_transpose_out,
319
                      Tensor* softmax_out, Tensor* fmha_out, Tensor* gate_out,
320 321 322 323 324 325 326 327 328 329 330
                      GateAttentionConfig<T>* config) {
    T* q_ptr = nullptr;
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;
    if (merge_qkv_) {
      // qkv_transpose_out = transpose(qkv_out)
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

331
      Tensor* qkv_out = config->GetQKVOut();
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353
      ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
      config->ClearQKVOut();

      // q_size == k_size
      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

354 355 356
      Tensor* query_out = config->GetQueryOut();
      Tensor* key_out = config->GetKeyOut();
      Tensor* value_out = config->GetValueOut();
357 358 359 360 361 362 363 364 365 366 367
      ComputeQKVTransposeForward(*query_out, *key_out, *value_out,
                                 q_transpose_out, k_transpose_out,
                                 v_transpose_out);

      // q_size != k_size
      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();
    }

    // qk_out = BatchedGEMM(Q, K^T)
368 369
    // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
    //                [batch_size, seq_len_m, num_heads, m_size, head_dim]
370
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
371
    Tensor* qk_out = config->GetQKOut(softmax_out);
372 373 374 375 376 377
    T* qk_out_ptr = qk_out->data<T>();

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    int64_t gemm_m = config->seq_len_r;
    int64_t gemm_n = config->m_size;
378
    int64_t gemm_k = config->head_dim;
379

380
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
381 382 383 384 385 386 387 388 389 390
    ComputeBatchedGEMM(q_ptr, k_ptr, qk_out_ptr, false, true, gemm_m, gemm_n,
                       gemm_k, gemm_batch_size, alpha);

    // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
    ComputeBiasMaskSoftmaxForward(nonbatched_bias, src_mask, qk_out,
                                  softmax_out);
    config->ClearQKOut();

    // qktv_out = BatchedGEMM(softmax_out, V)
    // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
391 392
    //               [batch_size, seq_len_m, num_heads, m_size, head_dim]
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
393 394
    Tensor* qktv_out = config->GetQKTVOut(gate_out);
    T* qktv_out_ptr = qktv_out->data<T>();
395 396

    gemm_m = config->seq_len_r;
397
    gemm_n = config->head_dim;
398 399 400 401 402 403 404
    gemm_k = config->m_size;

    T* softmax_out_ptr = softmax_out->data<T>();
    ComputeBatchedGEMM(softmax_out_ptr, v_ptr, qktv_out_ptr, false, false,
                       gemm_m, gemm_n, gemm_k, gemm_batch_size);

    // fmha_out = transpose(qktv_out)
405 406 407 408 409
    ComputeQKTVTransposeForward(*qktv_out, fmha_out);
    config->ClearQKTVOut();
    if (config->has_gating) {
      gate_out->Resize(config->gate_out_dims);
    }
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442
  }

  void ComputeBackward(const Tensor* q_transpose_out,
                       const Tensor* k_transpose_out,
                       const Tensor* v_transpose_out,
                       const Tensor* qkv_transpose_out,
                       const Tensor* softmax_out, const Tensor* fmha_out_grad,
                       Tensor* src_mask_grad, Tensor* nonbatched_bias_grad,
                       GateAttentionGradConfig<T>* config) {
    const T* q_ptr = nullptr;
    const T* k_ptr = nullptr;
    const T* v_ptr = nullptr;

    T* q_grad_ptr = nullptr;
    T* k_grad_ptr = nullptr;
    T* v_grad_ptr = nullptr;

    Tensor q_transpose_out_grad;
    Tensor k_transpose_out_grad;
    Tensor v_transpose_out_grad;
    Tensor qkv_transpose_out_grad;
    if (merge_qkv_) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;

      qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
443 444
      AllocWithDebugInfo<T>(dev_ctx_, "qkv_transpose_out_grad",
                            &qkv_transpose_out_grad);
445

446
      q_grad_ptr = qkv_transpose_out_grad.data<T>();
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
      k_grad_ptr = q_grad_ptr + q_size;
      v_grad_ptr = k_grad_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();

      q_transpose_out_grad.Resize(config->q_transpose_out_dims);
      k_transpose_out_grad.Resize(config->kv_transpose_out_dims);
      v_transpose_out_grad.Resize(config->kv_transpose_out_dims);

      q_grad_ptr = q_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
      k_grad_ptr = k_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
      v_grad_ptr = v_transpose_out_grad.mutable_data<T>(dev_ctx_.GetPlace());
    }

    Tensor softmax_out_grad;
    softmax_out_grad.Resize(config->softmax_out_dims);
478
    AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
479 480 481 482 483 484 485

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    {
      // Forward: fmha_out = transpose(qktv_out)
      Tensor qktv_out_grad;
      qktv_out_grad.Resize(config->qktv_out_dims);
486
      AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
487 488 489 490 491 492
      ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);

      // Forward: qktv_out = BatchedGEMM(softmax_out, V)
      // Backward:
      //  V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout)
      int64_t gemm_m = config->m_size;
493
      int64_t gemm_n = config->head_dim;
494 495 496
      int64_t gemm_k = config->seq_len_r;

      const T* softmax_out_ptr = softmax_out->data<T>();
497
      const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
498 499 500 501 502 503
      ComputeBatchedGEMM(softmax_out_ptr, qktv_out_grad_ptr, v_grad_ptr, true,
                         false, gemm_m, gemm_n, gemm_k, gemm_batch_size);

      // Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T)
      gemm_m = config->seq_len_r;
      gemm_n = config->m_size;
504
      gemm_k = config->head_dim;
505 506 507 508 509 510

      T* softmax_out_grad_ptr = softmax_out_grad.data<T>();
      ComputeBatchedGEMM(qktv_out_grad_ptr, v_ptr, softmax_out_grad_ptr, false,
                         true, gemm_m, gemm_n, gemm_k, gemm_batch_size);
    }

511
    Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
512 513 514 515 516 517 518
    ComputeBiasMaskSoftmaxBackward(&softmax_out_grad, softmax_out,
                                   src_mask_grad, qk_out_grad,
                                   nonbatched_bias_grad);

    // Forward: qk_out = BatchedGEMM(Q, K^T)
    // Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x)
    int64_t gemm_m = config->m_size;
519
    int64_t gemm_n = config->head_dim;
520
    int64_t gemm_k = config->seq_len_r;
521
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
522 523 524 525 526 527 528

    T* qk_out_grad_ptr = qk_out_grad->data<T>();
    ComputeBatchedGEMM(qk_out_grad_ptr, q_ptr, k_grad_ptr, true, false, gemm_m,
                       gemm_n, gemm_k, gemm_batch_size, alpha);

    // Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y)
    gemm_m = config->seq_len_r;
529
    gemm_n = config->head_dim;
530 531 532 533 534
    gemm_k = config->m_size;
    ComputeBatchedGEMM(qk_out_grad_ptr, k_ptr, q_grad_ptr, false, false, gemm_m,
                       gemm_n, gemm_k, gemm_batch_size, alpha);

    if (merge_qkv_) {
535
      Tensor* qkv_out_grad = config->GetQKVOutGrad();
536 537
      ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
    } else {
538 539 540
      Tensor* q_out_grad = config->GetQueryOutGrad();
      Tensor* k_out_grad = config->GetKeyOutGrad();
      Tensor* v_out_grad = config->GetValueOutGrad();
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
      ComputeQKVTransposeBackward(q_transpose_out_grad, k_transpose_out_grad,
                                  v_transpose_out_grad, q_out_grad, k_out_grad,
                                  v_out_grad);
    }
  }

  void ComputeQKVTransposeForward(const Tensor& q_out, const Tensor& k_out,
                                  const Tensor& v_out, Tensor* q_transpose_out,
                                  Tensor* k_transpose_out,
                                  Tensor* v_transpose_out) {
    int ndims = 5;
    std::vector<int> perm = {0, 1, 3, 2, 4};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, q_out, perm, q_transpose_out);
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, k_out, perm, k_transpose_out);
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, v_out, perm, v_transpose_out);
  }

  void ComputeQKVTransposeBackward(const Tensor& q_transpose_out_grad,
                                   const Tensor& k_transpose_out_grad,
                                   const Tensor& v_transpose_out_grad,
                                   Tensor* q_out_grad, Tensor* k_out_grad,
                                   Tensor* v_out_grad) {
    int ndims = 5;
    std::vector<int> perm = {0, 1, 3, 2, 4};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, q_transpose_out_grad, perm,
                                q_out_grad);
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, k_transpose_out_grad, perm,
                                k_out_grad);
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, v_transpose_out_grad, perm,
                                v_out_grad);
  }

573 574
  // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
  //         [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
  void ComputeQKVTransposeForward(const Tensor& qkv_out,
                                  Tensor* qkv_transpose_out) {
    int ndims = 6;
    std::vector<int> perm = {3, 0, 1, 4, 2, 5};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qkv_out, perm,
                                qkv_transpose_out);
  }

  void ComputeQKVTransposeBackward(const Tensor& qkv_transpose_out_grad,
                                   Tensor* qkv_out_grad) {
    int ndims = 6;
    std::vector<int> perm = {1, 2, 4, 0, 3, 5};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qkv_transpose_out_grad, perm,
                                qkv_out_grad);
  }

  // [batch_size, seq_len_m, num_head, seq_len_r, c] ->
  //         [batch_size, seq_len_m, seq_len_r, num_head, c]
  void ComputeQKTVTransposeForward(const Tensor& qktv_out, Tensor* fmha_out) {
    int ndims = 5;
    std::vector<int> perm = {0, 1, 3, 2, 4};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, qktv_out, perm, fmha_out);
  }

  void ComputeQKTVTransposeBackward(const Tensor& fmha_out_grad,
                                    Tensor* qktv_out_grad) {
    int ndims = 5;
    std::vector<int> perm = {0, 1, 3, 2, 4};
    TransposeGPUKernelDriver<T>(dev_ctx_, ndims, fmha_out_grad, perm,
                                qktv_out_grad);
  }

  // qk_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
  void ComputeBiasMaskSoftmaxForward(const Tensor* nonbatched_bias,
                                     const Tensor* src_mask, Tensor* qk_out,
                                     Tensor* softmax_out) {
    if (nonbatched_bias) {
613
      std::vector<const Tensor*> ins = {qk_out, src_mask, nonbatched_bias};
614
      std::vector<Tensor*> outs = {qk_out};
615
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
616 617 618 619
          dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
    } else {
      std::vector<const Tensor*> ins = {qk_out, src_mask};
      std::vector<Tensor*> outs = {qk_out};
620
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
    }
    phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
  }

  // src_mask_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
  void ComputeBiasMaskSoftmaxBackward(const Tensor* softmax_out_grad,
                                      const Tensor* softmax_out,
                                      Tensor* src_mask_grad,
                                      Tensor* qk_out_grad,
                                      Tensor* nonbatched_bias_grad) {
    PADDLE_ENFORCE_NOT_NULL(
        qk_out_grad,
        platform::errors::NotFound("The qk_out_grad can not be nullptr."));

    PADDLE_ENFORCE_EQ(qk_out_grad->dims(), softmax_out->dims(),
                      platform::errors::InvalidArgument(
                          "The shape of qk_out_grad and softmax_out is "
                          "expected to be the same. But recieved qk_out_grad's "
                          "shape = %s, softmax_out's shape = %s.",
                          qk_out_grad->dims(), softmax_out->dims()));

    PADDLE_ENFORCE_EQ(src_mask_grad, nullptr,
                      platform::errors::InvalidArgument(
                          "src_mask_grad is expected to be nullptr."));

    phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, *softmax_out,
                                            *softmax_out_grad, -1, qk_out_grad);

    if (nonbatched_bias_grad) {
652 653 654
      // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
      //      [batch_size, 1, num_heads, seq_len_r, m_size]
      phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
655
          dev_ctx_, *qk_out_grad, nonbatched_bias_grad,
656
          kps::IdentityFunctor<T>(), {1});
657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681
    }
  }

 private:
  void ComputeBatchedGEMM(const T* a_ptr, const T* b_ptr, T* c_ptr,
                          bool trans_a, bool trans_b, int64_t m, int64_t n,
                          int64_t k, int64_t batch_size,
                          T alpha = static_cast<T>(1.0),
                          T beta = static_cast<T>(0.0)) {
    CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
    int64_t stride_a = m * k;
    int64_t stride_b = k * n;

    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
    blas.BatchedGEMM(cblas_trans_a, cblas_trans_b, m, n, k, alpha, a_ptr, b_ptr,
                     beta, c_ptr, batch_size, stride_a, stride_b);
  }

  const platform::CUDADeviceContext& dev_ctx_;
  bool merge_qkv_;
};

}  // namespace operators
}  // namespace paddle