fused_gate_attention.h 28.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
20
#include "paddle/phi/kernels/funcs/reduce_function.h"
21
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
22 23 24 25 26
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
namespace operators {

27
inline std::string MemoryDebugString(const phi::DenseTensor& t) {
28 29 30 31 32 33
  int device_id = platform::GetCurrentDeviceId();
  int64_t allocated =
      memory::DeviceMemoryStatCurrentValue("Allocated", device_id);
  int64_t reserved =
      memory::DeviceMemoryStatCurrentValue("Reserved", device_id);

34 35 36
  std::stringstream ss;
  ss << "shape=[" << t.dims()
     << "], size=" << static_cast<float>(t.memory_size()) / (1 << 20)
37 38 39 40
     << " MB, ptr=" << t.data()
     << "; [MEMORY] allocated=" << static_cast<float>(allocated) / (1 << 20)
     << " MB"
     << ", reserved=" << static_cast<float>(reserved) / (1 << 20) << " MB";
41 42 43
  return ss.str();
}

44
template <typename T>
L
Leo Chen 已提交
45
void AllocWithDebugInfo(const phi::GPUContext& dev_ctx,
46
                        const std::string& info,
47
                        phi::DenseTensor* t) {
48
  dev_ctx.Alloc<T>(t, t->numel() * sizeof(T));
49 50 51
  VLOG(4) << info << ": " << MemoryDebugString(*t);
}

52 53 54 55 56 57 58 59
template <typename T>
struct TernaryAddFunctor {
  inline HOSTDEVICE T operator()(T a, T b, T c) const { return a + b + c; }
};

template <typename T>
struct GateAttentionConfig {
 public:
L
Leo Chen 已提交
60
  const phi::GPUContext& dev_ctx;
61 62 63 64

  bool merge_qkv;
  bool has_gating;

65 66 67 68 69
  int64_t batch_size;
  int64_t seq_len_m;
  int64_t seq_len_r;
  int64_t q_dim;
  int64_t kv_dim;
70
  int64_t head_dim;
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
  int64_t m_size;
  int64_t num_heads;

  phi::DDim qkv_out_dims;
  phi::DDim qkv_transpose_out_dims;

  phi::DDim q_out_dims;
  phi::DDim kv_out_dims;
  phi::DDim q_transpose_out_dims;
  phi::DDim kv_transpose_out_dims;

  phi::DDim qk_out_dims;
  phi::DDim softmax_out_dims;
  phi::DDim qktv_out_dims;
  phi::DDim gate_out_dims;

L
Leo Chen 已提交
87
  GateAttentionConfig(const phi::GPUContext& dev_ctx,
88 89 90 91
                      const phi::DenseTensor* query,
                      const phi::DenseTensor* key,
                      const phi::DenseTensor* query_weight,
                      const phi::DenseTensor* qkv_weight,
92 93
                      bool merge_qkv,
                      bool has_gating)
94
      : dev_ctx(dev_ctx), merge_qkv(merge_qkv), has_gating(has_gating) {
95 96 97 98 99 100 101 102 103 104 105 106 107
    // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
    batch_size = query->dims()[0];
    seq_len_m = query->dims()[1];
    seq_len_r = query->dims()[2];
    q_dim = query->dims()[3];

    if (merge_qkv) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_weight,
          platform::errors::NotFound("The input qkv_weight can not be nullptr "
                                     "when merge_qkv is true."));

      // When q_dim == kv_dim, QKV matmul can be computed merged.
108
      // qkv_weight: shape=[3, num_heads, head_dim, q_dim]
109
      num_heads = qkv_weight->dims()[1];
110
      head_dim = qkv_weight->dims()[2];
111 112 113
      m_size = seq_len_r;
      kv_dim = q_dim;

114
      qkv_out_dims = {batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim};
115 116
      qkv_transpose_out_dims = {
          3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
117 118 119 120 121 122 123 124 125 126 127 128
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          key,
          platform::errors::NotFound(
              "The input key can not be nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          query_weight,
          platform::errors::NotFound("The input query_weight can not be "
                                     "nullptr when merge_qkv is false."));

      // When q_dim != kv_dim, QKV matmul must be computed saparately.
      // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
129
      // query_w: shape=[q_dim, num_heads, head_dim]
130
      num_heads = query_weight->dims()[1];
131
      head_dim = query_weight->dims()[2];
132 133 134
      m_size = key->dims()[2];
      kv_dim = key->dims()[3];

135 136
      q_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
      kv_out_dims = {batch_size, seq_len_m, m_size, num_heads, head_dim};
137 138 139 140
      q_transpose_out_dims = {
          batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
      kv_transpose_out_dims = {
          batch_size, seq_len_m, num_heads, m_size, head_dim};
141 142 143 144
    }

    qk_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
    softmax_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, m_size};
145 146
    qktv_out_dims = {batch_size, seq_len_m, num_heads, seq_len_r, head_dim};
    gate_out_dims = {batch_size, seq_len_m, seq_len_r, num_heads, head_dim};
147 148 149
  }

  int64_t GetQuerySize() const {
150
    return batch_size * seq_len_m * seq_len_r * num_heads * head_dim;
151 152
  }

153
  phi::DenseTensor* GetQKVOut() {
154 155
    if (!qkv_out.IsInitialized()) {
      qkv_out.Resize(qkv_out_dims);
156
      AllocWithDebugInfo<T>(dev_ctx, "qkv_out", &qkv_out);
157 158 159 160
    }
    return &qkv_out;
  }

161
  phi::DenseTensor* GetQueryOut() {
162 163
    if (!query_out.IsInitialized()) {
      query_out.Resize(q_out_dims);
164
      AllocWithDebugInfo<T>(dev_ctx, "query_out", &query_out);
165 166 167 168
    }
    return &query_out;
  }

169
  phi::DenseTensor* GetKeyOut() {
170 171
    if (!key_out.IsInitialized()) {
      key_out.Resize(kv_out_dims);
172
      AllocWithDebugInfo<T>(dev_ctx, "key_out", &key_out);
173 174 175 176
    }
    return &key_out;
  }

177
  phi::DenseTensor* GetValueOut() {
178 179
    if (!value_out.IsInitialized()) {
      value_out.Resize(kv_out_dims);
180
      AllocWithDebugInfo<T>(dev_ctx, "value_out", &value_out);
181 182 183 184
    }
    return &value_out;
  }

185
  phi::DenseTensor* GetQKOut(phi::DenseTensor* softmax_out) {
186 187 188 189 190 191
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = m_size;
    if (!softmax_out || phi::UseCudnnSoftmax<T>(dev_ctx, softmax_dim, true)) {
      // Not sure whether cudnn softmax can execute inplace.
      if (!qkv_out.IsInitialized()) {
        qk_out.Resize(qk_out_dims);
192
        AllocWithDebugInfo<T>(dev_ctx, "qk_out", &qk_out);
193 194 195
      }
      return &qk_out;
    } else {
196
      // Enable inplace softmax.
197 198 199 200
      return softmax_out;
    }
  }

201
  phi::DenseTensor* GetQKTVOut(phi::DenseTensor* gate_out) {
202 203 204 205 206 207 208 209 210 211 212 213 214
    if (has_gating && gate_out) {
      // Reuse gate_out.
      gate_out->Resize(qktv_out_dims);
      return gate_out;
    } else {
      if (!qktv_out.IsInitialized()) {
        qktv_out.Resize(qktv_out_dims);
        AllocWithDebugInfo<T>(dev_ctx, "qktv_out", &qktv_out);
      }
      return &qktv_out;
    }
  }

215 216 217 218 219 220 221 222 223 224 225 226
  void ClearQKVOut() {
    if (qkv_out.IsInitialized()) {
      qkv_out.clear();
    }
  }

  void ClearQKOut() {
    if (qk_out.IsInitialized()) {
      qk_out.clear();
    }
  }

227 228 229 230 231 232
  void ClearQKTVOut() {
    if (qktv_out.IsInitialized()) {
      qktv_out.clear();
    }
  }

233
 protected:
234 235 236 237
  phi::DenseTensor qkv_out;
  phi::DenseTensor query_out;
  phi::DenseTensor key_out;
  phi::DenseTensor value_out;
238 239 240 241
  // qk_out = BatchedGEMM(Q, K^T)
  // qk_out: shape=[batch_size, seq_len_m, num_heads, seq_len_r, m_size]
  // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
  // The shape of qk_out, softmax_out is the same, thus can be called inplace.
242
  phi::DenseTensor qk_out;
243
  // qktv_out may reuse gate_out.
244
  phi::DenseTensor qktv_out;
245 246 247 248 249
};

template <typename T>
struct GateAttentionGradConfig : public GateAttentionConfig<T> {
 public:
L
Leo Chen 已提交
250
  GateAttentionGradConfig(const phi::GPUContext& dev_ctx,
251 252 253 254
                          const phi::DenseTensor* query,
                          const phi::DenseTensor* key,
                          const phi::DenseTensor* query_weight,
                          const phi::DenseTensor* qkv_weight,
255 256 257 258 259 260 261 262 263
                          bool merge_qkv,
                          bool has_gating)
      : GateAttentionConfig<T>(dev_ctx,
                               query,
                               key,
                               query_weight,
                               qkv_weight,
                               merge_qkv,
                               has_gating) {}
264

265
  phi::DenseTensor* GetQKVOutGrad() {
266 267
    if (!qkv_out_grad.IsInitialized()) {
      qkv_out_grad.Resize(this->qkv_out_dims);
268
      AllocWithDebugInfo<T>(this->dev_ctx, "qkv_out_grad", &qkv_out_grad);
269 270 271 272
    }
    return &qkv_out_grad;
  }

273
  phi::DenseTensor* GetQueryOutGrad() {
274 275
    if (!query_out_grad.IsInitialized()) {
      query_out_grad.Resize(this->q_out_dims);
276
      AllocWithDebugInfo<T>(this->dev_ctx, "query_out_grad", &query_out_grad);
277 278 279 280
    }
    return &query_out_grad;
  }

281
  phi::DenseTensor* GetKeyOutGrad() {
282 283
    if (!key_out_grad.IsInitialized()) {
      key_out_grad.Resize(this->kv_out_dims);
284
      AllocWithDebugInfo<T>(this->dev_ctx, "key_out_grad", &key_out_grad);
285 286 287 288
    }
    return &key_out_grad;
  }

289
  phi::DenseTensor* GetValueOutGrad() {
290 291
    if (!value_out_grad.IsInitialized()) {
      value_out_grad.Resize(this->kv_out_dims);
292
      AllocWithDebugInfo<T>(this->dev_ctx, "value_out_grad", &value_out_grad);
293 294 295 296
    }
    return &value_out_grad;
  }

297
  phi::DenseTensor* GetQKOutGrad(phi::DenseTensor* softmax_out_grad) {
298 299 300
    // softmax_dim = qk_out_dim[-1] = qk_out_dim[rank - 1]
    int softmax_dim = this->m_size;
    if (!softmax_out_grad ||
301
        phi::UseCudnnSoftmax<T>(this->dev_ctx, softmax_dim, true)) {
302 303
      if (!qk_out_grad.IsInitialized()) {
        qk_out_grad.Resize(this->qk_out_dims);
304
        AllocWithDebugInfo<T>(this->dev_ctx, "qk_out_grad", &qk_out_grad);
305 306 307 308 309 310 311 312
      }
      return &qk_out_grad;
    } else {
      return softmax_out_grad;
    }
  }

 protected:
313 314 315 316 317
  phi::DenseTensor qkv_out_grad;
  phi::DenseTensor query_out_grad;
  phi::DenseTensor key_out_grad;
  phi::DenseTensor value_out_grad;
  phi::DenseTensor qk_out_grad;
318 319 320 321 322
};

template <typename T>
class FMHAGateRef {
 public:
L
Leo Chen 已提交
323
  FMHAGateRef(const phi::GPUContext& dev_ctx, bool merge_qkv)
324 325
      : dev_ctx_(dev_ctx), merge_qkv_(merge_qkv) {}

326 327 328 329 330 331 332 333 334
  void ComputeForward(const phi::DenseTensor* nonbatched_bias,
                      const phi::DenseTensor* src_mask,
                      phi::DenseTensor* q_transpose_out,
                      phi::DenseTensor* k_transpose_out,
                      phi::DenseTensor* v_transpose_out,
                      phi::DenseTensor* qkv_transpose_out,
                      phi::DenseTensor* softmax_out,
                      phi::DenseTensor* fmha_out,
                      phi::DenseTensor* gate_out,
335 336 337 338 339 340 341 342 343 344 345
                      GateAttentionConfig<T>* config) {
    T* q_ptr = nullptr;
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;
    if (merge_qkv_) {
      // qkv_transpose_out = transpose(qkv_out)
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

346
      phi::DenseTensor* qkv_out = config->GetQKVOut();
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
      ComputeQKVTransposeForward(*qkv_out, qkv_transpose_out);
      config->ClearQKVOut();

      // q_size == k_size
      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

369 370 371
      phi::DenseTensor* query_out = config->GetQueryOut();
      phi::DenseTensor* key_out = config->GetKeyOut();
      phi::DenseTensor* value_out = config->GetValueOut();
372 373 374 375 376
      ComputeQKVTransposeForward(*query_out,
                                 *key_out,
                                 *value_out,
                                 q_transpose_out,
                                 k_transpose_out,
377 378 379 380 381 382 383 384 385
                                 v_transpose_out);

      // q_size != k_size
      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();
    }

    // qk_out = BatchedGEMM(Q, K^T)
386 387
    // [batch_size, seq_len_m, num_heads, seq_len_r, head_dim] *
    //                [batch_size, seq_len_m, num_heads, m_size, head_dim]
388
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, m_size]
389
    phi::DenseTensor* qk_out = config->GetQKOut(softmax_out);
390 391 392 393 394 395
    T* qk_out_ptr = qk_out->data<T>();

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    int64_t gemm_m = config->seq_len_r;
    int64_t gemm_n = config->m_size;
396
    int64_t gemm_k = config->head_dim;
397

398
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
399 400 401 402 403 404 405 406 407 408
    ComputeBatchedGEMM(q_ptr,
                       k_ptr,
                       qk_out_ptr,
                       false,
                       true,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
409 410

    // softmax_out = softmax(qk_out + nonbatched_bias + src_mask)
411 412
    ComputeBiasMaskSoftmaxForward(
        nonbatched_bias, src_mask, qk_out, softmax_out);
413 414 415 416
    config->ClearQKOut();

    // qktv_out = BatchedGEMM(softmax_out, V)
    // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] *
417 418
    //               [batch_size, seq_len_m, num_heads, m_size, head_dim]
    // -> [batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
419
    phi::DenseTensor* qktv_out = config->GetQKTVOut(gate_out);
420
    T* qktv_out_ptr = qktv_out->data<T>();
421 422

    gemm_m = config->seq_len_r;
423
    gemm_n = config->head_dim;
424 425 426
    gemm_k = config->m_size;

    T* softmax_out_ptr = softmax_out->data<T>();
427 428 429 430 431 432 433 434 435
    ComputeBatchedGEMM(softmax_out_ptr,
                       v_ptr,
                       qktv_out_ptr,
                       false,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size);
436 437

    // fmha_out = transpose(qktv_out)
438 439 440 441 442
    ComputeQKTVTransposeForward(*qktv_out, fmha_out);
    config->ClearQKTVOut();
    if (config->has_gating) {
      gate_out->Resize(config->gate_out_dims);
    }
443 444
  }

445 446 447 448 449 450 451 452
  void ComputeBackward(const phi::DenseTensor* q_transpose_out,
                       const phi::DenseTensor* k_transpose_out,
                       const phi::DenseTensor* v_transpose_out,
                       const phi::DenseTensor* qkv_transpose_out,
                       const phi::DenseTensor* softmax_out,
                       const phi::DenseTensor* fmha_out_grad,
                       phi::DenseTensor* src_mask_grad,
                       phi::DenseTensor* nonbatched_bias_grad,
453 454 455 456 457 458 459 460 461
                       GateAttentionGradConfig<T>* config) {
    const T* q_ptr = nullptr;
    const T* k_ptr = nullptr;
    const T* v_ptr = nullptr;

    T* q_grad_ptr = nullptr;
    T* k_grad_ptr = nullptr;
    T* v_grad_ptr = nullptr;

462 463 464 465
    phi::DenseTensor q_transpose_out_grad;
    phi::DenseTensor k_transpose_out_grad;
    phi::DenseTensor v_transpose_out_grad;
    phi::DenseTensor qkv_transpose_out_grad;
466 467 468 469 470 471 472 473 474 475 476 477
    if (merge_qkv_) {
      PADDLE_ENFORCE_NOT_NULL(
          qkv_transpose_out,
          platform::errors::NotFound("The input qkv_transpose_out can not be "
                                     "nullptr when merge_qkv is true."));

      int64_t q_size = config->GetQuerySize();
      q_ptr = qkv_transpose_out->data<T>();
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + q_size;

      qkv_transpose_out_grad.Resize(config->qkv_transpose_out_dims);
478 479
      AllocWithDebugInfo<T>(
          dev_ctx_, "qkv_transpose_out_grad", &qkv_transpose_out_grad);
480

481
      q_grad_ptr = qkv_transpose_out_grad.data<T>();
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
      k_grad_ptr = q_grad_ptr + q_size;
      v_grad_ptr = k_grad_ptr + q_size;
    } else {
      PADDLE_ENFORCE_NOT_NULL(
          q_transpose_out,
          platform::errors::NotFound("The input q_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          k_transpose_out,
          platform::errors::NotFound("The input k_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));
      PADDLE_ENFORCE_NOT_NULL(
          v_transpose_out,
          platform::errors::NotFound("The input v_transpose_out can not be "
                                     "nullptr when merge_qkv is false."));

      q_ptr = q_transpose_out->data<T>();
      k_ptr = k_transpose_out->data<T>();
      v_ptr = v_transpose_out->data<T>();

      q_transpose_out_grad.Resize(config->q_transpose_out_dims);
      k_transpose_out_grad.Resize(config->kv_transpose_out_dims);
      v_transpose_out_grad.Resize(config->kv_transpose_out_dims);

506 507 508 509 510 511
      q_grad_ptr = dev_ctx_.Alloc<T>(&q_transpose_out_grad,
                                     q_transpose_out_grad.numel() * sizeof(T));
      k_grad_ptr = dev_ctx_.Alloc<T>(&k_transpose_out_grad,
                                     k_transpose_out_grad.numel() * sizeof(T));
      v_grad_ptr = dev_ctx_.Alloc<T>(&v_transpose_out_grad,
                                     v_transpose_out_grad.numel() * sizeof(T));
512 513
    }

514
    phi::DenseTensor softmax_out_grad;
515
    softmax_out_grad.Resize(config->softmax_out_dims);
516
    AllocWithDebugInfo<T>(dev_ctx_, "softmax_out_grad", &softmax_out_grad);
517 518 519 520 521

    int64_t gemm_batch_size =
        config->batch_size * config->seq_len_m * config->num_heads;
    {
      // Forward: fmha_out = transpose(qktv_out)
522
      phi::DenseTensor qktv_out_grad;
523
      qktv_out_grad.Resize(config->qktv_out_dims);
524
      AllocWithDebugInfo<T>(dev_ctx_, "qktv_out_grad", &qktv_out_grad);
525 526 527 528 529 530
      ComputeQKTVTransposeBackward(*fmha_out_grad, &qktv_out_grad);

      // Forward: qktv_out = BatchedGEMM(softmax_out, V)
      // Backward:
      //  V_grad = BatchedGEMM(softmax_out^T, qktv_out_grad) (dy = x^T * dout)
      int64_t gemm_m = config->m_size;
531
      int64_t gemm_n = config->head_dim;
532 533 534
      int64_t gemm_k = config->seq_len_r;

      const T* softmax_out_ptr = softmax_out->data<T>();
535
      const T* qktv_out_grad_ptr = qktv_out_grad.data<T>();
536 537 538 539 540 541 542 543 544
      ComputeBatchedGEMM(softmax_out_ptr,
                         qktv_out_grad_ptr,
                         v_grad_ptr,
                         true,
                         false,
                         gemm_m,
                         gemm_n,
                         gemm_k,
                         gemm_batch_size);
545 546 547 548

      // Backward: softmax_out_grad = qktv_out_grad * V^T (dx = dout * y^T)
      gemm_m = config->seq_len_r;
      gemm_n = config->m_size;
549
      gemm_k = config->head_dim;
550 551

      T* softmax_out_grad_ptr = softmax_out_grad.data<T>();
552 553 554 555 556 557 558 559 560
      ComputeBatchedGEMM(qktv_out_grad_ptr,
                         v_ptr,
                         softmax_out_grad_ptr,
                         false,
                         true,
                         gemm_m,
                         gemm_n,
                         gemm_k,
                         gemm_batch_size);
561 562
    }

563
    phi::DenseTensor* qk_out_grad = config->GetQKOutGrad(&softmax_out_grad);
564 565 566 567
    ComputeBiasMaskSoftmaxBackward(&softmax_out_grad,
                                   softmax_out,
                                   src_mask_grad,
                                   qk_out_grad,
568 569 570 571 572
                                   nonbatched_bias_grad);

    // Forward: qk_out = BatchedGEMM(Q, K^T)
    // Backward: k_grad = BatchedGEMM(qk_out_grad^T, Q) (dy = dout^t * x)
    int64_t gemm_m = config->m_size;
573
    int64_t gemm_n = config->head_dim;
574
    int64_t gemm_k = config->seq_len_r;
575
    T alpha = static_cast<T>(1.0 / sqrt(config->head_dim));
576 577

    T* qk_out_grad_ptr = qk_out_grad->data<T>();
578 579 580 581 582 583 584 585 586 587
    ComputeBatchedGEMM(qk_out_grad_ptr,
                       q_ptr,
                       k_grad_ptr,
                       true,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
588 589 590

    // Backward: q_grad = BatchedGEMM(qk_out_grad, K) (dx = dout * y)
    gemm_m = config->seq_len_r;
591
    gemm_n = config->head_dim;
592
    gemm_k = config->m_size;
593 594 595 596 597 598 599 600 601 602
    ComputeBatchedGEMM(qk_out_grad_ptr,
                       k_ptr,
                       q_grad_ptr,
                       false,
                       false,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       gemm_batch_size,
                       alpha);
603 604

    if (merge_qkv_) {
605
      phi::DenseTensor* qkv_out_grad = config->GetQKVOutGrad();
606 607
      ComputeQKVTransposeBackward(qkv_transpose_out_grad, qkv_out_grad);
    } else {
608 609 610
      phi::DenseTensor* q_out_grad = config->GetQueryOutGrad();
      phi::DenseTensor* k_out_grad = config->GetKeyOutGrad();
      phi::DenseTensor* v_out_grad = config->GetValueOutGrad();
611 612 613 614 615
      ComputeQKVTransposeBackward(q_transpose_out_grad,
                                  k_transpose_out_grad,
                                  v_transpose_out_grad,
                                  q_out_grad,
                                  k_out_grad,
616 617 618 619
                                  v_out_grad);
    }
  }

620 621 622 623 624 625
  void ComputeQKVTransposeForward(const phi::DenseTensor& q_out,
                                  const phi::DenseTensor& k_out,
                                  const phi::DenseTensor& v_out,
                                  phi::DenseTensor* q_transpose_out,
                                  phi::DenseTensor* k_transpose_out,
                                  phi::DenseTensor* v_transpose_out) {
626
    std::vector<int> perm = {0, 1, 3, 2, 4};
627 628 629 630 631 632
    phi::funcs::TransposeGPUKernelDriver<T>(
        dev_ctx_, q_out, perm, q_transpose_out);
    phi::funcs::TransposeGPUKernelDriver<T>(
        dev_ctx_, k_out, perm, k_transpose_out);
    phi::funcs::TransposeGPUKernelDriver<T>(
        dev_ctx_, v_out, perm, v_transpose_out);
633 634
  }

635 636 637 638 639 640
  void ComputeQKVTransposeBackward(const phi::DenseTensor& q_transpose_out_grad,
                                   const phi::DenseTensor& k_transpose_out_grad,
                                   const phi::DenseTensor& v_transpose_out_grad,
                                   phi::DenseTensor* q_out_grad,
                                   phi::DenseTensor* k_out_grad,
                                   phi::DenseTensor* v_out_grad) {
641
    std::vector<int> perm = {0, 1, 3, 2, 4};
642
    phi::funcs::TransposeGPUKernelDriver<T>(
643
        dev_ctx_, q_transpose_out_grad, perm, q_out_grad);
644
    phi::funcs::TransposeGPUKernelDriver<T>(
645
        dev_ctx_, k_transpose_out_grad, perm, k_out_grad);
646
    phi::funcs::TransposeGPUKernelDriver<T>(
647
        dev_ctx_, v_transpose_out_grad, perm, v_out_grad);
648 649
  }

650 651
  // [batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim] ->
  //         [3, batch_size, seq_len_m, num_heads, seq_len_r, head_dim]
652 653
  void ComputeQKVTransposeForward(const phi::DenseTensor& qkv_out,
                                  phi::DenseTensor* qkv_transpose_out) {
654
    std::vector<int> perm = {3, 0, 1, 4, 2, 5};
655 656
    phi::funcs::TransposeGPUKernelDriver<T>(
        dev_ctx_, qkv_out, perm, qkv_transpose_out);
657 658
  }

659 660 661
  void ComputeQKVTransposeBackward(
      const phi::DenseTensor& qkv_transpose_out_grad,
      phi::DenseTensor* qkv_out_grad) {
662
    std::vector<int> perm = {1, 2, 4, 0, 3, 5};
663
    phi::funcs::TransposeGPUKernelDriver<T>(
664
        dev_ctx_, qkv_transpose_out_grad, perm, qkv_out_grad);
665 666 667 668
  }

  // [batch_size, seq_len_m, num_head, seq_len_r, c] ->
  //         [batch_size, seq_len_m, seq_len_r, num_head, c]
669 670
  void ComputeQKTVTransposeForward(const phi::DenseTensor& qktv_out,
                                   phi::DenseTensor* fmha_out) {
671
    std::vector<int> perm = {0, 1, 3, 2, 4};
672
    phi::funcs::TransposeGPUKernelDriver<T>(dev_ctx_, qktv_out, perm, fmha_out);
673 674
  }

675 676
  void ComputeQKTVTransposeBackward(const phi::DenseTensor& fmha_out_grad,
                                    phi::DenseTensor* qktv_out_grad) {
677
    std::vector<int> perm = {0, 1, 3, 2, 4};
678 679
    phi::funcs::TransposeGPUKernelDriver<T>(
        dev_ctx_, fmha_out_grad, perm, qktv_out_grad);
680 681 682 683
  }

  // qk_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
684 685 686 687
  void ComputeBiasMaskSoftmaxForward(const phi::DenseTensor* nonbatched_bias,
                                     const phi::DenseTensor* src_mask,
                                     phi::DenseTensor* qk_out,
                                     phi::DenseTensor* softmax_out) {
688
    if (nonbatched_bias) {
689 690 691
      std::vector<const phi::DenseTensor*> ins = {
          qk_out, src_mask, nonbatched_bias};
      std::vector<phi::DenseTensor*> outs = {qk_out};
692
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
693 694
          dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
    } else {
695 696
      std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask};
      std::vector<phi::DenseTensor*> outs = {qk_out};
697
      phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
698 699 700 701 702 703 704
          dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
    }
    phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
  }

  // src_mask_out = qk_out + nonbatched_bias + src_mask
  // softmax_out = softmax(src_mask_out)
705 706 707 708 709
  void ComputeBiasMaskSoftmaxBackward(const phi::DenseTensor* softmax_out_grad,
                                      const phi::DenseTensor* softmax_out,
                                      phi::DenseTensor* src_mask_grad,
                                      phi::DenseTensor* qk_out_grad,
                                      phi::DenseTensor* nonbatched_bias_grad) {
710 711 712 713
    PADDLE_ENFORCE_NOT_NULL(
        qk_out_grad,
        platform::errors::NotFound("The qk_out_grad can not be nullptr."));

714 715
    PADDLE_ENFORCE_EQ(qk_out_grad->dims(),
                      softmax_out->dims(),
716 717 718 719
                      platform::errors::InvalidArgument(
                          "The shape of qk_out_grad and softmax_out is "
                          "expected to be the same. But recieved qk_out_grad's "
                          "shape = %s, softmax_out's shape = %s.",
720 721
                          qk_out_grad->dims(),
                          softmax_out->dims()));
722

723 724
    PADDLE_ENFORCE_EQ(src_mask_grad,
                      nullptr,
725 726 727
                      platform::errors::InvalidArgument(
                          "src_mask_grad is expected to be nullptr."));

728 729
    phi::SoftmaxBackwardCUDAKernelDriver<T>(
        dev_ctx_, *softmax_out, *softmax_out_grad, -1, qk_out_grad);
730 731

    if (nonbatched_bias_grad) {
732 733 734
      // [batch_size, seq_len_m, num_heads, seq_len_r, m_size] ->
      //      [batch_size, 1, num_heads, seq_len_r, m_size]
      phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
735 736 737 738 739
          dev_ctx_,
          *qk_out_grad,
          nonbatched_bias_grad,
          kps::IdentityFunctor<T>(),
          {1});
740 741 742 743
    }
  }

 private:
744 745 746 747 748 749 750 751 752
  void ComputeBatchedGEMM(const T* a_ptr,
                          const T* b_ptr,
                          T* c_ptr,
                          bool trans_a,
                          bool trans_b,
                          int64_t m,
                          int64_t n,
                          int64_t k,
                          int64_t batch_size,
753 754 755 756 757 758 759
                          T alpha = static_cast<T>(1.0),
                          T beta = static_cast<T>(0.0)) {
    CBLAS_TRANSPOSE cblas_trans_a = trans_a ? CblasTrans : CblasNoTrans;
    CBLAS_TRANSPOSE cblas_trans_b = trans_b ? CblasTrans : CblasNoTrans;
    int64_t stride_a = m * k;
    int64_t stride_b = k * n;

L
Leo Chen 已提交
760
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
761 762 763 764 765 766 767 768 769 770 771 772 773
    blas.BatchedGEMM(cblas_trans_a,
                     cblas_trans_b,
                     m,
                     n,
                     k,
                     alpha,
                     a_ptr,
                     b_ptr,
                     beta,
                     c_ptr,
                     batch_size,
                     stride_a,
                     stride_b);
774 775
  }

L
Leo Chen 已提交
776
  const phi::GPUContext& dev_ctx_;
777 778 779 780 781
  bool merge_qkv_;
};

}  // namespace operators
}  // namespace paddle