fused_attention_op.cu 33.2 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25 26
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
27
#include "paddle/phi/api/include/tensor.h"
28 29
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
30
#include "paddle/phi/kernels/funcs/math_function.h"
L
Li Min 已提交
31

32
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
33
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
34 35 36 37
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
38 39 40
namespace paddle {
namespace operators {

41
using Tensor = phi::DenseTensor;
L
Li Min 已提交
42

43
template <typename T>
44
static void AllReduce(phi::DenseTensor &tensor,  // NOLINT
45
                      const int ring_id,
L
Leo Chen 已提交
46
                      const phi::GPUContext &ctx) {
47 48
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup *pg = map->get(ring_id);
    std::vector<phi::DenseTensor> in_tensor;
    std::vector<phi::DenseTensor> out_tensor;
    in_tensor.push_back(tensor);
    out_tensor.push_back(tensor);
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
    auto task = pg->AllReduce(in_tensor, out_tensor, opts);
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void *sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
67
    void *recvbuff = ctx.template Alloc<T>(&tensor, tensor.numel() * sizeof(T));
68 69 70 71 72
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
73 74 75 76 77 78 79
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
80 81 82 83 84
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
85
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
86
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
87 88
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
89 90 91 92 93
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_bias = ctx.Input<phi::DenseTensor>("LnBias");
    auto *ln_mean = ctx.Output<phi::DenseTensor>("LnMean");
    auto *ln_var = ctx.Output<phi::DenseTensor>("LnVariance");
    auto *ln_out = ctx.Output<phi::DenseTensor>("LnOut");
L
Li Min 已提交
94 95 96

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *qkv_out = ctx.Output<phi::DenseTensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<phi::DenseTensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<phi::DenseTensor>("TransposeOut2");
    auto *cache_kv = ctx.Input<phi::DenseTensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<phi::DenseTensor>("CacheKVOut");
    auto *qk_out = ctx.Output<phi::DenseTensor>("QKOut");
    auto *qktv_out = ctx.Output<phi::DenseTensor>("QKTVOut");
    auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Output<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<phi::DenseTensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<phi::DenseTensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<phi::DenseTensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<phi::DenseTensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<phi::DenseTensor>("DropoutMaskOut");
L
Li Min 已提交
122
    auto *bias_dropout_residual_out =
123 124 125
        ctx.Output<phi::DenseTensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<phi::DenseTensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<phi::DenseTensor>("Ln2Variance");
L
Li Min 已提交
126 127 128
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
129
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
130 131 132 133
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
134 135
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
L
Li Min 已提交
136 137
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
138
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
139 140

    // final output.
141
    auto *out = ctx.Output<phi::DenseTensor>("Y");
L
Li Min 已提交
142 143 144 145 146 147 148

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
149
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
150 151
    auto *qkv_out_data =
        dev_ctx.template Alloc<T>(qkv_out, qkv_out->numel() * sizeof(T));
152
    auto *qkv_bias_out_data =
153 154 155 156
        (qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(qkv_bias_out,
                                        qkv_bias_out->numel() * sizeof(T));
L
Li Min 已提交
157 158

    // get data ptr for FMHA.
159 160
    auto *transpose_out_2_data = dev_ctx.template Alloc<T>(
        transpose_out_2, transpose_out_2->numel() * sizeof(T));
161 162 163
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
164 165 166 167 168 169
            : dev_ctx.template Alloc<T>(cache_kv_out,
                                        cache_kv_out->numel() * sizeof(T));
    auto *qk_out_data =
        dev_ctx.template Alloc<T>(qk_out, qk_out->numel() * sizeof(T));
    auto *qktv_out_data =
        dev_ctx.template Alloc<T>(qktv_out, qktv_out->numel() * sizeof(T));
170
    auto *src_mask_out_data =
171 172 173 174 175 176 177 178 179 180 181 182 183
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(src_mask_out,
                                        src_mask_out->numel() * sizeof(T));
    auto *softmax_out_data = dev_ctx.template Alloc<T>(
        softmax_out, softmax_out->numel() * sizeof(T));
    auto *attn_dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
        attn_dropout_mask_out,
        attn_dropout_mask_out->numel() * sizeof(uint8_t));
    auto *attn_dropout_out_data = dev_ctx.template Alloc<T>(
        attn_dropout_out, attn_dropout_out->numel() * sizeof(T));
    auto *fmha_out_data =
        dev_ctx.template Alloc<T>(fmha_out, fmha_out->numel() * sizeof(T));
L
Li Min 已提交
184 185 186

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
187 188
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
189 190
    auto *out_linear_out_data = dev_ctx.template Alloc<T>(
        out_linear_out, out_linear_out->numel() * sizeof(T));
L
Li Min 已提交
191 192

    // get data ptr for bias+dropout+residual+layernorm
193 194 195 196
    auto *dropout_mask_out_data = dev_ctx.template Alloc<uint8_t>(
        dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t));
    auto *final_out_data =
        dev_ctx.template Alloc<T>(out, out->numel() * sizeof(T));
L
Li Min 已提交
197 198 199 200 201 202 203 204 205 206 207 208 209

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

210 211
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
212 213 214 215 216

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
217
    // (transA, transB, compute_bias) = (false, true, true)
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     false,
                                     true,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_bias);

    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_rate,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
L
Li Min 已提交
239 240 241

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
242 243 244 245 246
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
247 248 249 250 251 252 253
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            false,
                                            false,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            false);
L
Li Min 已提交
254 255
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
256 257 258 259
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
L
Li Min 已提交
260 261 262
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
263 264 265
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
266 267 268 269 270 271
      auto *ln_mean_data =
          dev_ctx.template Alloc<U>(ln_mean, ln_mean->numel() * sizeof(U));
      auto *ln_var_data =
          dev_ctx.template Alloc<U>(ln_var, ln_var->numel() * sizeof(U));
      auto *ln_out_data =
          dev_ctx.template Alloc<T>(ln_out, ln_out->numel() * sizeof(T));
L
Li Min 已提交
272

273 274 275 276 277 278 279 280
      layer_norm_compute.ComputeForward(x_data,
                                        ln_scale_data,
                                        ln_bias_data,
                                        ln_out_data,
                                        ln_mean_data,
                                        ln_var_data);
      qkv_compute.ComputeForward(
          qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
281
    } else {
282 283
      qkv_compute.ComputeForward(
          qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
284
    }
285
    if (qkv_bias == nullptr) {
286 287 288 289 290 291 292 293 294 295 296 297
      fmha_ref_compute.ComputeForward(*qkv_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
298
    } else {
299 300 301 302 303 304 305 306 307 308 309 310
      fmha_ref_compute.ComputeForward(*qkv_bias_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
311
    }
312

L
Li Min 已提交
313 314 315
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
316 317
    out_linear_compute.ComputeForward(
        out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
318 319 320
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

321 322
    bool add_residual = ctx.Attr<bool>("add_residual");
    const T *residual_ptr = add_residual ? x_data : nullptr;
L
Li Min 已提交
323 324 325
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
326 327 328 329 330 331
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          final_out_data,
          dropout_mask_out_data);
L
Li Min 已提交
332
    } else {
333
      // TODO(Xreki): support post layer_norm case when add_residual is false.
334 335
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
336 337 338 339 340 341
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

      const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
      const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
342 343 344 345 346 347 348
      T *bias_dropout_residual_out_ptr = dev_ctx.template Alloc<T>(
          bias_dropout_residual_out,
          bias_dropout_residual_out->numel() * sizeof(T));
      U *ln_mean_2_ptr =
          dev_ctx.template Alloc<U>(ln_mean_2, ln_mean_2->numel() * sizeof(U));
      U *ln_var_2_ptr =
          dev_ctx.template Alloc<U>(ln_var_2, ln_var_2->numel() * sizeof(U));
L
Li Min 已提交
349 350
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
351 352 353 354 355 356 357 358 359 360 361
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          ln_scale_2_ptr,
          ln_bias_2_ptr,
          bias_dropout_residual_out_ptr,
          dropout_mask_out_data,
          final_out_data,
          ln_mean_2_ptr,
          ln_var_2_ptr);
L
Li Min 已提交
362
    }
L
Li Min 已提交
363 364 365
  }
};

366 367 368 369 370 371 372 373 374 375
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
376
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
L
Li Min 已提交
377
    bool is_test_1 = ctx.Attr<bool>("is_test");
378 379 380 381
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
382 383
    auto *seed_1 =
        ctx.HasInput("Seed1") ? ctx.Input<phi::DenseTensor>("Seed1") : nullptr;
384 385
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
386
    int ring_id = ctx.Attr<int>("ring_id");
387 388

    // get inputs.
389
    auto *d_y = ctx.Input<phi::DenseTensor>(framework::GradVarName("Y"));
390 391 392
    auto *d_y_data = d_y->data<T>();

    // fw input
393 394 395
    auto *input_x = ctx.Input<phi::DenseTensor>("X");
    auto *ln_scale = ctx.Input<phi::DenseTensor>("LnScale");
    auto *ln_2_scale = ctx.Input<phi::DenseTensor>("Ln2Scale");
396 397 398 399 400
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
401 402 403 404 405
    auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVW");
    auto *qkv_bias = ctx.Input<phi::DenseTensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<phi::DenseTensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
406 407
    auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
    auto *qkv_weight_data = qkv_weight->data<T>();
408
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
409
    auto *out_linear_weight_data = out_linear_weight->data<T>();
410 411
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
412 413

    // fw output
414 415 416 417 418 419 420 421 422 423 424 425 426
    auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<phi::DenseTensor>("TransposeOut2");
    auto *qk_out = ctx.Input<phi::DenseTensor>("QKOut");
    auto *qktv_out = ctx.Input<phi::DenseTensor>("QKTVOut");
    auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
    auto *attn_dropout_mask_out =
        ctx.Input<phi::DenseTensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<phi::DenseTensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<phi::DenseTensor>("SrcMaskOut");
    auto *out_linear_out = ctx.Input<phi::DenseTensor>("OutLinearOut");
    auto *ln_2_mean = ctx.Input<phi::DenseTensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<phi::DenseTensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<phi::DenseTensor>("DropoutMaskOut");
427
    auto *bias_dropout_residual_out =
428
        ctx.Input<phi::DenseTensor>("BiasDropoutResidualOut");
429 430 431 432 433
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *qk_out_data = qk_out->data<T>();
    auto *qktv_out_data = qktv_out->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
434 435
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
436 437 438 439
    auto *out_linear_out_data = out_linear_out->data<T>();
    auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();

    // output's grad
440 441 442
    auto *d_x = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
    auto *d_qkv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVOut"));
443
    auto *d_qkv_bias_out =
444 445 446
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKTVOut"));
447
    auto *d_transpose_out_2 =
448 449 450
        ctx.Output<phi::DenseTensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKOut"));
451
    auto *d_softmax_out =
452
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SoftmaxOut"));
453
    auto *d_attn_dropout_out =
454
        ctx.Output<phi::DenseTensor>(framework::GradVarName("AttnDropoutOut"));
455
    auto *d_src_mask_out =
456 457 458
        ctx.Output<phi::DenseTensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("FMHAOut"));
459
    auto *d_out_linear_out =
460 461 462
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out = ctx.Output<phi::DenseTensor>(
        framework::GradVarName("BiasDropoutResidualOut"));
463
    auto *d_x_data = dev_ctx.template Alloc<T>(d_x, d_x->numel() * sizeof(T));
464 465 466 467
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
468 469
                               : dev_ctx.template Alloc<T>(
                                     d_qkv_out, d_qkv_out->numel() * sizeof(T));
470 471 472
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
473 474 475 476 477 478 479 480 481 482 483 484
            : dev_ctx.template Alloc<T>(d_qkv_bias_out,
                                        d_qkv_bias_out->numel() * sizeof(T));
    auto *d_qktv_out_data =
        dev_ctx.template Alloc<T>(d_qktv_out, d_qktv_out->numel() * sizeof(T));
    auto *d_transpose_out_2_data = dev_ctx.template Alloc<T>(
        d_transpose_out_2, d_transpose_out_2->numel() * sizeof(T));
    auto *d_qk_out_data =
        dev_ctx.template Alloc<T>(d_qk_out, d_qk_out->numel() * sizeof(T));
    auto *d_softmax_out_data = dev_ctx.template Alloc<T>(
        d_softmax_out, d_softmax_out->numel() * sizeof(T));
    auto *d_attn_dropout_out_data = dev_ctx.template Alloc<T>(
        d_attn_dropout_out, d_attn_dropout_out->numel() * sizeof(T));
485
    auto *d_src_mask_out_data =
486 487 488 489 490 491 492 493
        (src_mask == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_src_mask_out,
                                        d_src_mask_out->numel() * sizeof(T));
    auto *d_fmha_out_data =
        dev_ctx.template Alloc<T>(d_fmha_out, d_fmha_out->numel() * sizeof(T));
    auto *d_out_linear_out_data = dev_ctx.template Alloc<T>(
        d_out_linear_out, d_out_linear_out->numel() * sizeof(T));
494 495

    // parameter grad
496 497 498 499
    auto *d_qkv_weight =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVBias"));
500
    auto *d_out_linear_weight =
501
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearW"));
502
    auto *d_out_linear_bias =
503 504 505 506 507
        ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Ln2Bias"));
508

509 510 511 512 513 514 515 516 517
    auto *d_qkv_weight_data = dev_ctx.template Alloc<T>(
        d_qkv_weight, d_qkv_weight->numel() * sizeof(T));
    auto *d_qkv_bias_data =
        (d_qkv_bias == nullptr)
            ? nullptr
            : dev_ctx.template Alloc<T>(d_qkv_bias,
                                        d_qkv_bias->numel() * sizeof(T));
    auto *d_out_linear_weight_data = dev_ctx.template Alloc<T>(
        d_out_linear_weight, d_out_linear_weight->numel() * sizeof(T));
518
    auto *d_out_linear_bias_data =
519 520
        (d_out_linear_bias == nullptr)
            ? nullptr
521 522
            : dev_ctx.template Alloc<T>(d_out_linear_bias,
                                        d_out_linear_bias->numel() * sizeof(T));
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

538
    bool add_residual = ctx.Attr<bool>("add_residual");
539
    Tensor d_residual;
540 541 542
    T *d_residual_data = nullptr;
    if (add_residual) {
      d_residual.Resize(input_x_dims);
543 544
      d_residual_data = dev_ctx.template Alloc<T>(
          &d_residual, d_residual.numel() * sizeof(T));
545
    }
546 547 548

    bool transA = false;
    bool transB = true;
549
    bool compute_qkv_bias = qkv_bias ? true : false;
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     transA,
                                     transB,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_qkv_bias);
    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_prob,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
572 573 574
    output_size = hidden_size;
    transA = false;
    transB = false;
575
    bool compute_bias = false;
576
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
577 578 579 580 581 582 583
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            transA,
                                            transB,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            compute_bias);
584 585
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
586 587 588 589
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
590 591
        ln2epsilon);

L
Li Min 已提交
592 593
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
594 595 596 597 598 599
          ctx.cuda_device_context(),
          d_y_data,
          dropout_mask_out_data,
          d_out_linear_out_data,
          d_residual_data,
          d_out_linear_bias_data);
L
Li Min 已提交
600 601 602 603 604 605
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
606 607
          (d_ln_2_scale == nullptr
               ? nullptr
608 609
               : dev_ctx.template Alloc<U>(d_ln_2_scale,
                                           d_ln_2_scale->numel() * sizeof(U)));
L
Li Min 已提交
610
      auto *d_ln_2_bias_data =
611 612
          (d_ln_2_bias == nullptr
               ? nullptr
613 614 615 616 617
               : dev_ctx.template Alloc<U>(d_ln_2_bias,
                                           d_ln_2_bias->numel() * sizeof(U)));
      auto *d_bias_dropout_residual_out_data = dev_ctx.template Alloc<T>(
          d_bias_dropout_residual_out,
          d_bias_dropout_residual_out->numel() * sizeof(T));
L
Li Min 已提交
618 619

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
620 621 622 623 624 625 626 627 628 629 630 631 632
          ctx.cuda_device_context(),
          d_y_data,
          bias_dropout_residual_out_data,
          dropout_mask_out_data,
          ln_2_scale_data,
          ln_2_mean_data,
          ln_2_var_data,
          d_bias_dropout_residual_out_data,
          d_ln_2_scale_data,
          d_ln_2_bias_data,
          d_out_linear_out_data,
          d_out_linear_bias_data,
          d_residual_data);
L
Li Min 已提交
633
    }
634

635 636 637 638 639 640
    out_linear_compute.ComputeBackward(fmha_out,
                                       out_linear_weight,
                                       d_out_linear_out,
                                       d_fmha_out,
                                       d_out_linear_weight,
                                       nullptr);
L
Li Min 已提交
641

642
    if (qkv_bias != nullptr) {
643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_bias_out);
659
    } else {
660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_out);
676
    }
677 678

    if (pre_layer_norm) {
679 680 681
      auto *ln_mean = ctx.Input<phi::DenseTensor>("LnMean");
      auto *ln_var = ctx.Input<phi::DenseTensor>("LnVariance");
      auto *ln_out = ctx.Input<phi::DenseTensor>("LnOut");
682 683 684 685
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

686 687 688 689 690 691
      auto *d_ln_out =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("LnBias"));
692 693
      auto *d_ln_out_data =
          dev_ctx.template Alloc<T>(d_ln_out, d_ln_out->numel() * sizeof(T));
694
      auto *d_ln_scale_data =
695 696 697 698
          (d_ln_scale == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_scale,
                                           d_ln_scale->numel() * sizeof(U)));
699
      auto *d_ln_bias_data =
700 701 702 703
          (d_ln_bias == nullptr
               ? nullptr
               : dev_ctx.template Alloc<U>(d_ln_bias,
                                           d_ln_bias->numel() * sizeof(U)));
704
      if (qkv_bias != nullptr) {
705 706 707 708 709 710
        qkv_compute.ComputeBackward(ln_out,
                                    qkv_weight,
                                    d_qkv_bias_out,
                                    d_ln_out,
                                    d_qkv_weight,
                                    d_qkv_bias);
711
      } else {
712 713
        qkv_compute.ComputeBackward(
            ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
714
      }
715 716
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
717 718 719 720 721 722 723 724
      layer_norm_compute.ComputeBackward(x_data,
                                         d_ln_out_data,
                                         ln_scale_data,
                                         ln_mean_data,
                                         ln_var_data,
                                         d_x_data,
                                         d_ln_scale_data,
                                         d_ln_bias_data);
725
    } else {
726
      if (qkv_bias != nullptr) {
727 728
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
729
      } else {
730 731
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
732
      }
733 734
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
735
    }
736 737 738 739 740

    if (add_residual) {
      // gradient accumulation
      std::vector<const Tensor *> ins = {&d_residual, d_x};
      std::vector<Tensor *> outs = {d_x};
741 742
      phi::funcs::ElementwiseKernel<T>(
          ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
743
    }
744 745 746
  }
};

L
Li Min 已提交
747 748 749 750 751
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
752 753
REGISTER_OP_CUDA_KERNEL(fused_attention,
                        ops::FusedAttentionOpKernel<float>,
L
Li Min 已提交
754 755
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
756 757 758 759
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);