fused_attention_op.cu 30.4 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25 26
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
27
#include "paddle/phi/api/include/tensor.h"
28 29
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
30
#include "paddle/phi/kernels/funcs/math_function.h"
L
Li Min 已提交
31

32
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
33
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
34 35 36 37
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
38 39 40 41 42
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

43 44 45 46 47 48
template <typename T>
static void AllReduce(framework::Tensor &tensor,  // NOLINT
                      const int ring_id,
                      const platform::CUDADeviceContext &ctx) {
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

  if (map->has(ring_id)) {
    paddle::distributed::ProcessGroup *pg = map->get(ring_id);
    std::vector<phi::DenseTensor> in_tensor;
    std::vector<phi::DenseTensor> out_tensor;
    in_tensor.push_back(tensor);
    out_tensor.push_back(tensor);
    paddle::distributed::AllreduceOptions opts;
    opts.reduce_op = distributed::ReduceOp::SUM;
    auto task = pg->AllReduce(in_tensor, out_tensor, opts);
    task->Wait();
  } else {
    auto dtype = platform::ToNCCLDataType(
        framework::TransToProtoVarType(tensor.dtype()));
    int64_t numel = tensor.numel();
    const void *sendbuff = tensor.data<T>();
    auto place = ctx.GetPlace();
    void *recvbuff = tensor.mutable_data<T>(place);
    auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
    auto stream = ctx.stream();
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
        sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
  }
73 74 75 76 77 78 79
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    auto *input_x = ctx.Input<Tensor>("X");

    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_bias = ctx.Input<Tensor>("LnBias");
    auto *ln_mean = ctx.Output<Tensor>("LnMean");
    auto *ln_var = ctx.Output<Tensor>("LnVariance");
    auto *ln_out = ctx.Output<Tensor>("LnOut");

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *qkv_out = ctx.Output<Tensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
104 105
    auto *cache_kv = ctx.Input<Tensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<Tensor>("CacheKVOut");
L
Li Min 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    auto *qk_out = ctx.Output<Tensor>("QKOut");
    auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
    auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
    auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<Tensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
    auto *bias_dropout_residual_out =
        ctx.Output<Tensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
128
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
129 130 131 132 133 134 135
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
136
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
137 138 139 140 141 142 143 144 145 146

    // final output.
    auto *out = ctx.Output<Tensor>("Y");

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
147
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
L
Li Min 已提交
148
    auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
149 150 151
    auto *qkv_bias_out_data =
        (qkv_bias == nullptr) ? nullptr
                              : qkv_bias_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
152 153 154 155

    // get data ptr for FMHA.
    auto *transpose_out_2_data =
        transpose_out_2->mutable_data<T>(ctx.GetPlace());
156 157 158 159
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
            : cache_kv_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
160 161
    auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
    auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
162 163 164
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr
                              : src_mask_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
165 166 167 168 169 170 171 172 173
    auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
    auto *attn_dropout_mask_out_data =
        attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
    auto *attn_dropout_out_data =
        attn_dropout_out->mutable_data<T>(ctx.GetPlace());
    auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
174 175
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
L
Li Min 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());

    // get data ptr for bias+dropout+residual+layernorm
    auto *dropout_mask_out_data =
        dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
    auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

195 196
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
197 198 199 200 201

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
202
    // (transA, transB, compute_bias) = (false, true, true)
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     false,
                                     true,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_bias);

    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_rate,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
L
Li Min 已提交
224 225 226

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
227 228 229 230 231
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
232 233 234 235 236 237 238
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            false,
                                            false,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            false);
L
Li Min 已提交
239 240
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
241 242 243 244
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
L
Li Min 已提交
245 246 247
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
248 249 250 251 252 253 254
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
      auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
      auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
      auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());

255 256 257 258 259 260 261 262
      layer_norm_compute.ComputeForward(x_data,
                                        ln_scale_data,
                                        ln_bias_data,
                                        ln_out_data,
                                        ln_mean_data,
                                        ln_var_data);
      qkv_compute.ComputeForward(
          qkv_weight, ln_out, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
263
    } else {
264 265
      qkv_compute.ComputeForward(
          qkv_weight, input_x, qkv_bias, qkv_out, qkv_bias_out);
L
Li Min 已提交
266
    }
267
    if (qkv_bias == nullptr) {
268 269 270 271 272 273 274 275 276 277 278 279
      fmha_ref_compute.ComputeForward(*qkv_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
280
    } else {
281 282 283 284 285 286 287 288 289 290 291 292
      fmha_ref_compute.ComputeForward(*qkv_bias_out,
                                      cache_kv,
                                      src_mask,
                                      transpose_out_2,
                                      cache_kv_out,
                                      qk_out,
                                      src_mask_out,
                                      softmax_out,
                                      attn_dropout_mask_out,
                                      attn_dropout_out,
                                      qktv_out,
                                      fmha_out);
293
    }
294

L
Li Min 已提交
295 296 297
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
298 299
    out_linear_compute.ComputeForward(
        out_linear_weight, fmha_out, nullptr, out_linear_out, nullptr);
300 301 302
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

303 304
    bool add_residual = ctx.Attr<bool>("add_residual");
    const T *residual_ptr = add_residual ? x_data : nullptr;
L
Li Min 已提交
305 306 307
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
308 309 310 311 312 313
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          final_out_data,
          dropout_mask_out_data);
L
Li Min 已提交
314
    } else {
315
      // TODO(Xreki): support post layer_norm case when add_residual is false.
316 317
      PADDLE_ENFORCE_EQ(add_residual,
                        true,
318 319 320 321 322 323 324
                        platform::errors::InvalidArgument(
                            "Attribute add_residual is expected to be true "
                            "when pre_layer_norm is false."));

      const U *ln_scale_2_ptr = ln_scale_2 ? ln_scale_2->data<U>() : nullptr;
      const U *ln_bias_2_ptr = ln_bias_2 ? ln_bias_2->data<U>() : nullptr;
      T *bias_dropout_residual_out_ptr =
L
Li Min 已提交
325
          bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
326 327
      U *ln_mean_2_ptr = ln_mean_2->mutable_data<U>(ctx.GetPlace());
      U *ln_var_2_ptr = ln_var_2->mutable_data<U>(ctx.GetPlace());
L
Li Min 已提交
328 329
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
330 331 332 333 334 335 336 337 338 339 340
          ctx.cuda_device_context(),
          out_linear_out_data,
          residual_ptr,
          out_linear_bias_data,
          ln_scale_2_ptr,
          ln_bias_2_ptr,
          bias_dropout_residual_out_ptr,
          dropout_mask_out_data,
          final_out_data,
          ln_mean_2_ptr,
          ln_var_2_ptr);
L
Li Min 已提交
341
    }
L
Li Min 已提交
342 343 344
  }
};

345 346 347 348 349 350 351 352 353 354
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
355
    bool is_test_1 = ctx.Attr<bool>("is_test");
356 357 358 359 360 361 362
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
363
    int ring_id = ctx.Attr<int>("ring_id");
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384

    // get inputs.
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto *d_y_data = d_y->data<T>();

    // fw input
    auto *input_x = ctx.Input<Tensor>("X");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_2_scale = ctx.Input<Tensor>("Ln2Scale");
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
    auto *qkv_weight_data = qkv_weight->data<T>();
385
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
386
    auto *out_linear_weight_data = out_linear_weight->data<T>();
387 388
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409

    // fw output
    auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
    auto *qk_out = ctx.Input<Tensor>("QKOut");
    auto *qktv_out = ctx.Input<Tensor>("QKTVOut");
    auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
    auto *attn_dropout_mask_out = ctx.Input<Tensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<Tensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<Tensor>("SrcMaskOut");
    auto *out_linear_out = ctx.Input<Tensor>("OutLinearOut");
    auto *ln_2_mean = ctx.Input<Tensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<Tensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
    auto *bias_dropout_residual_out =
        ctx.Input<Tensor>("BiasDropoutResidualOut");
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *qk_out_data = qk_out->data<T>();
    auto *qktv_out_data = qktv_out->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
410 411
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
    auto *out_linear_out_data = out_linear_out->data<T>();
    auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();

    // output's grad
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
    auto *d_qkv_bias_out =
        ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out = ctx.Output<Tensor>(framework::GradVarName("QKTVOut"));
    auto *d_transpose_out_2 =
        ctx.Output<Tensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out = ctx.Output<Tensor>(framework::GradVarName("QKOut"));
    auto *d_softmax_out =
        ctx.Output<Tensor>(framework::GradVarName("SoftmaxOut"));
    auto *d_attn_dropout_out =
        ctx.Output<Tensor>(framework::GradVarName("AttnDropoutOut"));
    auto *d_src_mask_out =
        ctx.Output<Tensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
    auto *d_out_linear_out =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out =
        ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
    auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
436 437 438 439 440 441 442 443 444
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
                               : d_qkv_out->mutable_data<T>(ctx.GetPlace());
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
            : d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
445 446 447 448 449 450 451
    auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
    auto *d_transpose_out_2_data =
        d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
    auto *d_qk_out_data = d_qk_out->mutable_data<T>(ctx.GetPlace());
    auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
    auto *d_attn_dropout_out_data =
        d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
452 453 454
    auto *d_src_mask_out_data =
        (src_mask == nullptr) ? nullptr
                              : d_src_mask_out->mutable_data<T>(ctx.GetPlace());
455 456 457 458 459 460 461 462 463 464 465 466 467
    auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
    auto *d_out_linear_out_data =
        d_out_linear_out->mutable_data<T>(ctx.GetPlace());

    // parameter grad
    auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
    auto *d_out_linear_weight =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearW"));
    auto *d_out_linear_bias =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
468

469
    auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
470 471 472
    auto *d_qkv_bias_data = (d_qkv_bias == nullptr)
                                ? nullptr
                                : d_qkv_bias->mutable_data<T>(ctx.GetPlace());
473 474 475
    auto *d_out_linear_weight_data =
        d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
    auto *d_out_linear_bias_data =
476 477 478
        (d_out_linear_bias == nullptr)
            ? nullptr
            : d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

494
    bool add_residual = ctx.Attr<bool>("add_residual");
495
    Tensor d_residual;
496 497 498 499 500
    T *d_residual_data = nullptr;
    if (add_residual) {
      d_residual.Resize(input_x_dims);
      d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
    }
501 502 503

    bool transA = false;
    bool transB = true;
504
    bool compute_qkv_bias = qkv_bias ? true : false;
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526
    auto layer_norm_compute = AttnLayerNorm<T>(
        ctx.cuda_device_context(), epsilon, bsz_seq, dim_embed);
    auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                     transA,
                                     transB,
                                     bsz_seq,
                                     output_size,
                                     input_size,
                                     compute_qkv_bias);
    AttnDropoutParam attn_dropout_param(is_test_1,
                                        dropout_implementation_1,
                                        attn_dropout_prob,
                                        is_upscale_in_train_1,
                                        is_fix_seed_1,
                                        seed_val_1,
                                        seed_1);
    auto fmha_ref_compute = FMHARef<T>(ctx.cuda_device_context(),
                                       batch_size,
                                       max_seq_len,
                                       num_head,
                                       dim_head,
                                       attn_dropout_param);
527 528 529
    output_size = hidden_size;
    transA = false;
    transB = false;
530
    bool compute_bias = false;
531
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
532 533 534 535 536 537 538
    auto out_linear_compute = AttnMatMul<T>(ctx.cuda_device_context(),
                                            transA,
                                            transB,
                                            bsz_seq,
                                            input_size,
                                            output_size,
                                            compute_bias);
539 540
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
541 542 543 544
        ctx.cuda_device_context(),
        bsz_seq,
        dim_embed,
        dropout_param2,
545 546
        ln2epsilon);

L
Li Min 已提交
547 548
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
549 550 551 552 553 554
          ctx.cuda_device_context(),
          d_y_data,
          dropout_mask_out_data,
          d_out_linear_out_data,
          d_residual_data,
          d_out_linear_bias_data);
L
Li Min 已提交
555 556 557 558 559 560
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
561 562 563
          (d_ln_2_scale == nullptr
               ? nullptr
               : d_ln_2_scale->mutable_data<U>(ctx.GetPlace()));
L
Li Min 已提交
564
      auto *d_ln_2_bias_data =
565 566 567
          (d_ln_2_bias == nullptr
               ? nullptr
               : d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
L
Li Min 已提交
568 569 570 571
      auto *d_bias_dropout_residual_out_data =
          d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
572 573 574 575 576 577 578 579 580 581 582 583 584
          ctx.cuda_device_context(),
          d_y_data,
          bias_dropout_residual_out_data,
          dropout_mask_out_data,
          ln_2_scale_data,
          ln_2_mean_data,
          ln_2_var_data,
          d_bias_dropout_residual_out_data,
          d_ln_2_scale_data,
          d_ln_2_bias_data,
          d_out_linear_out_data,
          d_out_linear_bias_data,
          d_residual_data);
L
Li Min 已提交
585
    }
586

587 588 589 590 591 592
    out_linear_compute.ComputeBackward(fmha_out,
                                       out_linear_weight,
                                       d_out_linear_out,
                                       d_fmha_out,
                                       d_out_linear_weight,
                                       nullptr);
L
Li Min 已提交
593

594
    if (qkv_bias != nullptr) {
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_bias_out);
611
    } else {
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
      fmha_ref_compute.ComputeBackward(*transpose_out_2,
                                       src_mask,
                                       *softmax_out,
                                       *attn_dropout_mask_out,
                                       *attn_dropout_out,
                                       *qk_out,
                                       *src_mask_out,
                                       *d_fmha_out,
                                       d_qktv_out,
                                       d_attn_dropout_out,
                                       d_softmax_out,
                                       d_src_mask_out,
                                       d_qk_out,
                                       d_transpose_out_2,
                                       nullptr,
                                       d_qkv_out);
628
    }
629 630

    if (pre_layer_norm) {
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647
      auto *ln_mean = ctx.Input<Tensor>("LnMean");
      auto *ln_var = ctx.Input<Tensor>("LnVariance");
      auto *ln_out = ctx.Input<Tensor>("LnOut");
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

      auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
      auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
      auto *d_ln_scale_data =
          (d_ln_scale == nullptr ? nullptr
                                 : d_ln_scale->mutable_data<U>(ctx.GetPlace()));
      auto *d_ln_bias_data =
          (d_ln_bias == nullptr ? nullptr
                                : d_ln_bias->mutable_data<U>(ctx.GetPlace()));
648
      if (qkv_bias != nullptr) {
649 650 651 652 653 654
        qkv_compute.ComputeBackward(ln_out,
                                    qkv_weight,
                                    d_qkv_bias_out,
                                    d_ln_out,
                                    d_qkv_weight,
                                    d_qkv_bias);
655
      } else {
656 657
        qkv_compute.ComputeBackward(
            ln_out, qkv_weight, d_qkv_out, d_ln_out, d_qkv_weight, d_qkv_bias);
658
      }
659 660
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
661 662 663 664 665 666 667 668
      layer_norm_compute.ComputeBackward(x_data,
                                         d_ln_out_data,
                                         ln_scale_data,
                                         ln_mean_data,
                                         ln_var_data,
                                         d_x_data,
                                         d_ln_scale_data,
                                         d_ln_bias_data);
669
    } else {
670
      if (qkv_bias != nullptr) {
671 672
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_bias_out, d_x, d_qkv_weight, d_qkv_bias);
673
      } else {
674 675
        qkv_compute.ComputeBackward(
            input_x, qkv_weight, d_qkv_out, d_x, d_qkv_weight, d_qkv_bias);
676
      }
677 678
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
679
    }
680 681 682 683 684

    if (add_residual) {
      // gradient accumulation
      std::vector<const Tensor *> ins = {&d_residual, d_x};
      std::vector<Tensor *> outs = {d_x};
685 686
      phi::funcs::ElementwiseKernel<T>(
          ctx.cuda_device_context(), ins, &outs, phi::funcs::AddFunctor<T>());
687
    }
688 689 690
  }
};

L
Li Min 已提交
691 692 693 694 695
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
696 697
REGISTER_OP_CUDA_KERNEL(fused_attention,
                        ops::FusedAttentionOpKernel<float>,
L
Li Min 已提交
698 699
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
700 701 702 703
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);