fused_attention_op.cu 35.6 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
26
#include "paddle/phi/api/include/tensor.h"
27
#include "paddle/phi/backends/gpu/gpu_device_function.h"
28
#include "paddle/phi/core/dense_tensor.h"
29 30
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
31
#include "paddle/phi/kernels/funcs/functors.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
33
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
L
Li Min 已提交
34

35
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wen Sun 已提交
36
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
37 38 39 40
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
41 42 43
namespace paddle {
namespace operators {

44
template <typename T>
45
static void AllReduce(phi::DenseTensor &tensor,  // NOLINT
46
                      const int ring_id,
L
Leo Chen 已提交
47
                      const phi::GPUContext &ctx) {
48 49
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
50 51 52 53
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup *pg = map->get(ring_id);
54
    auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
55 56
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
57
    auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
58 59 60 61 62 63 64
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void *sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
65
    void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
66 67 68 69 70
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
71 72 73 74 75 76 77
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
78 79 80 81 82
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
83
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
84
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
85 86
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
87 88 89 90 91
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_bias = ctx.Input<phi::DenseTensor>("LnBias");
    auto *ln_mean = ctx.Output<phi::DenseTensor>("LnMean");
    auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
    auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
L
Li Min 已提交
92

93 94 95
    const auto num_heads = ctx.Attr<int>("num_heads");
    const auto transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");

L
Li Min 已提交
96
    // x: qkv's input [batch_size, seq_len, dim_embed]
97
    // if transpose_qkv_wb is False
L
Li Min 已提交
98
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
99 100
    // if transpose_qkv_wb is True
    // y: qkv's weight: [dim_embed, 3 * dim_embed]
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<phi::DenseTensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<phi::DenseTensor>("TransposeOut2");
    auto *cache_kv = ctx.Input<phi::DenseTensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<phi::DenseTensor>("CacheKVOut");
    auto *qk_out = ctx.Output<phi::DenseTensor>("QKOut");
    auto *qktv_out = ctx.Output<phi::DenseTensor>("QKTVOut");
    auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Output<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<phi::DenseTensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<phi::DenseTensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<phi::DenseTensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<phi::DenseTensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<phi::DenseTensor>("DropoutMaskOut");
L
Li Min 已提交
126
    auto *bias_dropout_residual_out =
127 128 129
        ctx.Output<phi::DenseTensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<phi::DenseTensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<phi::DenseTensor>("Ln2Variance");
L
Li Min 已提交
130 131 132
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
133 134 135 136
    const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
    DropoutParam dropout_param2(ctx, 0);
    const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);

L
Li Min 已提交
137
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
138 139 140 141
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
142 143
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
L
Li Min 已提交
144 145
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
146
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
147 148

    // final output.
149
    auto *out = ctx.Output<phi::DenseTensor>("Y");
L
Li Min 已提交
150 151 152 153 154 155 156

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
157
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
158 159
    auto *qkv_out_data =
        dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
160
    auto *qkv_bias_out_data =
161 162 163 164
        (qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(qkv_bias_out,
                                        qkv_bias_out->numel() * sizeof(T));
L
Li Min 已提交
165 166

    // get data ptr for FMHA.
167 168
    auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
        transpose_out_2, transpose_out_2->numel() * sizeof(T));
169 170 171
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
172 173 174 175 176 177
            : dev_ctx.template Alloc<T>(cache_kv_out,
                                        cache_kv_out->numel() * sizeof(T));
    auto *qk_out_data =
        dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
    auto *qktv_out_data =
        dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
178
    auto *src_mask_out_data =
179 180 181 182 183 184
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(src_mask_out,
                                        src_mask_out->numel() * sizeof(T));
    auto *softmax_out_data = dev_ctx.template Alloc<T>(
        softmax_out, softmax_out->numel() * sizeof(T));
185 186 187 188 189 190 191 192 193 194
    auto *attn_dropout_mask_out_data =
        has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
                               attn_dropout_mask_out,
                               attn_dropout_mask_out->numel() * sizeof(uint8_t))
                         : nullptr;
    auto *attn_dropout_out_data =
        has_attn_dropout
            ? dev_ctx.template Alloc<T>(attn_dropout_out,
                                        attn_dropout_out->numel() * sizeof(T))
            : nullptr;
195 196
    auto *fmha_out_data =
        dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
L
Li Min 已提交
197 198 199

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
200 201
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
202 203
    auto *out_linear_out_data = dev_ctx.template Alloc<T>(
        out_linear_out, out_linear_out->numel() * sizeof(T));
L
Li Min 已提交
204 205

    // get data ptr for bias+dropout+residual+layernorm
206 207 208 209 210
    auto *dropout_mask_out_data =
        has_dropout
            ? dev_ctx.template Alloc<uint8_t>(
                  dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
            : nullptr;
211 212
    auto *final_out_data =
        dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
L
Li Min 已提交
213 214 215 216 217

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

218 219 220 221 222 223 224 225 226 227
    int num_head;
    int dim_head;
    // get num_head and dim_head in two different ways
    if (!transpose_qkv_wb) {
      num_head = qkv_w_dims[1];
      dim_head = qkv_w_dims[2];
    } else {
      num_head = num_heads;
      dim_head = dim_embed / num_head;
    }
L
Li Min 已提交
228 229 230 231 232 233

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

234 235
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
236 237 238 239 240

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
241
    // (transA, transB, compute_bias) = (false, true, true)
242
    bool transB = transpose_qkv_wb ? false : true;
243 244
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     false,
245
                                     transB,
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_bias);

    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_rate,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
L
Li Min 已提交
264 265 266

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
267 268 269 270 271
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
272 273 274 275 276 277 278
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            false,
                                            false,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            false);
L
Li Min 已提交
279
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
280 281 282 283
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
L
Li Min 已提交
284 285 286
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
287 288 289
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
290 291 292 293 294 295
      auto *ln_mean_data =
          dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
      auto *ln_var_data =
          dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
      auto *ln_out_data =
          dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
L
Li Min 已提交
296

297 298 299 300 301 302 303 304
      layer_norm_compute.ComputeForward(x_data,
                                        ln_scale_data,
                                        ln_bias_data,
                                        ln_out_data,
                                        ln_mean_data,
                                        ln_var_data);
      qkv_compute.ComputeForward(
          qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
305
    } else {
306 307
      qkv_compute.ComputeForward(
          qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
308
    }
309 310 311 312 313 314 315

    if (transpose_qkv_wb) {
      // resize the output for fmha compute
      qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
      qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
    }

316
    if (qkv_bias == nullptr) {
317 318 319 320 321 322 323 324 325 326 327 328
      fmha_ref_compute.ComputeForward(*qkv_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
329
    } else {
330 331 332 333 334 335 336 337 338 339 340 341
      fmha_ref_compute.ComputeForward(*qkv_bias_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
342
    }
343

344 345 346 347 348 349
    if (transpose_qkv_wb) {
      // resize the output back to make the shape compatible with infer shape
      qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
    }

L
Li Min 已提交
350 351 352
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
353 354
    out_linear_compute.ComputeForward(
        out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
355 356 357
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

358 359
    bool add_residual = ctx.Attr<bool>("add_residual");
    const T *residual_ptr = add_residual ? x_data : nullptr;
L
Li Min 已提交
360 361 362
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
363 364 365 366 367 368
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          final_out_data,
          dropout_mask_out_data);
L
Li Min 已提交
369
    } else {
370
      // TODO(Xreki): support post layer_norm case when add_residual is false.
371 372
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
373 374 375 376 377 378
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

      const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
      const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
379 380 381 382 383 384 385
      T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
          bias_dropout_residual_out,
          bias_dropout_residual_out->numel() * sizeof(T));
      U *ln_mean_2_ptr =
          dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
      U *ln_var_2_ptr =
          dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
L
Li Min 已提交
386 387
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
388 389 390 391 392 393 394 395 396 397 398
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          ln_scale_2_ptr,
          ln_bias_2_ptr,
          bias_dropout_residual_out_ptr,
          dropout_mask_out_data,
          final_out_data,
          ln_mean_2_ptr,
          ln_var_2_ptr);
L
Li Min 已提交
399
    }
L
Li Min 已提交
400 401 402
  }
};

403 404 405 406 407
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
408 409
    const int num_heads = ctx.Attr<int>("num_heads");
    const bool transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
410 411 412 413
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

414 415 416 417 418
    const float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
    const bool has_attn_dropout = (attn_dropout_prob != 0.0f);
    DropoutParam dropout_param2(ctx, 0);
    const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);

419
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
420
    bool is_test_1 = ctx.Attr<bool>("is_test");
421 422 423 424
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
425 426
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
427 428
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
429
    int ring_id = ctx.Attr<int>("ring_id");
430 431

    // get inputs.
432
    auto *d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
433 434 435
    auto *d_y_data = d_y->data<T>();

    // fw input
436 437 438
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_2_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
439 440 441 442 443
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
444 445 446 447 448
    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
449
    auto *qkv_weight_data = qkv_weight->data<T>();
450
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
451
    auto *out_linear_weight_data = out_linear_weight->data<T>();
452 453
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
454 455

    // fw output
456 457 458 459 460 461 462 463 464 465 466
    auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
    auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
    auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
    auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
467
    auto *bias_dropout_residual_out =
468
        ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
469 470 471
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
472 473
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
474 475
    auto *dropout_mask_out_data =
        has_dropout ? dropout_mask_out->data<uint8_t>() : nullptr;
476 477

    // output's grad
478 479 480
    auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto *d_qkv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVOut"));
481
    auto *d_qkv_bias_out =
482 483 484
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKTVOut"));
485
    auto *d_transpose_out_2 =
486 487 488
        ctx.Output<phi::DenseTensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKOut"));
489
    auto *d_softmax_out =
490
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SoftmaxOut"));
491
    auto *d_attn_dropout_out =
492
        ctx.Output<phi::DenseTensor>(framework::GradVarName("AttnDropoutOut"));
493
    auto *d_src_mask_out =
494 495 496
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("FMHAOut"));
497
    auto *d_out_linear_out =
498 499 500
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out = ctx.Output<phi::DenseTensor>(
        framework::GradVarName("BiasDropoutResidualOut"));
501
    auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
502 503 504 505
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
506 507
                               : dev_ctx.template Alloc<T>(
                                     d_qkv_out, d_qkv_out->numel() * sizeof(T));
508 509 510
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
511 512 513 514 515 516 517 518 519 520
            : dev_ctx.template Alloc<T>(d_qkv_bias_out,
                                        d_qkv_bias_out->numel() * sizeof(T));
    auto *d_qktv_out_data =
        dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
    auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
        d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
    auto *d_qk_out_data =
        dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
    auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
        d_softmax_out, d_softmax_out->numel() * sizeof(T));
521 522 523 524 525
    auto *d_attn_dropout_out_data =
        has_attn_dropout
            ? dev_ctx.template Alloc<T>(d_attn_dropout_out,
                                        d_attn_dropout_out->numel() * sizeof(T))
            : nullptr;
526
    auto *d_src_mask_out_data =
527 528 529 530 531 532 533 534
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_src_mask_out,
                                        d_src_mask_out->numel() * sizeof(T));
    auto *d_fmha_out_data =
        dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
    auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
        d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
535 536

    // parameter grad
537 538 539 540
    auto *d_qkv_weight =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBias"));
541
    auto *d_out_linear_weight =
542
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearW"));
543
    auto *d_out_linear_bias =
544 545 546 547 548
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));
549

550 551 552 553 554 555
    auto *d_qkv_weight_data =
        (d_qkv_weight == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_weight,
                                        d_qkv_weight->numel() * sizeof(T));

556 557 558 559 560
    auto *d_qkv_bias_data =
        (d_qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_bias,
                                        d_qkv_bias->numel() * sizeof(T));
561 562 563 564 565 566 567
    auto *d_out_linear_weight_data =
        (d_out_linear_weight == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(
                  d_out_linear_weight,
                  d_out_linear_weight->numel() * sizeof(T));

568
    auto *d_out_linear_bias_data =
569 570
        (d_out_linear_bias == nullptr)
            ? nullptr
571 572
            : dev_ctx.template Alloc<T>(d_out_linear_bias,
                                        d_out_linear_bias->numel() * sizeof(T));
573 574 575 576 577 578 579

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
580 581 582 583 584 585 586 587 588
    int num_head;
    int dim_head;
    if (!transpose_qkv_wb) {
      num_head = qkv_w_dims[1];
      dim_head = qkv_w_dims[2];
    } else {
      num_head = num_heads;
      dim_head = dim_embed / num_head;
    }
589 590 591 592 593 594

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

595
    bool add_residual = ctx.Attr<bool>("add_residual");
596
    phi::DenseTensor d_residual;
597 598 599
    T *d_residual_data = nullptr;
    if (add_residual) {
      d_residual.Resize(input_x_dims);
600 601
      d_residual_data = dev_ctx.template Alloc<T>(
          &d_residual, d_residual.numel() * sizeof(T));
602
    }
603 604

    bool transA = false;
605
    bool transB = transpose_qkv_wb ? false : true;
606
    bool compute_qkv_bias = qkv_bias ? true : false;
607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     transA,
                                     transB,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_qkv_bias);
    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_prob,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
629 630 631
    output_size = hidden_size;
    transA = false;
    transB = false;
632
    bool compute_bias = false;
633
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
634 635 636 637 638 639 640
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            transA,
                                            transB,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            compute_bias);
641
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
642 643 644 645
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
646 647
        ln2epsilon);

L
Li Min 已提交
648 649
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
650 651 652 653 654 655
          ctx.cuda_device_context(),
          d_y_data,
          dropout_mask_out_data,
          d_out_linear_out_data,
          d_residual_data,
          d_out_linear_bias_data);
L
Li Min 已提交
656 657 658 659 660 661
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
662 663
          (d_ln_2_scale == nullptr
               ? nullptr
664 665
               : dev_ctx.template Alloc<U>(d_ln_2_scale,
                                           d_ln_2_scale->numel() * sizeof(U)));
L
Li Min 已提交
666
      auto *d_ln_2_bias_data =
667 668
          (d_ln_2_bias == nullptr
               ? nullptr
669 670 671 672 673
               : dev_ctx.template Alloc<U>(d_ln_2_bias,
                                           d_ln_2_bias->numel() * sizeof(U)));
      auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
          d_bias_dropout_residual_out,
          d_bias_dropout_residual_out->numel() * sizeof(T));
L
Li Min 已提交
674 675

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
676 677 678 679 680 681 682 683 684 685 686 687 688
          ctx.cuda_device_context(),
          d_y_data,
          bias_dropout_residual_out_data,
          dropout_mask_out_data,
          ln_2_scale_data,
          ln_2_mean_data,
          ln_2_var_data,
          d_bias_dropout_residual_out_data,
          d_ln_2_scale_data,
          d_ln_2_bias_data,
          d_out_linear_out_data,
          d_out_linear_bias_data,
          d_residual_data);
L
Li Min 已提交
689
    }
690

691 692 693 694 695 696
    out_linear_compute.ComputeBackward(fmha_out,
                                       out_linear_weight,
                                       d_out_linear_out,
                                       d_fmha_out,
                                       d_out_linear_weight,
                                       nullptr);
L
Li Min 已提交
697

698 699 700 701 702 703 704 705 706
    if (transpose_qkv_wb) {
      if (compute_qkv_bias) {
        d_qkv_bias_out->Resize(
            {batch_size, max_seq_len, 3, num_head, dim_head});
      } else {
        d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
      }
    }

707
    if (qkv_bias != nullptr) {
708
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
709
                                       has_attn_dropout ? src_mask : nullptr,
710 711 712 713 714 715 716 717 718 719 720 721 722 723
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_bias_out);
724
    } else {
725
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
726
                                       has_attn_dropout ? src_mask : nullptr,
727 728 729 730 731 732 733 734 735 736 737 738 739 740
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_out);
741
    }
742

743 744 745 746 747 748 749 750
    if (transpose_qkv_wb) {
      if (compute_qkv_bias) {
        d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      } else {
        d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      }
    }

751
    if (pre_layer_norm) {
752 753 754
      auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
      auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
      auto *ln_out = ctx.Input<phi::DenseTensor>("LnOut");
755 756 757 758
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

759 760 761 762 763 764
      auto *d_ln_out =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnBias"));
765 766
      auto *d_ln_out_data =
          dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
767
      auto *d_ln_scale_data =
768 769 770 771
          (d_ln_scale == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_scale,
                                           d_ln_scale->numel() * sizeof(U)));
772
      auto *d_ln_bias_data =
773 774 775 776
          (d_ln_bias == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_bias,
                                           d_ln_bias->numel() * sizeof(U)));
777
      if (qkv_bias != nullptr) {
778 779 780 781 782 783
        qkv_compute.ComputeBackward(ln_out,
                                    qkv_weight,
                                    d_qkv_bias_out,
                                    d_ln_out,
                                    d_qkv_weight,
                                    d_qkv_bias);
784
      } else {
785 786
        qkv_compute.ComputeBackward(
            ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
787
      }
788 789
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
790 791 792 793 794 795 796 797
      layer_norm_compute.ComputeBackward(x_data,
                                         d_ln_out_data,
                                         ln_scale_data,
                                         ln_mean_data,
                                         ln_var_data,
                                         d_x_data,
                                         d_ln_scale_data,
                                         d_ln_bias_data);
798
    } else {
799
      if (qkv_bias != nullptr) {
800 801
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
802
      } else {
803 804
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
805
      }
806 807
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
808
    }
809 810 811

    if (add_residual) {
      // gradient accumulation
812 813
      std::vector<const phi::DenseTensor *> ins = {&d_residual, d_x};
      std::vector<phi::DenseTensor *> outs = {d_x};
814 815
      phi::funcs::ElementwiseKernel<T>(
          ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
816
    }
817 818 819
  }
};

L
Li Min 已提交
820 821 822 823 824
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
825 826
REGISTER_OP_CUDA_KERNEL(fused_attention,
                        ops::FusedAttentionOpKernel<float>,
L
Li Min 已提交
827 828
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
829 830 831 832
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);