fmha_ref.h 24.3 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13 14 15 16 17
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/dropout_impl.cu.h"
18
#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h"
19
#include "paddle/phi/kernels/funcs/broadcast_function.h"
20
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
W
WangXi 已提交
21
#include "paddle/phi/kernels/funcs/elementwise_base.h"
22
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
W
WangXi 已提交
23
#include "paddle/phi/kernels/funcs/functors.h"
24
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
25
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

namespace paddle {
namespace operators {

class AttnDropoutParam {
 public:
  AttnDropoutParam() {
    is_test_ = false;
    dropout_implementation_ = "downgrade_in_infer";
    dropout_prob_ = 0.5;
    is_upscale_in_train_ = false;
    is_fix_seed_ = false;
    seed_val_ = 0;
    seed_ = nullptr;
  }
41 42 43 44 45 46
  AttnDropoutParam(bool is_test,
                   const std::string dropout_implementation,
                   float dropout_prob,
                   bool is_upscale_in_train,
                   bool is_fix_seed,
                   int seed_val,
47
                   const phi::DenseTensor* seed) {
48 49 50 51 52 53 54 55 56 57 58 59 60 61
    is_test_ = is_test;
    dropout_implementation_ = dropout_implementation;
    dropout_prob_ = dropout_prob;
    is_upscale_in_train_ = is_upscale_in_train;
    is_fix_seed_ = is_fix_seed;
    seed_val_ = seed_val;
    seed_ = seed;
  }
  bool is_test_;
  std::string dropout_implementation_;
  float dropout_prob_;
  bool is_upscale_in_train_;
  bool is_fix_seed_;
  int seed_val_;
62
  const phi::DenseTensor* seed_;
63 64 65 66 67
};

template <typename T>
class FMHARef {
 public:
L
Leo Chen 已提交
68
  FMHARef(const phi::GPUContext& dev_ctx,
69 70 71 72
          int64_t batch_size,
          int64_t seq_len,
          int64_t num_head,
          int64_t head_dim,
73 74 75 76 77 78 79 80 81 82
          AttnDropoutParam param)
      : dev_ctx_(dev_ctx),
        batch_size_(batch_size),
        seq_len_(seq_len),
        num_head_(num_head),
        head_dim_(head_dim),
        dropout_param_(param) {}

  ~FMHARef() {}

83 84 85 86 87 88 89 90 91 92 93 94
  void ComputeForward(const phi::DenseTensor& qkv_input_tensor,
                      const phi::DenseTensor* cache_kv_tensor,
                      const phi::DenseTensor* src_mask_tensor,
                      phi::DenseTensor* transpose_2_out_tensor,
                      phi::DenseTensor* cache_kv_out_tensor,
                      phi::DenseTensor* qk_out_tensor,
                      phi::DenseTensor* src_mask_out_tensor,
                      phi::DenseTensor* softmax_out_tensor,
                      phi::DenseTensor* dropout_mask_out_tensor,
                      phi::DenseTensor* dropout_out_tensor,
                      phi::DenseTensor* qktv_out_tensor,
                      phi::DenseTensor* fmha_out_tensor) {
95
    // input shape: [bs, seq_len, 3, num_head, head_dim]
96
    // transpose with perm [2, 0, 3, 1, 4],
97 98
    // output_shape: [3, bs, num_head, seq_len, head_dim]
    std::vector<int> perm_1 = {2, 0, 3, 1, 4};
99
    phi::funcs::TransposeGPUKernelDriver<T>(
100
        dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor);
101 102 103 104 105 106 107
    T* qkv_data = transpose_2_out_tensor->data<T>();
    T* qk_out_data = qk_out_tensor->data<T>();
    T* qktv_out_data = qktv_out_tensor->data<T>();
    T* softmax_out_data = softmax_out_tensor->data<T>();
    T* dropout_out_data = dropout_out_tensor->data<T>();
    T* fmha_out_data = fmha_out_tensor->data<T>();

108 109 110 111 112 113 114 115 116 117 118
    auto out_seq_len = seq_len_;
    if (cache_kv_tensor) {
      // kv [2, bs, num_head, seq_len, head_dim]
      auto kv_tensor = transpose_2_out_tensor->Slice(1, 3);
      phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
      // out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
      concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor);
      out_seq_len = cache_kv_out_tensor->dims()[3];
    }

    int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
119
    T* q_ptr = qkv_data;
120 121 122 123 124 125 126 127 128 129 130 131
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;

    if (cache_kv_tensor) {
      int64_t k_size = cache_kv_out_tensor->numel() / 2;
      k_ptr = cache_kv_out_tensor->data<T>();
      v_ptr = k_ptr + k_size;
    } else {
      int64_t k_size = q_size;
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + k_size;
    }
132

W
WangXi 已提交
133 134 135 136 137 138
    {
      // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
      // float16 calculation, INF may appear in QK^T if we do not scale before.
      float alpha = 1.0 / sqrt(head_dim_);
      auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
      auto functor = phi::funcs::ScaleFunctor<T>(alpha);
139 140
      std::vector<const phi::DenseTensor*> ins = {&q_tensor};
      std::vector<phi::DenseTensor*> outs = {&q_tensor};
141
      phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
W
WangXi 已提交
142 143
    }

144 145 146
    // q*k^t, batched_gemm
    CBLAS_TRANSPOSE transA = CblasNoTrans;
    CBLAS_TRANSPOSE transB = CblasTrans;
L
Leo Chen 已提交
147
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
148 149
    int gemm_batch_size = batch_size_ * num_head_;
    int gemm_m = seq_len_;
150
    int gemm_n = out_seq_len;
151
    int gemm_k = head_dim_;
W
WangXi 已提交
152
    T alpha = static_cast<T>(1.0);
153
    T beta = static_cast<T>(0.0);
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
    int64_t stride_a = gemm_m * gemm_k;
    int64_t stride_b = gemm_k * gemm_n;
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     q_ptr,
                     k_ptr,
                     beta,
                     qk_out_data,
                     gemm_batch_size,
                     stride_a,
                     stride_b);
    int softmax_axis = -1;
    if (src_mask_tensor != nullptr) {
      if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
        LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
                                        src_mask_tensor->data<T>(),
                                        softmax_out_data,
                                        batch_size_,
                                        num_head_,
                                        seq_len_,
                                        dev_ctx_.stream());
      } else {
        std::vector<const phi::DenseTensor*> ins;
        std::vector<phi::DenseTensor*> outs;
        ins.emplace_back(qk_out_tensor);
        ins.emplace_back(src_mask_tensor);
        outs.emplace_back(src_mask_out_tensor);
        int elewise_add_axis = -1;
        phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
            dev_ctx_,
            ins,
            &outs,
            elewise_add_axis,
            phi::funcs::AddFunctor<T>());

        phi::SoftmaxForwardCUDAKernelDriver<T>(
            dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
      }
    } else {
      phi::SoftmaxForwardCUDAKernelDriver<T>(
          dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
    }

    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
    gemm_k = out_seq_len;
    alpha = static_cast<T>(1.0);
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;

    if (dropout_param_.dropout_prob_) {
      DropoutFwGPUKernelDriver<T>(
          static_cast<const phi::GPUContext&>(dev_ctx_),
          dropout_param_.is_test_,
          dropout_param_.dropout_prob_,
          dropout_param_.is_upscale_in_train_,
          dropout_param_.is_fix_seed_,
          dropout_param_.seed_val_,
          static_cast<const phi::DenseTensor&>(*softmax_out_tensor),
          dropout_param_.seed_,
          dropout_mask_out_tensor,
          dropout_out_tensor,
          false);
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       dropout_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
    } else {
      // softmax_out * v, batched_gemm
      // output shape: [batch_size, num_heads, seq_len, head_dim]
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       softmax_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
    }
    // transpose: [0, 2, 1, 3]
    // output shape: [batch_size, seq_len, num_heads, head_dim]
    std::vector<int> perm_3 = {0, 2, 1, 3};
255
    phi::funcs::TransposeGPUKernelDriver<T>(
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
        dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
  }

  void ComputeForwardWithoutTranspose(const phi::DenseTensor& qkv_input_tensor,
                                      const phi::DenseTensor* cache_kv_tensor,
                                      const phi::DenseTensor* src_mask_tensor,
                                      phi::DenseTensor* q_transpose_out_tensor,
                                      phi::DenseTensor* kv_transpose_out_tensor,
                                      phi::DenseTensor* cache_kv_out_tensor,
                                      phi::DenseTensor* qk_out_tensor,
                                      phi::DenseTensor* src_mask_out_tensor,
                                      phi::DenseTensor* softmax_out_tensor,
                                      phi::DenseTensor* dropout_mask_out_tensor,
                                      phi::DenseTensor* dropout_out_tensor,
                                      phi::DenseTensor* qktv_out_tensor,
                                      phi::DenseTensor* fmha_out_tensor) {
    // input shape: [bs, seq_len, 3, num_head, head_dim]
    // transpose with perm [2, 0, 3, 1, 4],
    // output_shape: [3, bs, num_head, seq_len, head_dim]
    T* qk_out_data = qk_out_tensor->data<T>();
    T* qktv_out_data = qktv_out_tensor->data<T>();
    T* softmax_out_data = softmax_out_tensor->data<T>();
    T* dropout_out_data = dropout_out_tensor->data<T>();
    T* fmha_out_data = fmha_out_tensor->data<T>();

    auto out_seq_len = seq_len_;
    if (cache_kv_tensor) {
      // kv [2, bs, num_head, seq_len, head_dim]
      phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
      // out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
      concat(dev_ctx_,
             {*cache_kv_tensor, *kv_transpose_out_tensor},
             3,
             cache_kv_out_tensor);
      out_seq_len = cache_kv_out_tensor->dims()[3];
    }

    int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
    T* q_ptr = q_transpose_out_tensor->data<T>();
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;

    if (cache_kv_tensor) {
      int64_t k_size = cache_kv_out_tensor->numel() / 2;
      k_ptr = cache_kv_out_tensor->data<T>();
      v_ptr = k_ptr + k_size;
    } else {
      int64_t k_size = q_size;
      k_ptr = kv_transpose_out_tensor->data<T>();
      v_ptr = k_ptr + k_size;
    }

    {
      // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
      // float16 calculation, INF may appear in QK^T if we do not scale before.
      float alpha = 1.0 / sqrt(head_dim_);
      auto functor = phi::funcs::ScaleFunctor<T>(alpha);
      std::vector<const phi::DenseTensor*> ins = {q_transpose_out_tensor};
      std::vector<phi::DenseTensor*> outs = {q_transpose_out_tensor};
      phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
    }

    // q*k^t, batched_gemm
    CBLAS_TRANSPOSE transA = CblasNoTrans;
    CBLAS_TRANSPOSE transB = CblasTrans;
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
    int gemm_batch_size = batch_size_ * num_head_;
    int gemm_m = seq_len_;
    int gemm_n = out_seq_len;
    int gemm_k = head_dim_;
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);
328 329
    int64_t stride_a = gemm_m * gemm_k;
    int64_t stride_b = gemm_k * gemm_n;
330 331 332 333 334 335 336 337 338 339 340 341
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     q_ptr,
                     k_ptr,
                     beta,
                     qk_out_data,
                     gemm_batch_size,
                     stride_a,
342 343
                     stride_b);
    int softmax_axis = -1;
344
    if (src_mask_tensor != nullptr) {
345
      if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
346 347 348 349 350 351 352
        LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
                                        src_mask_tensor->data<T>(),
                                        softmax_out_data,
                                        batch_size_,
                                        num_head_,
                                        seq_len_,
                                        dev_ctx_.stream());
353
      } else {
354 355
        std::vector<const phi::DenseTensor*> ins;
        std::vector<phi::DenseTensor*> outs;
356 357 358 359
        ins.emplace_back(qk_out_tensor);
        ins.emplace_back(src_mask_tensor);
        outs.emplace_back(src_mask_out_tensor);
        int elewise_add_axis = -1;
360
        phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
361 362 363 364
            dev_ctx_,
            ins,
            &outs,
            elewise_add_axis,
365
            phi::funcs::AddFunctor<T>());
366

367 368 369
        phi::SoftmaxForwardCUDAKernelDriver<T>(
            dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
      }
370
    } else {
371 372
      phi::SoftmaxForwardCUDAKernelDriver<T>(
          dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
373 374 375 376 377
    }

    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
378
    gemm_k = out_seq_len;
379 380 381 382 383 384
    alpha = static_cast<T>(1.0);
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;

    if (dropout_param_.dropout_prob_) {
      DropoutFwGPUKernelDriver<T>(
H
hong 已提交
385
          static_cast<const phi::GPUContext&>(dev_ctx_),
386 387 388 389
          dropout_param_.is_test_,
          dropout_param_.dropout_prob_,
          dropout_param_.is_upscale_in_train_,
          dropout_param_.is_fix_seed_,
390
          dropout_param_.seed_val_,
391
          static_cast<const phi::DenseTensor&>(*softmax_out_tensor),
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
          dropout_param_.seed_,
          dropout_mask_out_tensor,
          dropout_out_tensor,
          false);
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       dropout_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
409 410 411
    } else {
      // softmax_out * v, batched_gemm
      // output shape: [batch_size, num_heads, seq_len, head_dim]
412 413 414 415 416 417 418 419 420 421 422 423 424
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       softmax_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
425 426 427 428
    }
    // transpose: [0, 2, 1, 3]
    // output shape: [batch_size, seq_len, num_heads, head_dim]
    std::vector<int> perm_3 = {0, 2, 1, 3};
429
    phi::funcs::TransposeGPUKernelDriver<T>(
430
        dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
431 432
  }

433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
  void ComputeBackward(const phi::DenseTensor& transpose_2_out_tensor,
                       const phi::DenseTensor* src_mask_tensor,
                       const phi::DenseTensor& softmax_out_tensor,
                       const phi::DenseTensor& dropout_mask_out_tensor,
                       const phi::DenseTensor& dropout_out_tensor,
                       const phi::DenseTensor& qk_out_tensor,
                       const phi::DenseTensor& src_mask_out_tensor,
                       const phi::DenseTensor& fmha_out_grad_tensor,
                       phi::DenseTensor* qktv_out_grad_tensor,
                       phi::DenseTensor* dropout_out_grad_tensor,
                       phi::DenseTensor* softmax_out_grad_tensor,
                       phi::DenseTensor* src_mask_out_grad_tensor,
                       phi::DenseTensor* qk_out_grad_tensor,
                       phi::DenseTensor* transpose_2_out_grad_tensor,
                       phi::DenseTensor* src_mask_grad_tensor,
                       phi::DenseTensor* qkv_input_grad_tensor) {
L
Leo Chen 已提交
449
    auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx_);
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
    int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
    int k_size = q_size;
    int softmax_axis = -1;

    T* qkv_grad_data = transpose_2_out_grad_tensor->data<T>();
    T* q_grad_ptr = qkv_grad_data;
    T* k_grad_ptr = q_grad_ptr + q_size;
    T* v_grad_ptr = k_grad_ptr + k_size;
    const T* qkv_data = transpose_2_out_tensor.data<T>();
    const T* q_ptr = qkv_data;
    const T* k_ptr = q_ptr + q_size;
    const T* v_ptr = k_ptr + k_size;

    const T* softmax_out_data = softmax_out_tensor.data<T>();
    T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
    const T* dropout_out_data = dropout_out_tensor.data<T>();
    T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
    T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();

    // transpose bw
    std::vector<int> perm_3 = {0, 2, 1, 3};
471
    phi::funcs::TransposeGPUKernelDriver<T>(
472
        dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor);
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487

    // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) =
    // qktv_out_data(out)
    CBLAS_TRANSPOSE transA = CblasTrans;
    CBLAS_TRANSPOSE transB = CblasNoTrans;
    int gemm_batch_size = batch_size_ * num_head_;
    int gemm_m = seq_len_;
    int gemm_n = head_dim_;
    int gemm_k = seq_len_;
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);
    int64_t stride_a = gemm_m * gemm_k;
    int64_t stride_b = gemm_k * gemm_n;
    // bw: dy = x^t * dout
    if (dropout_param_.dropout_prob_) {
488 489 490 491 492 493 494 495 496 497 498 499 500
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       dropout_out_data,
                       qktv_out_grad_data,
                       beta,
                       v_grad_ptr,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
501
    } else {
502 503 504 505 506 507 508 509 510 511 512 513 514
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       softmax_out_data,
                       qktv_out_grad_data,
                       beta,
                       v_grad_ptr,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
515 516 517 518 519 520 521 522 523 524
    }
    // bw: dx = dout * y^t
    transA = CblasNoTrans;
    transB = CblasTrans;
    gemm_m = seq_len_;
    gemm_n = seq_len_;
    gemm_k = head_dim_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
    if (dropout_param_.dropout_prob_) {
525 526 527 528 529 530 531 532 533 534 535 536 537
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       qktv_out_grad_data,
                       v_ptr,
                       beta,
                       dropout_out_grad_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
538
    } else {
539 540 541 542 543 544 545 546 547 548 549 550 551
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       qktv_out_grad_data,
                       v_ptr,
                       beta,
                       softmax_out_grad_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
552 553 554 555
    }
    // dropout bw
    if (dropout_param_.dropout_prob_) {
      DropoutGradGPUKernelDriver<T>(
556 557 558 559
          static_cast<const phi::GPUContext&>(dev_ctx_),
          false,
          dropout_param_.dropout_prob_,
          dropout_param_.is_upscale_in_train_,
560
          static_cast<const phi::DenseTensor&>(*dropout_out_grad_tensor),
561 562 563
          dropout_mask_out_tensor,
          softmax_out_grad_tensor,
          false);
564 565
    }

566
    if (src_mask_tensor != nullptr) {
567 568 569 570 571
      phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
                                              softmax_out_tensor,
                                              *softmax_out_grad_tensor,
                                              softmax_axis,
                                              src_mask_out_grad_tensor);
572 573 574 575 576 577 578
      // recall LaunchElementwiseCudaKernel fw:  src_mask_out = qk_out +
      // src_mask
      // Special case when dy is not needed and dx doesn't reduce
      if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr &&
          qk_out_tensor.dims() == src_mask_out_tensor.dims()) {
        VLOG(4) << "Special case when dy is not needed and dx doesn't "
                   "reduce";
579 580 581 582
        framework::TensorCopy(*src_mask_out_grad_tensor,
                              dev_ctx_.GetPlace(),
                              dev_ctx_,
                              qk_out_grad_tensor);
583 584 585 586 587 588 589 590
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Only used for the backward elementwise_add op when"
            "dy is not needed and dx is not reduce"));
        return;
      }

    } else {
591 592
      phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
                                              softmax_out_tensor,
593
                                              *softmax_out_grad_tensor,
594 595
                                              softmax_axis,
                                              qk_out_grad_tensor);
596 597 598
    }

    T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
W
WangXi 已提交
599 600 601
    // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
    //   alpha = 1.0 in backward.
    alpha = static_cast<T>(1.0);
602 603 604 605 606 607 608 609 610
    // recall batchedgemm(nt) fw:  q_ptr * (k_ptr)^t = qk_out
    // bw: dy (seq_len * head_dim) = (dout)^t * x
    transA = CblasTrans;
    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
    gemm_k = seq_len_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
611 612 613 614 615 616 617 618 619 620 621 622 623
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     qk_out_grad_data,
                     q_ptr,
                     beta,
                     k_grad_ptr,
                     gemm_batch_size,
                     stride_a,
                     stride_b);
624
    // dx (seq_len * head_dim) = dout * y
W
WangXi 已提交
625
    alpha = static_cast<T>(1.0 / sqrt(head_dim_));
626 627 628 629 630 631 632
    transA = CblasNoTrans;
    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
    gemm_k = seq_len_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
633 634 635 636 637 638 639 640 641 642 643 644 645
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     qk_out_grad_data,
                     k_ptr,
                     beta,
                     q_grad_ptr,
                     gemm_batch_size,
                     stride_a,
                     stride_b);
646 647 648

    // transpose bw
    std::vector<int> perm_1 = {1, 3, 0, 2, 4};
649
    phi::funcs::TransposeGPUKernelDriver<T>(
650
        dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor);
651 652 653
  }

 private:
L
Leo Chen 已提交
654
  const phi::GPUContext& dev_ctx_;
655 656 657 658 659 660 661 662 663 664 665

  int64_t batch_size_;
  int64_t seq_len_;
  int64_t num_head_;
  int64_t head_dim_;

  AttnDropoutParam dropout_param_;
};

}  // namespace operators
}  // namespace paddle