fused_attention_op.cu 32.3 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25 26
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
27
#include "paddle/phi/api/include/tensor.h"
28 29
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
30
#include "paddle/phi/kernels/funcs/math_function.h"
L
Li Min 已提交
31

32
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
33
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
34 35 36 37
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
38 39 40 41 42
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

43 44 45
template <typename T>
static void AllReduce(framework::Tensor &tensor,  // NOLINT
                      const int ring_id,
L
Leo Chen 已提交
46
                      const phi::GPUContext &ctx) {
47 48
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
49 50 51 52
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup *pg = map->get(ring_id);
53 54
    auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);

55 56 57 58 59 60
    std::vector<phi::DenseTensor> in_tensor;
    std::vector<phi::DenseTensor> out_tensor;
    in_tensor.push_back(tensor);
    out_tensor.push_back(tensor);
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
61
    auto task = pg_nccl->AllReduce(in_tensor, out_tensor, opts, true, true);
62 63 64 65 66 67 68
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void *sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
69
    void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
70 71 72 73 74
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
75 76 77 78 79 80 81
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
82 83 84 85 86 87
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    auto *input_x = ctx.Input<Tensor>("X");
88
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_bias = ctx.Input<Tensor>("LnBias");
    auto *ln_mean = ctx.Output<Tensor>("LnMean");
    auto *ln_var = ctx.Output<Tensor>("LnVariance");
    auto *ln_out = ctx.Output<Tensor>("LnOut");

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *qkv_out = ctx.Output<Tensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
106 107
    auto *cache_kv = ctx.Input<Tensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<Tensor>("CacheKVOut");
L
Li Min 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
    auto *qk_out = ctx.Output<Tensor>("QKOut");
    auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
    auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
    auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<Tensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
    auto *bias_dropout_residual_out =
        ctx.Output<Tensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
130
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
131 132 133 134 135 136 137
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
138
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
139 140 141 142 143 144 145 146 147 148

    // final output.
    auto *out = ctx.Output<Tensor>("Y");

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
149
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
150 151
    auto *qkv_out_data =
        dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
152
    auto *qkv_bias_out_data =
153 154 155 156
        (qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(qkv_bias_out,
                                        qkv_bias_out->numel() * sizeof(T));
L
Li Min 已提交
157 158

    // get data ptr for FMHA.
159 160
    auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
        transpose_out_2, transpose_out_2->numel() * sizeof(T));
161 162 163
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
164 165 166 167 168 169
            : dev_ctx.template Alloc<T>(cache_kv_out,
                                        cache_kv_out->numel() * sizeof(T));
    auto *qk_out_data =
        dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
    auto *qktv_out_data =
        dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
170
    auto *src_mask_out_data =
171 172 173 174 175 176 177 178 179 180 181 182 183
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(src_mask_out,
                                        src_mask_out->numel() * sizeof(T));
    auto *softmax_out_data = dev_ctx.template Alloc<T>(
        softmax_out, softmax_out->numel() * sizeof(T));
    auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
        attn_dropout_mask_out,
        attn_dropout_mask_out->numel() * sizeof(uint8_t));
    auto *attn_dropout_out_data = dev_ctx.template Alloc<T>(
        attn_dropout_out, attn_dropout_out->numel() * sizeof(T));
    auto *fmha_out_data =
        dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
L
Li Min 已提交
184 185 186

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
187 188
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
189 190
    auto *out_linear_out_data = dev_ctx.template Alloc<T>(
        out_linear_out, out_linear_out->numel() * sizeof(T));
L
Li Min 已提交
191 192

    // get data ptr for bias+dropout+residual+layernorm
193 194 195 196
    auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
        dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
    auto *final_out_data =
        dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
L
Li Min 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

210 211
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
212 213 214 215 216

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
217
    // (transA, transB, compute_bias) = (false, true, true)
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     false,
                                     true,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_bias);

    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_rate,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
L
Li Min 已提交
239 240 241

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
242 243 244 245 246
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
247 248 249 250 251 252 253
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            false,
                                            false,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            false);
L
Li Min 已提交
254 255
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
256 257 258 259
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
L
Li Min 已提交
260 261 262
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
263 264 265
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
266 267 268 269 270 271
      auto *ln_mean_data =
          dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
      auto *ln_var_data =
          dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
      auto *ln_out_data =
          dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
L
Li Min 已提交
272

273 274 275 276 277 278 279 280
      layer_norm_compute.ComputeForward(x_data,
                                        ln_scale_data,
                                        ln_bias_data,
                                        ln_out_data,
                                        ln_mean_data,
                                        ln_var_data);
      qkv_compute.ComputeForward(
          qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
281
    } else {
282 283
      qkv_compute.ComputeForward(
          qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
284
    }
285
    if (qkv_bias == nullptr) {
286 287 288 289 290 291 292 293 294 295 296 297
      fmha_ref_compute.ComputeForward(*qkv_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
298
    } else {
299 300 301 302 303 304 305 306 307 308 309 310
      fmha_ref_compute.ComputeForward(*qkv_bias_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
311
    }
312

L
Li Min 已提交
313 314 315
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
316 317
    out_linear_compute.ComputeForward(
        out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
318 319 320
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

321 322
    bool add_residual = ctx.Attr<bool>("add_residual");
    const T *residual_ptr = add_residual ? x_data : nullptr;
L
Li Min 已提交
323 324 325
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
326 327 328 329 330 331
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          final_out_data,
          dropout_mask_out_data);
L
Li Min 已提交
332
    } else {
333
      // TODO(Xreki): support post layer_norm case when add_residual is false.
334 335
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
336 337 338 339 340 341
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

      const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
      const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
342 343 344 345 346 347 348
      T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
          bias_dropout_residual_out,
          bias_dropout_residual_out->numel() * sizeof(T));
      U *ln_mean_2_ptr =
          dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
      U *ln_var_2_ptr =
          dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
L
Li Min 已提交
349 350
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
351 352 353 354 355 356 357 358 359 360 361
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          ln_scale_2_ptr,
          ln_bias_2_ptr,
          bias_dropout_residual_out_ptr,
          dropout_mask_out_data,
          final_out_data,
          ln_mean_2_ptr,
          ln_var_2_ptr);
L
Li Min 已提交
362
    }
L
Li Min 已提交
363 364 365
  }
};

366 367 368 369 370 371 372 373 374 375
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
376
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
377
    bool is_test_1 = ctx.Attr<bool>("is_test");
378 379 380 381 382 383 384
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
385
    int ring_id = ctx.Attr<int>("ring_id");
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406

    // get inputs.
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto *d_y_data = d_y->data<T>();

    // fw input
    auto *input_x = ctx.Input<Tensor>("X");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_2_scale = ctx.Input<Tensor>("Ln2Scale");
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
    auto *qkv_weight_data = qkv_weight->data<T>();
407
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
408
    auto *out_linear_weight_data = out_linear_weight->data<T>();
409 410
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
411 412

    // fw output
413 414 415 416 417 418 419 420 421 422 423
    auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
    auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
    auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
    auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
424 425 426 427 428
    auto *bias_dropout_residual_out =
        ctx.Input<Tensor>("BiasDropoutResidualOut");
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
429 430
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
    auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();

    // output's grad
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
    auto *d_qkv_bias_out =
        ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out = ctx.Output<Tensor>(framework::GradVarName("QKTVOut"));
    auto *d_transpose_out_2 =
        ctx.Output<Tensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out = ctx.Output<Tensor>(framework::GradVarName("QKOut"));
    auto *d_softmax_out =
        ctx.Output<Tensor>(framework::GradVarName("SoftmaxOut"));
    auto *d_attn_dropout_out =
        ctx.Output<Tensor>(framework::GradVarName("AttnDropoutOut"));
    auto *d_src_mask_out =
        ctx.Output<Tensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
    auto *d_out_linear_out =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out =
        ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
453
    auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
454 455 456 457
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
458 459
                               : dev_ctx.template Alloc<T>(
                                     d_qkv_out, d_qkv_out->numel() * sizeof(T));
460 461 462
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
463 464 465 466 467 468 469 470 471 472 473 474
            : dev_ctx.template Alloc<T>(d_qkv_bias_out,
                                        d_qkv_bias_out->numel() * sizeof(T));
    auto *d_qktv_out_data =
        dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
    auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
        d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
    auto *d_qk_out_data =
        dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
    auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
        d_softmax_out, d_softmax_out->numel() * sizeof(T));
    auto *d_attn_dropout_out_data = dev_ctx.template Alloc<T>(
        d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T));
475
    auto *d_src_mask_out_data =
476 477 478 479 480 481 482 483
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_src_mask_out,
                                        d_src_mask_out->numel() * sizeof(T));
    auto *d_fmha_out_data =
        dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
    auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
        d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
484 485 486 487 488 489 490 491 492 493

    // parameter grad
    auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
    auto *d_out_linear_weight =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearW"));
    auto *d_out_linear_bias =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
494

495 496 497 498 499 500 501 502 503
    auto *d_qkv_weight_data = dev_ctx.template Alloc<T>(
        d_qkv_weight, d_qkv_weight->numel() * sizeof(T));
    auto *d_qkv_bias_data =
        (d_qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_bias,
                                        d_qkv_bias->numel() * sizeof(T));
    auto *d_out_linear_weight_data = dev_ctx.template Alloc<T>(
        d_out_linear_weight, d_out_linear_weight->numel() * sizeof(T));
504
    auto *d_out_linear_bias_data =
505 506
        (d_out_linear_bias == nullptr)
            ? nullptr
507 508
            : dev_ctx.template Alloc<T>(d_out_linear_bias,
                                        d_out_linear_bias->numel() * sizeof(T));
509 510 511 512 513 514 515 516 517 518 519 520 521 522 523

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

524
    bool add_residual = ctx.Attr<bool>("add_residual");
525
    Tensor d_residual;
526 527 528
    T *d_residual_data = nullptr;
    if (add_residual) {
      d_residual.Resize(input_x_dims);
529 530
      d_residual_data = dev_ctx.template Alloc<T>(
          &d_residual, d_residual.numel() * sizeof(T));
531
    }
532 533 534

    bool transA = false;
    bool transB = true;
535
    bool compute_qkv_bias = qkv_bias ? true : false;
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     transA,
                                     transB,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_qkv_bias);
    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_prob,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
558 559 560
    output_size = hidden_size;
    transA = false;
    transB = false;
561
    bool compute_bias = false;
562
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
563 564 565 566 567 568 569
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            transA,
                                            transB,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            compute_bias);
570 571
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
572 573 574 575
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
576 577
        ln2epsilon);

L
Li Min 已提交
578 579
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
580 581 582 583 584 585
          ctx.cuda_device_context(),
          d_y_data,
          dropout_mask_out_data,
          d_out_linear_out_data,
          d_residual_data,
          d_out_linear_bias_data);
L
Li Min 已提交
586 587 588 589 590 591
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
592 593
          (d_ln_2_scale == nullptr
               ? nullptr
594 595
               : dev_ctx.template Alloc<U>(d_ln_2_scale,
                                           d_ln_2_scale->numel() * sizeof(U)));
L
Li Min 已提交
596
      auto *d_ln_2_bias_data =
597 598
          (d_ln_2_bias == nullptr
               ? nullptr
599 600 601 602 603
               : dev_ctx.template Alloc<U>(d_ln_2_bias,
                                           d_ln_2_bias->numel() * sizeof(U)));
      auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
          d_bias_dropout_residual_out,
          d_bias_dropout_residual_out->numel() * sizeof(T));
L
Li Min 已提交
604 605

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
606 607 608 609 610 611 612 613 614 615 616 617 618
          ctx.cuda_device_context(),
          d_y_data,
          bias_dropout_residual_out_data,
          dropout_mask_out_data,
          ln_2_scale_data,
          ln_2_mean_data,
          ln_2_var_data,
          d_bias_dropout_residual_out_data,
          d_ln_2_scale_data,
          d_ln_2_bias_data,
          d_out_linear_out_data,
          d_out_linear_bias_data,
          d_residual_data);
L
Li Min 已提交
619
    }
620

621 622 623 624 625 626
    out_linear_compute.ComputeBackward(fmha_out,
                                       out_linear_weight,
                                       d_out_linear_out,
                                       d_fmha_out,
                                       d_out_linear_weight,
                                       nullptr);
L
Li Min 已提交
627

628
    if (qkv_bias != nullptr) {
629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_bias_out);
645
    } else {
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_out);
662
    }
663 664

    if (pre_layer_norm) {
665 666 667 668 669 670 671 672 673 674
      auto *ln_mean = ctx.Input<Tensor>("LnMean");
      auto *ln_var = ctx.Input<Tensor>("LnVariance");
      auto *ln_out = ctx.Input<Tensor>("LnOut");
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

      auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
675 676
      auto *d_ln_out_data =
          dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
677
      auto *d_ln_scale_data =
678 679 680 681
          (d_ln_scale == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_scale,
                                           d_ln_scale->numel() * sizeof(U)));
682
      auto *d_ln_bias_data =
683 684 685 686
          (d_ln_bias == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_bias,
                                           d_ln_bias->numel() * sizeof(U)));
687
      if (qkv_bias != nullptr) {
688 689 690 691 692 693
        qkv_compute.ComputeBackward(ln_out,
                                    qkv_weight,
                                    d_qkv_bias_out,
                                    d_ln_out,
                                    d_qkv_weight,
                                    d_qkv_bias);
694
      } else {
695 696
        qkv_compute.ComputeBackward(
            ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
697
      }
698 699
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
700 701 702 703 704 705 706 707
      layer_norm_compute.ComputeBackward(x_data,
                                         d_ln_out_data,
                                         ln_scale_data,
                                         ln_mean_data,
                                         ln_var_data,
                                         d_x_data,
                                         d_ln_scale_data,
                                         d_ln_bias_data);
708
    } else {
709
      if (qkv_bias != nullptr) {
710 711
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
712
      } else {
713 714
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
715
      }
716 717
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
718
    }
719 720 721 722 723

    if (add_residual) {
      // gradient accumulation
      std::vector<const Tensor *> ins = {&d_residual, d_x};
      std::vector<Tensor *> outs = {d_x};
724 725
      phi::funcs::ElementwiseKernel<T>(
          ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
726
    }
727 728 729
  }
};

L
Li Min 已提交
730 731 732 733 734
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
735 736
REGISTER_OP_CUDA_KERNEL(fused_attention,
                        ops::FusedAttentionOpKernel<float>,
L
Li Min 已提交
737 738
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
739 740 741 742
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);