fused_gate_attention.h 27.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/transpose_op.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
21
#include "paddle/phi/kernels/funcs/reduce_function.h"
22 23 24 25 26 27 28 29
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

inline std::string MemoryDebugString(const Tensor& t) {
30 31 32 33 34 35
  int device_id = platform::GetCurrentDeviceId();
  int64_t allocated =
      memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
  int64_t reserved =
      memory::DeviceMemoryStatCurrentValue("Reserved", device_id);

36 37 38
  std::stringstream ss;
  ss << "shape=[" << t.dims()
     << "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
39 40 41 42
     << " MB, ptr=" << t.data()
     << "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
     << " MB"
     << ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
43 44 45
  return ss.str();
}

46
template <typename T>
L
Leo Chen 已提交
47
void AllocWithDebugInfo(const phi::GPUContext& dev_ctx,
48 49
                        const std::string& info,
                        Tensor* t) {
50
  dev_ctx.Alloc<T>(t, t->numel() * sizeof(T));
51 52 53
  VLOG(4) << info << ": " << MemoryDebugString(*t);
}

54 55 56 57 58 59 60 61
template <typename T>
struct TernaryAddFunctor {
  inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};

template <typename T>
struct GateAttentionConfig {
 public:
L
Leo Chen 已提交
62
  const phi::GPUContext& dev_ctx;
63 64 65 66

  bool merge_qkv;
  bool has_gating;

67 68 69 70 71
  int64_t batch_size;
  int64_t seq_len_m;
  int64_t seq_len_r;
  int64_t q_dim;
  int64_t kv_dim;
72
  int64_t head_dim;
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
  int64_t m_size;
  int64_t num_heads;

  phi::DDim qkv_out_dims;
  phi::DDim qkv_transpose_out_dims;

  phi::DDim q_out_dims;
  phi::DDim kv_out_dims;
  phi::DDim q_transpose_out_dims;
  phi::DDim kv_transpose_out_dims;

  phi::DDim qk_out_dims;
  phi::DDim softmax_out_dims;
  phi::DDim qktv_out_dims;
  phi::DDim gate_out_dims;

L
Leo Chen 已提交
89
  GateAttentionConfig(const phi::GPUContext& dev_ctx,
90 91 92 93 94 95
                      const Tensor* query,
                      const Tensor* key,
                      const Tensor* query_weight,
                      const Tensor* qkv_weight,
                      bool merge_qkv,
                      bool has_gating)
96
      : dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
97 98 99 100 101 102 103 104 105 106 107 108 109
    // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
    batch_size = query->dims()[0];
    seq_len_m = query->dims()[1];
    seq_len_r = query->dims()[2];
    q_dim = query->dims()[3];

    if (merge_qkv) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_weight,
          platform::errors::NotFound("The input qkv_weight can not be nullptr "
                                     "when merge_qkv is true."));

      // When q_dim == kv_dim, QKV matmul can be computed merged.
110
      // qkv_weight: shape=[3, num_heads, head_dim, q_dim]
111
      num_heads = qkv_weight->dims()[1];
112
      head_dim = qkv_weight->dims()[2];
113 114 115
      m_size = seq_len_r;
      kv_dim = q_dim;

116
      qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim};
117 118
      qkv_transpose_out_dims = {
          3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
119 120 121 122 123 124 125 126 127 128 129 130
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          key,
          platform::errors::NotFound(
              "The input key can not be nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          query_weight,
          platform::errors::NotFound("The input query_weight can not be "
                                     "nullptr when merge_qkv is false."));

      // When q_dim != kv_dim, QKV matmul must be computed saparately.
      // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
131
      // query_w: shape=[q_dim, num_heads, head_dim]
132
      num_heads = query_weight->dims()[1];
133
      head_dim = query_weight->dims()[2];
134 135 136
      m_size = key->dims()[2];
      kv_dim = key->dims()[3];

137 138
      q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
      kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim};
139 140 141 142
      q_transpose_out_dims = {
          batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
      kv_transpose_out_dims = {
          batch_size, seq_len_m, num_heads, m_size, head_dim};
143 144 145 146
    }

    qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
    softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
147 148
    qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
    gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
149 150 151
  }

  int64_t GetQuerySize() const {
152
    return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
153 154
  }

155
  Tensor* GetQKVOut() {
156 157
    if (!qkv_out.IsInitialized()) {
      qkv_out.Resize(qkv_out_dims);
158
      AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
159 160 161 162
    }
    return &qkv_out;
  }

163
  Tensor* GetQueryOut() {
164 165
    if (!query_out.IsInitialized()) {
      query_out.Resize(q_out_dims);
166
      AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
167 168 169 170
    }
    return &query_out;
  }

171
  Tensor* GetKeyOut() {
172 173
    if (!key_out.IsInitialized()) {
      key_out.Resize(kv_out_dims);
174
      AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
175 176 177 178
    }
    return &key_out;
  }

179
  Tensor* GetValueOut() {
180 181
    if (!value_out.IsInitialized()) {
      value_out.Resize(kv_out_dims);
182
      AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
183 184 185 186
    }
    return &value_out;
  }

187
  Tensor* GetQKOut(Tensor* softmax_out) {
188 189 190 191 192 193
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = m_size;
    if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
      // Not sure whether cudnn softmax can execute inplace.
      if (!qkv_out.IsInitialized()) {
        qk_out.Resize(qk_out_dims);
194
        AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
195 196 197
      }
      return &qk_out;
    } else {
198
      // Enable inplace softmax.
199 200 201 202
      return softmax_out;
    }
  }

203 204 205 206 207 208 209 210 211 212 213 214 215 216
  Tensor* GetQKTVOut(Tensor* gate_out) {
    if (has_gating && gate_out) {
      // Reuse gate_out.
      gate_out->Resize(qktv_out_dims);
      return gate_out;
    } else {
      if (!qktv_out.IsInitialized()) {
        qktv_out.Resize(qktv_out_dims);
        AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
      }
      return &qktv_out;
    }
  }

217 218 219 220 221 222 223 224 225 226 227 228
  void ClearQKVOut() {
    if (qkv_out.IsInitialized()) {
      qkv_out.clear();
    }
  }

  void ClearQKOut() {
    if (qk_out.IsInitialized()) {
      qk_out.clear();
    }
  }

229 230 231 232 233 234
  void ClearQKTVOut() {
    if (qktv_out.IsInitialized()) {
      qktv_out.clear();
    }
  }

235 236 237 238 239 240 241 242 243 244
 protected:
  Tensor qkv_out;
  Tensor query_out;
  Tensor key_out;
  Tensor value_out;
  // qk_out = BatchedGEMM(Q, K^T)
  // qk_out: shape=[batch_size, seq_len_m, num_heads, seq_len_r, m_size]
  // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
  // The shape of qk_out, softmax_out is the same, thus can be called inplace.
  Tensor qk_out;
245 246
  // qktv_out may reuse gate_out.
  Tensor qktv_out;
247 248 249 250 251
};

template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> {
 public:
L
Leo Chen 已提交
252
  GateAttentionGradConfig(const phi::GPUContext& dev_ctx,
253 254 255 256 257 258 259 260 261 262 263 264 265
                          const Tensor* query,
                          const Tensor* key,
                          const Tensor* query_weight,
                          const Tensor* qkv_weight,
                          bool merge_qkv,
                          bool has_gating)
      : GateAttentionConfig<T>(dev_ctx,
                               query,
                               key,
                               query_weight,
                               qkv_weight,
                               merge_qkv,
                               has_gating) {}
266

267
  Tensor* GetQKVOutGrad() {
268 269
    if (!qkv_out_grad.IsInitialized()) {
      qkv_out_grad.Resize(this->qkv_out_dims);
270
      AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
271 272 273 274
    }
    return &qkv_out_grad;
  }

275
  Tensor* GetQueryOutGrad() {
276 277
    if (!query_out_grad.IsInitialized()) {
      query_out_grad.Resize(this->q_out_dims);
278
      AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
279 280 281 282
    }
    return &query_out_grad;
  }

283
  Tensor* GetKeyOutGrad() {
284 285
    if (!key_out_grad.IsInitialized()) {
      key_out_grad.Resize(this->kv_out_dims);
286
      AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
287 288 289 290
    }
    return &key_out_grad;
  }

291
  Tensor* GetValueOutGrad() {
292 293
    if (!value_out_grad.IsInitialized()) {
      value_out_grad.Resize(this->kv_out_dims);
294
      AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
295 296 297 298
    }
    return &value_out_grad;
  }

299
  Tensor* GetQKOutGrad(Tensor* softmax_out_grad) {
300 301 302
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = this->m_size;
    if (!softmax_out_grad ||
303
        phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
304 305
      if (!qk_out_grad.IsInitialized()) {
        qk_out_grad.Resize(this->qk_out_dims);
306
        AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
      }
      return &qk_out_grad;
    } else {
      return softmax_out_grad;
    }
  }

 protected:
  Tensor qkv_out_grad;
  Tensor query_out_grad;
  Tensor key_out_grad;
  Tensor value_out_grad;
  Tensor qk_out_grad;
};

template <typename T>
class FMHAGateRef {
 public:
L
Leo Chen 已提交
325
  FMHAGateRef(const phi::GPUContext& dev_ctx, bool merge_qkv)
326 327
      : dev_ctx_(dev_ctx), merge_qkv_(merge_qkv) {}

328 329 330 331 332 333 334 335 336
  void ComputeForward(const Tensor* nonbatched_bias,
                      const Tensor* src_mask,
                      Tensor* q_transpose_out,
                      Tensor* k_transpose_out,
                      Tensor* v_transpose_out,
                      Tensor* qkv_transpose_out,
                      Tensor* softmax_out,
                      Tensor* fmha_out,
                      Tensor* gate_out,
337 338 339 340 341 342 343 344 345 346 347
                      GateAttentionConfig<T>* config) {
    T* q_ptr = nullptr;
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;
    if (merge_qkv_) {
      // qkv_transpose_out = transpose(qkv_out)
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

348
      Tensor* qkv_out = config->GetQKVOut();
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
      ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
      config->ClearQKVOut();

      // q_size == k_size
      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

371 372 373
      Tensor* query_out = config->GetQueryOut();
      Tensor* key_out = config->GetKeyOut();
      Tensor* value_out = config->GetValueOut();
374 375 376 377 378
      ComputeQKVTransposeForward(*query_out,
                                 *key_out,
                                 *value_out,
                                 q_transpose_out,
                                 k_transpose_out,
379 380 381 382 383 384 385 386 387
                                 v_transpose_out);

      // q_size != k_size
      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();
    }

    // qk_out = BatchedGEMM(Q, K^T)
388 389
    // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
    //                [batch_size, seq_len_m, num_heads, m_size, head_dim]
390
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
391
    Tensor* qk_out = config->GetQKOut(softmax_out);
392 393 394 395 396 397
    T* qk_out_ptr = qk_out->data<T>();

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    int64_t gemm_m = config->seq_len_r;
    int64_t gemm_n = config->m_size;
398
    int64_t gemm_k = config->head_dim;
399

400
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
401 402 403 404 405 406 407 408 409 410
    ComputeBatchedGEMM(q_ptr,
                       k_ptr,
                       qk_out_ptr,
                       false,
                       true,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
411 412

    // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
413 414
    ComputeBiasMaskSoftmaxForward(
        nonbatched_bias, src_mask, qk_out, softmax_out);
415 416 417 418
    config->ClearQKOut();

    // qktv_out = BatchedGEMM(softmax_out, V)
    // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
419 420
    //               [batch_size, seq_len_m, num_heads, m_size, head_dim]
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
421 422
    Tensor* qktv_out = config->GetQKTVOut(gate_out);
    T* qktv_out_ptr = qktv_out->data<T>();
423 424

    gemm_m = config->seq_len_r;
425
    gemm_n = config->head_dim;
426 427 428
    gemm_k = config->m_size;

    T* softmax_out_ptr = softmax_out->data<T>();
429 430 431 432 433 434 435 436 437
    ComputeBatchedGEMM(softmax_out_ptr,
                       v_ptr,
                       qktv_out_ptr,
                       false,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size);
438 439

    // fmha_out = transpose(qktv_out)
440 441 442 443 444
    ComputeQKTVTransposeForward(*qktv_out, fmha_out);
    config->ClearQKTVOut();
    if (config->has_gating) {
      gate_out->Resize(config->gate_out_dims);
    }
445 446 447 448 449 450
  }

  void ComputeBackward(const Tensor* q_transpose_out,
                       const Tensor* k_transpose_out,
                       const Tensor* v_transpose_out,
                       const Tensor* qkv_transpose_out,
451 452 453 454
                       const Tensor* softmax_out,
                       const Tensor* fmha_out_grad,
                       Tensor* src_mask_grad,
                       Tensor* nonbatched_bias_grad,
455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
                       GateAttentionGradConfig<T>* config) {
    const T* q_ptr = nullptr;
    const T* k_ptr = nullptr;
    const T* v_ptr = nullptr;

    T* q_grad_ptr = nullptr;
    T* k_grad_ptr = nullptr;
    T* v_grad_ptr = nullptr;

    Tensor q_transpose_out_grad;
    Tensor k_transpose_out_grad;
    Tensor v_transpose_out_grad;
    Tensor qkv_transpose_out_grad;
    if (merge_qkv_) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;

      qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
480 481
      AllocWithDebugInfo<T>(
          dev_ctx_, "qkv_transpose_out_grad", &qkv_transpose_out_grad);
482

483
      q_grad_ptr = qkv_transpose_out_grad.data<T>();
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
      k_grad_ptr = q_grad_ptr + q_size;
      v_grad_ptr = k_grad_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();

      q_transpose_out_grad.Resize(config->q_transpose_out_dims);
      k_transpose_out_grad.Resize(config->kv_transpose_out_dims);
      v_transpose_out_grad.Resize(config->kv_transpose_out_dims);

508 509 510 511 512 513
      q_grad_ptr = dev_ctx_.Alloc<T>(&q_transpose_out_grad,
                                     q_transpose_out_grad.numel() * sizeof(T));
      k_grad_ptr = dev_ctx_.Alloc<T>(&k_transpose_out_grad,
                                     k_transpose_out_grad.numel() * sizeof(T));
      v_grad_ptr = dev_ctx_.Alloc<T>(&v_transpose_out_grad,
                                     v_transpose_out_grad.numel() * sizeof(T));
514 515 516 517
    }

    Tensor softmax_out_grad;
    softmax_out_grad.Resize(config->softmax_out_dims);
518
    AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
519 520 521 522 523 524 525

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    {
      // Forward: fmha_out = transpose(qktv_out)
      Tensor qktv_out_grad;
      qktv_out_grad.Resize(config->qktv_out_dims);
526
      AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
527 528 529 530 531 532
      ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);

      // Forward: qktv_out = BatchedGEMM(softmax_out, V)
      // Backward:
      //  V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout)
      int64_t gemm_m = config->m_size;
533
      int64_t gemm_n = config->head_dim;
534 535 536
      int64_t gemm_k = config->seq_len_r;

      const T* softmax_out_ptr = softmax_out->data<T>();
537
      const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
538 539 540 541 542 543 544 545 546
      ComputeBatchedGEMM(softmax_out_ptr,
                         qktv_out_grad_ptr,
                         v_grad_ptr,
                         true,
                         false,
                         gemm_m,
                         gemm_n,
                         gemm_k,
                         gemm_batch_size);
547 548 549 550

      // Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T)
      gemm_m = config->seq_len_r;
      gemm_n = config->m_size;
551
      gemm_k = config->head_dim;
552 553

      T* softmax_out_grad_ptr = softmax_out_grad.data<T>();
554 555 556 557 558 559 560 561 562
      ComputeBatchedGEMM(qktv_out_grad_ptr,
                         v_ptr,
                         softmax_out_grad_ptr,
                         false,
                         true,
                         gemm_m,
                         gemm_n,
                         gemm_k,
                         gemm_batch_size);
563 564
    }

565
    Tensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
566 567 568 569
    ComputeBiasMaskSoftmaxBackward(&softmax_out_grad,
                                   softmax_out,
                                   src_mask_grad,
                                   qk_out_grad,
570 571 572 573 574
                                   nonbatched_bias_grad);

    // Forward: qk_out = BatchedGEMM(Q, K^T)
    // Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x)
    int64_t gemm_m = config->m_size;
575
    int64_t gemm_n = config->head_dim;
576
    int64_t gemm_k = config->seq_len_r;
577
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
578 579

    T* qk_out_grad_ptr = qk_out_grad->data<T>();
580 581 582 583 584 585 586 587 588 589
    ComputeBatchedGEMM(qk_out_grad_ptr,
                       q_ptr,
                       k_grad_ptr,
                       true,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
590 591 592

    // Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y)
    gemm_m = config->seq_len_r;
593
    gemm_n = config->head_dim;
594
    gemm_k = config->m_size;
595 596 597 598 599 600 601 602 603 604
    ComputeBatchedGEMM(qk_out_grad_ptr,
                       k_ptr,
                       q_grad_ptr,
                       false,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
605 606

    if (merge_qkv_) {
607
      Tensor* qkv_out_grad = config->GetQKVOutGrad();
608 609
      ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
    } else {
610 611 612
      Tensor* q_out_grad = config->GetQueryOutGrad();
      Tensor* k_out_grad = config->GetKeyOutGrad();
      Tensor* v_out_grad = config->GetValueOutGrad();
613 614 615 616 617
      ComputeQKVTransposeBackward(q_transpose_out_grad,
                                  k_transpose_out_grad,
                                  v_transpose_out_grad,
                                  q_out_grad,
                                  k_out_grad,
618 619 620 621
                                  v_out_grad);
    }
  }

622 623 624 625
  void ComputeQKVTransposeForward(const Tensor& q_out,
                                  const Tensor& k_out,
                                  const Tensor& v_out,
                                  Tensor* q_transpose_out,
626 627 628
                                  Tensor* k_transpose_out,
                                  Tensor* v_transpose_out) {
    std::vector<int> perm = {0, 1, 3, 2, 4};
629 630 631
    TransposeGPUKernelDriver<T>(dev_ctx_, q_out, perm, q_transpose_out);
    TransposeGPUKernelDriver<T>(dev_ctx_, k_out, perm, k_transpose_out);
    TransposeGPUKernelDriver<T>(dev_ctx_, v_out, perm, v_transpose_out);
632 633 634 635 636
  }

  void ComputeQKVTransposeBackward(const Tensor& q_transpose_out_grad,
                                   const Tensor& k_transpose_out_grad,
                                   const Tensor& v_transpose_out_grad,
637 638
                                   Tensor* q_out_grad,
                                   Tensor* k_out_grad,
639 640
                                   Tensor* v_out_grad) {
    std::vector<int> perm = {0, 1, 3, 2, 4};
641
    TransposeGPUKernelDriver<T>(
642
        dev_ctx_, q_transpose_out_grad, perm, q_out_grad);
643
    TransposeGPUKernelDriver<T>(
644
        dev_ctx_, k_transpose_out_grad, perm, k_out_grad);
645
    TransposeGPUKernelDriver<T>(
646
        dev_ctx_, v_transpose_out_grad, perm, v_out_grad);
647 648
  }

649 650
  // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
  //         [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
651 652 653
  void ComputeQKVTransposeForward(const Tensor& qkv_out,
                                  Tensor* qkv_transpose_out) {
    std::vector<int> perm = {3, 0, 1, 4, 2, 5};
654
    TransposeGPUKernelDriver<T>(dev_ctx_, qkv_out, perm, qkv_transpose_out);
655 656 657 658 659
  }

  void ComputeQKVTransposeBackward(const Tensor& qkv_transpose_out_grad,
                                   Tensor* qkv_out_grad) {
    std::vector<int> perm = {1, 2, 4, 0, 3, 5};
660
    TransposeGPUKernelDriver<T>(
661
        dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad);
662 663 664 665 666 667
  }

  // [batch_size, seq_len_m, num_head, seq_len_r, c] ->
  //         [batch_size, seq_len_m, seq_len_r, num_head, c]
  void ComputeQKTVTransposeForward(const Tensor& qktv_out, Tensor* fmha_out) {
    std::vector<int> perm = {0, 1, 3, 2, 4};
668
    TransposeGPUKernelDriver<T>(dev_ctx_, qktv_out, perm, fmha_out);
669 670 671 672 673
  }

  void ComputeQKTVTransposeBackward(const Tensor& fmha_out_grad,
                                    Tensor* qktv_out_grad) {
    std::vector<int> perm = {0, 1, 3, 2, 4};
674
    TransposeGPUKernelDriver<T>(dev_ctx_, fmha_out_grad, perm, qktv_out_grad);
675 676 677 678 679
  }

  // qk_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
  void ComputeBiasMaskSoftmaxForward(const Tensor* nonbatched_bias,
680 681
                                     const Tensor* src_mask,
                                     Tensor* qk_out,
682 683
                                     Tensor* softmax_out) {
    if (nonbatched_bias) {
684
      std::vector<const Tensor*> ins = {qk_out, src_mask, nonbatched_bias};
685
      std::vector<Tensor*> outs = {qk_out};
686
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
687 688 689 690
          dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
    } else {
      std::vector<const Tensor*> ins = {qk_out, src_mask};
      std::vector<Tensor*> outs = {qk_out};
691
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
    }
    phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
  }

  // src_mask_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
  void ComputeBiasMaskSoftmaxBackward(const Tensor* softmax_out_grad,
                                      const Tensor* softmax_out,
                                      Tensor* src_mask_grad,
                                      Tensor* qk_out_grad,
                                      Tensor* nonbatched_bias_grad) {
    PADDLE_ENFORCE_NOT_NULL(
        qk_out_grad,
        platform::errors::NotFound("The qk_out_grad can not be nullptr."));

708 709
    PADDLE_ENFORCE_EQ(qk_out_grad->dims(),
                      softmax_out->dims(),
710 711 712 713
                      platform::errors::InvalidArgument(
                          "The shape of qk_out_grad and softmax_out is "
                          "expected to be the same. But recieved qk_out_grad's "
                          "shape = %s, softmax_out's shape = %s.",
714 715
                          qk_out_grad->dims(),
                          softmax_out->dims()));
716

717 718
    PADDLE_ENFORCE_EQ(src_mask_grad,
                      nullptr,
719 720 721
                      platform::errors::InvalidArgument(
                          "src_mask_grad is expected to be nullptr."));

722 723
    phi::SoftmaxBackwardCUDAKernelDriver<T>(
        dev_ctx_, *softmax_out, *softmax_out_grad, -1, qk_out_grad);
724 725

    if (nonbatched_bias_grad) {
726 727 728
      // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
      //      [batch_size, 1, num_heads, seq_len_r, m_size]
      phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
729 730 731 732 733
          dev_ctx_,
          *qk_out_grad,
          nonbatched_bias_grad,
          kps::IdentityFunctor<T>(),
          {1});
734 735 736 737
    }
  }

 private:
738 739 740 741 742 743 744 745 746
  void ComputeBatchedGEMM(const T* a_ptr,
                          const T* b_ptr,
                          T* c_ptr,
                          bool trans_a,
                          bool trans_b,
                          int64_t m,
                          int64_t n,
                          int64_t k,
                          int64_t batch_size,
747 748 749 750 751 752 753
                          T alpha = static_cast<T>(1.0),
                          T beta = static_cast<T>(0.0)) {
    CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
    int64_t stride_a = m * k;
    int64_t stride_b = k * n;

L
Leo Chen 已提交
754
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
755 756 757 758 759 760 761 762 763 764 765 766 767
    blas.BatchedGEMM(cblas_trans_a,
                     cblas_trans_b,
                     m,
                     n,
                     k,
                     alpha,
                     a_ptr,
                     b_ptr,
                     beta,
                     c_ptr,
                     batch_size,
                     stride_a,
                     stride_b);
768 769
  }

L
Leo Chen 已提交
770
  const phi::GPUContext& dev_ctx_;
771 772 773 774 775
  bool merge_qkv_;
};

}  // namespace operators
}  // namespace paddle