fmha_ref.h 17.3 KB
Newer Older
1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2

3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

7
    http://www.apache.org/licenses/LICENSE-2.0
8

9 10 11 12 13 14 15 16 17
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/dropout_impl.cu.h"
18
#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h"
19
#include "paddle/fluid/operators/transpose_op.cu.h"
20
#include "paddle/phi/kernels/funcs/broadcast_function.h"
21
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
W
WangXi 已提交
22
#include "paddle/phi/kernels/funcs/elementwise_base.h"
23
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
W
WangXi 已提交
24
#include "paddle/phi/kernels/funcs/functors.h"
25
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class AttnDropoutParam {
 public:
  AttnDropoutParam() {
    is_test_ = false;
    dropout_implementation_ = "downgrade_in_infer";
    dropout_prob_ = 0.5;
    is_upscale_in_train_ = false;
    is_fix_seed_ = false;
    seed_val_ = 0;
    seed_ = nullptr;
  }
43 44 45 46 47 48 49
  AttnDropoutParam(bool is_test,
                   const std::string dropout_implementation,
                   float dropout_prob,
                   bool is_upscale_in_train,
                   bool is_fix_seed,
                   int seed_val,
                   const Tensor* seed) {
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    is_test_ = is_test;
    dropout_implementation_ = dropout_implementation;
    dropout_prob_ = dropout_prob;
    is_upscale_in_train_ = is_upscale_in_train;
    is_fix_seed_ = is_fix_seed;
    seed_val_ = seed_val;
    seed_ = seed;
  }
  bool is_test_;
  std::string dropout_implementation_;
  float dropout_prob_;
  bool is_upscale_in_train_;
  bool is_fix_seed_;
  int seed_val_;
  const Tensor* seed_;
};

template <typename T>
class FMHARef {
 public:
70 71 72 73 74
  FMHARef(const platform::CUDADeviceContext& dev_ctx,
          int64_t batch_size,
          int64_t seq_len,
          int64_t num_head,
          int64_t head_dim,
75 76 77 78 79 80 81 82 83 84 85
          AttnDropoutParam param)
      : dev_ctx_(dev_ctx),
        batch_size_(batch_size),
        seq_len_(seq_len),
        num_head_(num_head),
        head_dim_(head_dim),
        dropout_param_(param) {}

  ~FMHARef() {}

  void ComputeForward(const Tensor& qkv_input_tensor,
86
                      const Tensor* cache_kv_tensor,
87
                      const Tensor* src_mask_tensor,
88
                      Tensor* transpose_2_out_tensor,
89 90 91 92
                      Tensor* cache_kv_out_tensor,
                      Tensor* qk_out_tensor,
                      Tensor* src_mask_out_tensor,
                      Tensor* softmax_out_tensor,
93
                      Tensor* dropout_mask_out_tensor,
94 95
                      Tensor* dropout_out_tensor,
                      Tensor* qktv_out_tensor,
96 97
                      Tensor* fmha_out_tensor) {
    // input shape: [bs, seq_len, 3, num_head, head_dim]
98
    // transpose with perm [2, 0, 3, 1, 4],
99 100
    // output_shape: [3, bs, num_head, seq_len, head_dim]
    std::vector<int> perm_1 = {2, 0, 3, 1, 4};
101
    TransposeGPUKernelDriver<T>(
102
        dev_ctx_, qkv_input_tensor, perm_1, transpose_2_out_tensor);
103 104 105 106 107 108 109
    T* qkv_data = transpose_2_out_tensor->data<T>();
    T* qk_out_data = qk_out_tensor->data<T>();
    T* qktv_out_data = qktv_out_tensor->data<T>();
    T* softmax_out_data = softmax_out_tensor->data<T>();
    T* dropout_out_data = dropout_out_tensor->data<T>();
    T* fmha_out_data = fmha_out_tensor->data<T>();

110 111 112 113 114 115 116 117 118 119 120
    auto out_seq_len = seq_len_;
    if (cache_kv_tensor) {
      // kv [2, bs, num_head, seq_len, head_dim]
      auto kv_tensor = transpose_2_out_tensor->Slice(1, 3);
      phi::funcs::ConcatFunctor<phi::GPUContext, T> concat;
      // out [2, bs, num_head, cache_seq_len + seq_len, head_dim]
      concat(dev_ctx_, {*cache_kv_tensor, kv_tensor}, 3, cache_kv_out_tensor);
      out_seq_len = cache_kv_out_tensor->dims()[3];
    }

    int64_t q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
121
    T* q_ptr = qkv_data;
122 123 124 125 126 127 128 129 130 131 132 133
    T* k_ptr = nullptr;
    T* v_ptr = nullptr;

    if (cache_kv_tensor) {
      int64_t k_size = cache_kv_out_tensor->numel() / 2;
      k_ptr = cache_kv_out_tensor->data<T>();
      v_ptr = k_ptr + k_size;
    } else {
      int64_t k_size = q_size;
      k_ptr = q_ptr + q_size;
      v_ptr = k_ptr + k_size;
    }
134

W
WangXi 已提交
135 136 137 138 139 140 141 142
    {
      // NOTE(wangxi): We scale Q with 1/sqrt(Dh) before QK^T, because for
      // float16 calculation, INF may appear in QK^T if we do not scale before.
      float alpha = 1.0 / sqrt(head_dim_);
      auto q_tensor = transpose_2_out_tensor->Slice(0, 1);
      auto functor = phi::funcs::ScaleFunctor<T>(alpha);
      std::vector<const framework::Tensor*> ins = {&q_tensor};
      std::vector<framework::Tensor*> outs = {&q_tensor};
143
      phi::funcs::ElementwiseKernel<T>(dev_ctx_, ins, &outs, functor);
W
WangXi 已提交
144 145
    }

146 147 148
    // q*k^t, batched_gemm
    CBLAS_TRANSPOSE transA = CblasNoTrans;
    CBLAS_TRANSPOSE transB = CblasTrans;
149
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
150 151
    int gemm_batch_size = batch_size_ * num_head_;
    int gemm_m = seq_len_;
152
    int gemm_n = out_seq_len;
153
    int gemm_k = head_dim_;
W
WangXi 已提交
154
    T alpha = static_cast<T>(1.0);
155 156 157
    T beta = static_cast<T>(0.0);
    int64_t stride_a = gemm_m * gemm_k;
    int64_t stride_b = gemm_k * gemm_n;
158 159 160 161 162 163 164 165 166 167 168 169
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     q_ptr,
                     k_ptr,
                     beta,
                     qk_out_data,
                     gemm_batch_size,
                     stride_a,
170 171
                     stride_b);
    int softmax_axis = -1;
172
    if (src_mask_tensor != nullptr) {
173
      if (src_mask_out_tensor == nullptr && seq_len_ == out_seq_len) {
174 175 176 177 178 179 180
        LaunchFusedSoftmaxMaskKernel<T>(qk_out_data,
                                        src_mask_tensor->data<T>(),
                                        softmax_out_data,
                                        batch_size_,
                                        num_head_,
                                        seq_len_,
                                        dev_ctx_.stream());
181 182 183 184 185 186 187
      } else {
        std::vector<const Tensor*> ins;
        std::vector<Tensor*> outs;
        ins.emplace_back(qk_out_tensor);
        ins.emplace_back(src_mask_tensor);
        outs.emplace_back(src_mask_out_tensor);
        int elewise_add_axis = -1;
188
        phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
189 190 191 192
            dev_ctx_,
            ins,
            &outs,
            elewise_add_axis,
193
            phi::funcs::AddFunctor<T>());
194

195 196 197
        phi::SoftmaxForwardCUDAKernelDriver<T>(
            dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
      }
198
    } else {
199 200
      phi::SoftmaxForwardCUDAKernelDriver<T>(
          dev_ctx_, *qk_out_tensor, softmax_axis, softmax_out_tensor);
201 202 203 204 205
    }

    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
206
    gemm_k = out_seq_len;
207 208 209 210 211 212
    alpha = static_cast<T>(1.0);
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;

    if (dropout_param_.dropout_prob_) {
      DropoutFwGPUKernelDriver<T>(
H
hong 已提交
213
          static_cast<const phi::GPUContext&>(dev_ctx_),
214 215 216 217
          dropout_param_.is_test_,
          dropout_param_.dropout_prob_,
          dropout_param_.is_upscale_in_train_,
          dropout_param_.is_fix_seed_,
218
          dropout_param_.seed_val_,
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
          static_cast<const Tensor&>(*softmax_out_tensor),
          dropout_param_.seed_,
          dropout_mask_out_tensor,
          dropout_out_tensor,
          false);
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       dropout_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
237 238 239
    } else {
      // softmax_out * v, batched_gemm
      // output shape: [batch_size, num_heads, seq_len, head_dim]
240 241 242 243 244 245 246 247 248 249 250 251 252
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       softmax_out_data,
                       v_ptr,
                       beta,
                       qktv_out_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
253 254 255 256
    }
    // transpose: [0, 2, 1, 3]
    // output shape: [batch_size, seq_len, num_heads, head_dim]
    std::vector<int> perm_3 = {0, 2, 1, 3};
257
    TransposeGPUKernelDriver<T>(
258
        dev_ctx_, *qktv_out_tensor, perm_3, fmha_out_tensor);
259 260
  }

261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  void ComputeBackward(const Tensor& transpose_2_out_tensor,
                       const Tensor* src_mask_tensor,
                       const Tensor& softmax_out_tensor,
                       const Tensor& dropout_mask_out_tensor,
                       const Tensor& dropout_out_tensor,
                       const Tensor& qk_out_tensor,
                       const Tensor& src_mask_out_tensor,
                       const Tensor& fmha_out_grad_tensor,
                       Tensor* qktv_out_grad_tensor,
                       Tensor* dropout_out_grad_tensor,
                       Tensor* softmax_out_grad_tensor,
                       Tensor* src_mask_out_grad_tensor,
                       Tensor* qk_out_grad_tensor,
                       Tensor* transpose_2_out_grad_tensor,
                       Tensor* src_mask_grad_tensor,
                       Tensor* qkv_input_grad_tensor) {
277
    auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(dev_ctx_);
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
    int q_size = batch_size_ * seq_len_ * num_head_ * head_dim_;
    int k_size = q_size;
    int softmax_axis = -1;

    T* qkv_grad_data = transpose_2_out_grad_tensor->data<T>();
    T* q_grad_ptr = qkv_grad_data;
    T* k_grad_ptr = q_grad_ptr + q_size;
    T* v_grad_ptr = k_grad_ptr + k_size;
    const T* qkv_data = transpose_2_out_tensor.data<T>();
    const T* q_ptr = qkv_data;
    const T* k_ptr = q_ptr + q_size;
    const T* v_ptr = k_ptr + k_size;

    const T* softmax_out_data = softmax_out_tensor.data<T>();
    T* softmax_out_grad_data = softmax_out_grad_tensor->data<T>();
    const T* dropout_out_data = dropout_out_tensor.data<T>();
    T* dropout_out_grad_data = dropout_out_grad_tensor->data<T>();
    T* qktv_out_grad_data = qktv_out_grad_tensor->data<T>();

    // transpose bw
    std::vector<int> perm_3 = {0, 2, 1, 3};
299
    TransposeGPUKernelDriver<T>(
300
        dev_ctx_, fmha_out_grad_tensor, perm_3, qktv_out_grad_tensor);
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315

    // recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) =
    // qktv_out_data(out)
    CBLAS_TRANSPOSE transA = CblasTrans;
    CBLAS_TRANSPOSE transB = CblasNoTrans;
    int gemm_batch_size = batch_size_ * num_head_;
    int gemm_m = seq_len_;
    int gemm_n = head_dim_;
    int gemm_k = seq_len_;
    T alpha = static_cast<T>(1.0);
    T beta = static_cast<T>(0.0);
    int64_t stride_a = gemm_m * gemm_k;
    int64_t stride_b = gemm_k * gemm_n;
    // bw: dy = x^t * dout
    if (dropout_param_.dropout_prob_) {
316 317 318 319 320 321 322 323 324 325 326 327 328
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       dropout_out_data,
                       qktv_out_grad_data,
                       beta,
                       v_grad_ptr,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
329
    } else {
330 331 332 333 334 335 336 337 338 339 340 341 342
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       softmax_out_data,
                       qktv_out_grad_data,
                       beta,
                       v_grad_ptr,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
343 344 345 346 347 348 349 350 351 352
    }
    // bw: dx = dout * y^t
    transA = CblasNoTrans;
    transB = CblasTrans;
    gemm_m = seq_len_;
    gemm_n = seq_len_;
    gemm_k = head_dim_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
    if (dropout_param_.dropout_prob_) {
353 354 355 356 357 358 359 360 361 362 363 364 365
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       qktv_out_grad_data,
                       v_ptr,
                       beta,
                       dropout_out_grad_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
366
    } else {
367 368 369 370 371 372 373 374 375 376 377 378 379
      blas.BatchedGEMM(transA,
                       transB,
                       gemm_m,
                       gemm_n,
                       gemm_k,
                       alpha,
                       qktv_out_grad_data,
                       v_ptr,
                       beta,
                       softmax_out_grad_data,
                       gemm_batch_size,
                       stride_a,
                       stride_b);
380 381 382 383
    }
    // dropout bw
    if (dropout_param_.dropout_prob_) {
      DropoutGradGPUKernelDriver<T>(
384 385 386 387
          static_cast<const phi::GPUContext&>(dev_ctx_),
          false,
          dropout_param_.dropout_prob_,
          dropout_param_.is_upscale_in_train_,
388
          static_cast<const Tensor&>(*dropout_out_grad_tensor),
389 390 391
          dropout_mask_out_tensor,
          softmax_out_grad_tensor,
          false);
392 393
    }

394
    if (src_mask_tensor != nullptr) {
395 396 397 398 399
      phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
                                              softmax_out_tensor,
                                              *softmax_out_grad_tensor,
                                              softmax_axis,
                                              src_mask_out_grad_tensor);
400 401 402 403 404 405 406
      // recall LaunchElementwiseCudaKernel fw:  src_mask_out = qk_out +
      // src_mask
      // Special case when dy is not needed and dx doesn't reduce
      if (qk_out_grad_tensor != nullptr && src_mask_grad_tensor == nullptr &&
          qk_out_tensor.dims() == src_mask_out_tensor.dims()) {
        VLOG(4) << "Special case when dy is not needed and dx doesn't "
                   "reduce";
407 408 409 410
        framework::TensorCopy(*src_mask_out_grad_tensor,
                              dev_ctx_.GetPlace(),
                              dev_ctx_,
                              qk_out_grad_tensor);
411 412 413 414 415 416 417 418
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Only used for the backward elementwise_add op when"
            "dy is not needed and dx is not reduce"));
        return;
      }

    } else {
419 420
      phi::SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_,
                                              softmax_out_tensor,
421
                                              *softmax_out_grad_tensor,
422 423
                                              softmax_axis,
                                              qk_out_grad_tensor);
424 425 426
    }

    T* qk_out_grad_data = qk_out_grad_tensor->data<T>();
W
WangXi 已提交
427 428 429
    // NOTE(wangxi): For we scale Q with 1/sqrt(Dh) in forward, so we set
    //   alpha = 1.0 in backward.
    alpha = static_cast<T>(1.0);
430 431 432 433 434 435 436 437 438
    // recall batchedgemm(nt) fw:  q_ptr * (k_ptr)^t = qk_out
    // bw: dy (seq_len * head_dim) = (dout)^t * x
    transA = CblasTrans;
    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
    gemm_k = seq_len_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
439 440 441 442 443 444 445 446 447 448 449 450 451
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     qk_out_grad_data,
                     q_ptr,
                     beta,
                     k_grad_ptr,
                     gemm_batch_size,
                     stride_a,
                     stride_b);
452
    // dx (seq_len * head_dim) = dout * y
W
WangXi 已提交
453
    alpha = static_cast<T>(1.0 / sqrt(head_dim_));
454 455 456 457 458 459 460
    transA = CblasNoTrans;
    transB = CblasNoTrans;
    gemm_m = seq_len_;
    gemm_n = head_dim_;
    gemm_k = seq_len_;
    stride_a = gemm_m * gemm_k;
    stride_b = gemm_k * gemm_n;
461 462 463 464 465 466 467 468 469 470 471 472 473
    blas.BatchedGEMM(transA,
                     transB,
                     gemm_m,
                     gemm_n,
                     gemm_k,
                     alpha,
                     qk_out_grad_data,
                     k_ptr,
                     beta,
                     q_grad_ptr,
                     gemm_batch_size,
                     stride_a,
                     stride_b);
474 475 476

    // transpose bw
    std::vector<int> perm_1 = {1, 3, 0, 2, 4};
477 478
    TransposeGPUKernelDriver<T>(
        dev_ctx_, *transpose_2_out_grad_tensor, perm_1, qkv_input_grad_tensor);
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
  }

 private:
  const platform::CUDADeviceContext& dev_ctx_;

  int64_t batch_size_;
  int64_t seq_len_;
  int64_t num_head_;
  int64_t head_dim_;

  AttnDropoutParam dropout_param_;
};

}  // namespace operators
}  // namespace paddle