fused_attention_op.cu 35.7 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
26
#include "paddle/phi/api/include/tensor.h"
27
#include "paddle/phi/backends/gpu/gpu_device_function.h"
28
#include "paddle/phi/core/dense_tensor.h"
29 30
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
31
#include "paddle/phi/kernels/funcs/functors.h"
32
#include "paddle/phi/kernels/funcs/math_function.h"
33
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
L
Li Min 已提交
34

35
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
W
Wen Sun 已提交
36
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
37 38 39 40
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
41 42 43
namespace paddle {
namespace operators {

44
template <typename T>
45
static void AllReduce(phi::DenseTensor &tensor,  // NOLINT
46
                      const int ring_id,
L
Leo Chen 已提交
47
                      const phi::GPUContext &ctx) {
48 49
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
50 51 52 53
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup *pg = map->get(ring_id);
54
    auto pg_nccl = static_cast<distributed::ProcessGroupNCCL *>(pg);
55 56
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
57
    auto task = pg_nccl->AllReduce(&tensor, tensor, opts, true, true);
58 59 60 61 62 63 64
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void *sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
65
    void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
66 67 68 69 70
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
71 72 73 74 75 76 77
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
78 79 80 81 82
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
83
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
84
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
85 86
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
87 88 89 90 91
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_bias = ctx.Input<phi::DenseTensor>("LnBias");
    auto *ln_mean = ctx.Output<phi::DenseTensor>("LnMean");
    auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
    auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
L
Li Min 已提交
92

93 94 95
    const auto num_heads = ctx.Attr<int>("num_heads");
    const auto transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");

L
Li Min 已提交
96
    // x: qkv's input [batch_size, seq_len, dim_embed]
97
    // if transpose_qkv_wb is False
L
Li Min 已提交
98
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
99 100
    // if transpose_qkv_wb is True
    // y: qkv's weight: [dim_embed, 3 * dim_embed]
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<phi::DenseTensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<phi::DenseTensor>("TransposeOut2");
    auto *cache_kv = ctx.Input<phi::DenseTensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<phi::DenseTensor>("CacheKVOut");
    auto *qk_out = ctx.Output<phi::DenseTensor>("QKOut");
    auto *qktv_out = ctx.Output<phi::DenseTensor>("QKTVOut");
    auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Output<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<phi::DenseTensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<phi::DenseTensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<phi::DenseTensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<phi::DenseTensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<phi::DenseTensor>("DropoutMaskOut");
L
Li Min 已提交
126
    auto *bias_dropout_residual_out =
127 128 129
        ctx.Output<phi::DenseTensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<phi::DenseTensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<phi::DenseTensor>("Ln2Variance");
L
Li Min 已提交
130 131 132
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
133 134 135 136
    const bool has_attn_dropout = (attn_dropout_rate != 0.0f);
    DropoutParam dropout_param2(ctx, 0);
    const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);

L
Li Min 已提交
137
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
138 139 140 141
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
142 143
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
L
Li Min 已提交
144 145
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
146
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
147 148

    // final output.
149
    auto *out = ctx.Output<phi::DenseTensor>("Y");
L
Li Min 已提交
150 151 152 153 154 155 156

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
157
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
158 159
    auto *qkv_out_data =
        dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
160
    auto *qkv_bias_out_data =
161 162 163 164
        (qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(qkv_bias_out,
                                        qkv_bias_out->numel() * sizeof(T));
L
Li Min 已提交
165 166

    // get data ptr for FMHA.
167 168
    auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
        transpose_out_2, transpose_out_2->numel() * sizeof(T));
169 170 171
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
172 173 174 175 176 177
            : dev_ctx.template Alloc<T>(cache_kv_out,
                                        cache_kv_out->numel() * sizeof(T));
    auto *qk_out_data =
        dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
    auto *qktv_out_data =
        dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
178
    auto *src_mask_out_data =
179 180 181 182 183 184
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(src_mask_out,
                                        src_mask_out->numel() * sizeof(T));
    auto *softmax_out_data = dev_ctx.template Alloc<T>(
        softmax_out, softmax_out->numel() * sizeof(T));
185 186 187 188 189 190 191 192 193 194
    auto *attn_dropout_mask_out_data =
        has_attn_dropout ? dev_ctx.template Alloc<uint8_t>(
                               attn_dropout_mask_out,
                               attn_dropout_mask_out->numel() * sizeof(uint8_t))
                         : nullptr;
    auto *attn_dropout_out_data =
        has_attn_dropout
            ? dev_ctx.template Alloc<T>(attn_dropout_out,
                                        attn_dropout_out->numel() * sizeof(T))
            : nullptr;
195 196
    auto *fmha_out_data =
        dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
L
Li Min 已提交
197 198 199

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
200 201
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
202 203
    auto *out_linear_out_data = dev_ctx.template Alloc<T>(
        out_linear_out, out_linear_out->numel() * sizeof(T));
L
Li Min 已提交
204 205

    // get data ptr for bias+dropout+residual+layernorm
206 207 208 209 210
    auto *dropout_mask_out_data =
        has_dropout
            ? dev_ctx.template Alloc<uint8_t>(
                  dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t))
            : nullptr;
211 212
    auto *final_out_data =
        dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
L
Li Min 已提交
213 214 215 216 217

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

218 219
    int num_head;
    int dim_head;
220
    int nranks = 1;
221 222 223 224 225
    // get num_head and dim_head in two different ways
    if (!transpose_qkv_wb) {
      num_head = qkv_w_dims[1];
      dim_head = qkv_w_dims[2];
    } else {
226
      nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
227
      num_head = num_heads;
228
      dim_head = dim_embed / (num_head * nranks);
229
    }
L
Li Min 已提交
230 231 232 233 234 235

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

236 237
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
238 239 240 241 242

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
243
    // (transA, transB, compute_bias) = (false, true, true)
244
    bool transB = transpose_qkv_wb ? false : true;
245 246
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     false,
247
                                     transB,
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_bias);

    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_rate,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
L
Li Min 已提交
266 267 268

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
269 270 271 272 273
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
274 275 276 277 278 279 280
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            false,
                                            false,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            false);
L
Li Min 已提交
281
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
282 283 284 285
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
L
Li Min 已提交
286 287 288
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
289 290 291
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
292 293 294 295 296 297
      auto *ln_mean_data =
          dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
      auto *ln_var_data =
          dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
      auto *ln_out_data =
          dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
L
Li Min 已提交
298

299 300 301 302 303 304 305 306
      layer_norm_compute.ComputeForward(x_data,
                                        ln_scale_data,
                                        ln_bias_data,
                                        ln_out_data,
                                        ln_mean_data,
                                        ln_var_data);
      qkv_compute.ComputeForward(
          qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
307
    } else {
308 309
      qkv_compute.ComputeForward(
          qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
310
    }
311 312 313 314 315 316 317

    if (transpose_qkv_wb) {
      // resize the output for fmha compute
      qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
      qkv_bias_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
    }

318
    if (qkv_bias == nullptr) {
319 320 321 322 323 324 325 326 327 328 329 330
      fmha_ref_compute.ComputeForward(*qkv_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
331
    } else {
332 333 334 335 336 337 338 339 340 341 342 343
      fmha_ref_compute.ComputeForward(*qkv_bias_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
344
    }
345

346 347 348 349 350 351
    if (transpose_qkv_wb) {
      // resize the output back to make the shape compatible with infer shape
      qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
    }

L
Li Min 已提交
352 353 354
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
355 356
    out_linear_compute.ComputeForward(
        out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
357 358 359
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

360 361
    bool add_residual = ctx.Attr<bool>("add_residual");
    const T *residual_ptr = add_residual ? x_data : nullptr;
L
Li Min 已提交
362 363 364
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
365 366 367 368 369 370
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          final_out_data,
          dropout_mask_out_data);
L
Li Min 已提交
371
    } else {
372
      // TODO(Xreki): support post layer_norm case when add_residual is false.
373 374
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
375 376 377 378 379 380
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

      const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
      const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
381 382 383 384 385 386 387
      T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
          bias_dropout_residual_out,
          bias_dropout_residual_out->numel() * sizeof(T));
      U *ln_mean_2_ptr =
          dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
      U *ln_var_2_ptr =
          dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
L
Li Min 已提交
388 389
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
390 391 392 393 394 395 396 397 398 399 400
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          ln_scale_2_ptr,
          ln_bias_2_ptr,
          bias_dropout_residual_out_ptr,
          dropout_mask_out_data,
          final_out_data,
          ln_mean_2_ptr,
          ln_var_2_ptr);
L
Li Min 已提交
401
    }
L
Li Min 已提交
402 403 404
  }
};

405 406 407 408 409
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
410 411
    const int num_heads = ctx.Attr<int>("num_heads");
    const bool transpose_qkv_wb = ctx.Attr<bool>("transpose_qkv_wb");
412 413 414 415
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

416 417 418 419 420
    const float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
    const bool has_attn_dropout = (attn_dropout_prob != 0.0f);
    DropoutParam dropout_param2(ctx, 0);
    const bool has_dropout = (dropout_param2.dropout_prob != 0.0f);

421
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
422
    bool is_test_1 = ctx.Attr<bool>("is_test");
423 424 425 426
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
427 428
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
429 430
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
431
    int ring_id = ctx.Attr<int>("ring_id");
432 433

    // get inputs.
434
    auto *d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
435 436 437
    auto *d_y_data = d_y->data<T>();

    // fw input
438 439 440
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_2_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
441 442 443 444 445
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
446 447 448 449 450
    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
451
    auto *qkv_weight_data = qkv_weight->data<T>();
452
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
453
    auto *out_linear_weight_data = out_linear_weight->data<T>();
454 455
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
456 457

    // fw output
458 459 460 461 462 463 464 465 466 467 468
    auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
    auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
    auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
    auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
469
    auto *bias_dropout_residual_out =
470
        ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
471 472 473
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
474 475
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
476 477
    auto *dropout_mask_out_data =
        has_dropout ? dropout_mask_out->data<uint8_t>() : nullptr;
478 479

    // output's grad
480 481 482
    auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto *d_qkv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVOut"));
483
    auto *d_qkv_bias_out =
484 485 486
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKTVOut"));
487
    auto *d_transpose_out_2 =
488 489 490
        ctx.Output<phi::DenseTensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKOut"));
491
    auto *d_softmax_out =
492
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SoftmaxOut"));
493
    auto *d_attn_dropout_out =
494
        ctx.Output<phi::DenseTensor>(framework::GradVarName("AttnDropoutOut"));
495
    auto *d_src_mask_out =
496 497 498
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("FMHAOut"));
499
    auto *d_out_linear_out =
500 501 502
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out = ctx.Output<phi::DenseTensor>(
        framework::GradVarName("BiasDropoutResidualOut"));
503
    auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
504 505 506 507
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
508 509
                               : dev_ctx.template Alloc<T>(
                                     d_qkv_out, d_qkv_out->numel() * sizeof(T));
510 511 512
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
513 514 515 516 517 518 519 520 521 522
            : dev_ctx.template Alloc<T>(d_qkv_bias_out,
                                        d_qkv_bias_out->numel() * sizeof(T));
    auto *d_qktv_out_data =
        dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
    auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
        d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
    auto *d_qk_out_data =
        dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
    auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
        d_softmax_out, d_softmax_out->numel() * sizeof(T));
523 524 525 526 527
    auto *d_attn_dropout_out_data =
        has_attn_dropout
            ? dev_ctx.template Alloc<T>(d_attn_dropout_out,
                                        d_attn_dropout_out->numel() * sizeof(T))
            : nullptr;
528
    auto *d_src_mask_out_data =
529 530 531 532 533 534 535 536
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_src_mask_out,
                                        d_src_mask_out->numel() * sizeof(T));
    auto *d_fmha_out_data =
        dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
    auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
        d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
537 538

    // parameter grad
539 540 541 542
    auto *d_qkv_weight =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBias"));
543
    auto *d_out_linear_weight =
544
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearW"));
545
    auto *d_out_linear_bias =
546 547 548 549 550
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));
551

552 553 554 555 556 557
    auto *d_qkv_weight_data =
        (d_qkv_weight == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_weight,
                                        d_qkv_weight->numel() * sizeof(T));

558 559 560 561 562
    auto *d_qkv_bias_data =
        (d_qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_bias,
                                        d_qkv_bias->numel() * sizeof(T));
563 564 565 566 567 568 569
    auto *d_out_linear_weight_data =
        (d_out_linear_weight == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(
                  d_out_linear_weight,
                  d_out_linear_weight->numel() * sizeof(T));

570
    auto *d_out_linear_bias_data =
571 572
        (d_out_linear_bias == nullptr)
            ? nullptr
573 574
            : dev_ctx.template Alloc<T>(d_out_linear_bias,
                                        d_out_linear_bias->numel() * sizeof(T));
575 576 577 578 579 580 581

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
582 583
    int num_head;
    int dim_head;
584
    int nranks = 1;
585 586 587 588
    if (!transpose_qkv_wb) {
      num_head = qkv_w_dims[1];
      dim_head = qkv_w_dims[2];
    } else {
589
      nranks = (qkv_w_dims[0] * 3) / qkv_w_dims[1];
590
      num_head = num_heads;
591
      dim_head = dim_embed / (num_head * nranks);
592
    }
593 594 595 596 597 598

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

599
    bool add_residual = ctx.Attr<bool>("add_residual");
600
    phi::DenseTensor d_residual;
601 602 603
    T *d_residual_data = nullptr;
    if (add_residual) {
      d_residual.Resize(input_x_dims);
604 605
      d_residual_data = dev_ctx.template Alloc<T>(
          &d_residual, d_residual.numel() * sizeof(T));
606
    }
607 608

    bool transA = false;
609
    bool transB = transpose_qkv_wb ? false : true;
610
    bool compute_qkv_bias = qkv_bias ? true : false;
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     transA,
                                     transB,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_qkv_bias);
    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_prob,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
633 634 635
    output_size = hidden_size;
    transA = false;
    transB = false;
636
    bool compute_bias = false;
637
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
638 639 640 641 642 643 644
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            transA,
                                            transB,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            compute_bias);
645
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
646 647 648 649
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
650 651
        ln2epsilon);

L
Li Min 已提交
652 653
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
654 655 656 657 658 659
          ctx.cuda_device_context(),
          d_y_data,
          dropout_mask_out_data,
          d_out_linear_out_data,
          d_residual_data,
          d_out_linear_bias_data);
L
Li Min 已提交
660 661 662 663 664 665
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
666 667
          (d_ln_2_scale == nullptr
               ? nullptr
668 669
               : dev_ctx.template Alloc<U>(d_ln_2_scale,
                                           d_ln_2_scale->numel() * sizeof(U)));
L
Li Min 已提交
670
      auto *d_ln_2_bias_data =
671 672
          (d_ln_2_bias == nullptr
               ? nullptr
673 674 675 676 677
               : dev_ctx.template Alloc<U>(d_ln_2_bias,
                                           d_ln_2_bias->numel() * sizeof(U)));
      auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
          d_bias_dropout_residual_out,
          d_bias_dropout_residual_out->numel() * sizeof(T));
L
Li Min 已提交
678 679

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
680 681 682 683 684 685 686 687 688 689 690 691 692
          ctx.cuda_device_context(),
          d_y_data,
          bias_dropout_residual_out_data,
          dropout_mask_out_data,
          ln_2_scale_data,
          ln_2_mean_data,
          ln_2_var_data,
          d_bias_dropout_residual_out_data,
          d_ln_2_scale_data,
          d_ln_2_bias_data,
          d_out_linear_out_data,
          d_out_linear_bias_data,
          d_residual_data);
L
Li Min 已提交
693
    }
694

695 696 697 698 699 700
    out_linear_compute.ComputeBackward(fmha_out,
                                       out_linear_weight,
                                       d_out_linear_out,
                                       d_fmha_out,
                                       d_out_linear_weight,
                                       nullptr);
L
Li Min 已提交
701

702 703 704 705 706 707 708 709 710
    if (transpose_qkv_wb) {
      if (compute_qkv_bias) {
        d_qkv_bias_out->Resize(
            {batch_size, max_seq_len, 3, num_head, dim_head});
      } else {
        d_qkv_out->Resize({batch_size, max_seq_len, 3, num_head, dim_head});
      }
    }

711
    if (qkv_bias != nullptr) {
712
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
713
                                       has_attn_dropout ? src_mask : nullptr,
714 715 716 717 718 719 720 721 722 723 724 725 726 727
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_bias_out);
728
    } else {
729
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
730
                                       has_attn_dropout ? src_mask : nullptr,
731 732 733 734 735 736 737 738 739 740 741 742 743 744
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_out);
745
    }
746

747 748 749 750 751 752 753 754
    if (transpose_qkv_wb) {
      if (compute_qkv_bias) {
        d_qkv_bias_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      } else {
        d_qkv_out->Resize({batch_size, max_seq_len, 3 * hidden_size});
      }
    }

755
    if (pre_layer_norm) {
756 757 758
      auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
      auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
      auto *ln_out = ctx.Input<phi::DenseTensor>("LnOut");
759 760 761 762
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

763 764 765 766 767 768
      auto *d_ln_out =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnBias"));
769 770
      auto *d_ln_out_data =
          dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
771
      auto *d_ln_scale_data =
772 773 774 775
          (d_ln_scale == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_scale,
                                           d_ln_scale->numel() * sizeof(U)));
776
      auto *d_ln_bias_data =
777 778 779 780
          (d_ln_bias == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_bias,
                                           d_ln_bias->numel() * sizeof(U)));
781
      if (qkv_bias != nullptr) {
782 783 784 785 786 787
        qkv_compute.ComputeBackward(ln_out,
                                    qkv_weight,
                                    d_qkv_bias_out,
                                    d_ln_out,
                                    d_qkv_weight,
                                    d_qkv_bias);
788
      } else {
789 790
        qkv_compute.ComputeBackward(
            ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
791
      }
792 793
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
794 795 796 797 798 799 800 801
      layer_norm_compute.ComputeBackward(x_data,
                                         d_ln_out_data,
                                         ln_scale_data,
                                         ln_mean_data,
                                         ln_var_data,
                                         d_x_data,
                                         d_ln_scale_data,
                                         d_ln_bias_data);
802
    } else {
803
      if (qkv_bias != nullptr) {
804 805
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
806
      } else {
807 808
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
809
      }
810 811
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
812
    }
813 814 815

    if (add_residual) {
      // gradient accumulation
816 817
      std::vector<const phi::DenseTensor *> ins = {&d_residual, d_x};
      std::vector<phi::DenseTensor *> outs = {d_x};
818 819
      phi::funcs::ElementwiseKernel<T>(
          ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
820
    }
821 822 823
  }
};

L
Li Min 已提交
824 825 826 827 828
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
829 830
REGISTER_OP_CUDA_KERNEL(fused_attention,
                        ops::FusedAttentionOpKernel<float>,
L
Li Min 已提交
831 832
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
833 834 835 836
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);