fused_attention_op.cu 25.3 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <cuda_fp16.h>
16

L
Li Min 已提交
17
#include <cub/cub.cuh>
18

L
Li Min 已提交
19 20
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
21 22 23 24
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
25 26
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
27 28
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
29
#include "paddle/phi/kernels/funcs/math_function.h"
L
Li Min 已提交
30

31 32 33 34 35
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#endif

L
Li Min 已提交
36 37 38 39 40
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
template <typename T>
static void AllReduce(framework::Tensor &tensor,  // NOLINT
                      const int ring_id,
                      const platform::CUDADeviceContext &ctx) {
  if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
  auto dtype =
      platform::ToNCCLDataType(framework::TransToProtoVarType(tensor.dtype()));
  int64_t numel = tensor.numel();
  const void *sendbuff = tensor.data<T>();
  auto place = ctx.GetPlace();
  void *recvbuff = tensor.mutable_data<T>(place);
  auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
  auto stream = ctx.stream();
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
      sendbuff, recvbuff, numel, dtype, ncclSum, comm->comm(), stream));
#else
  PADDLE_THROW(platform::errors::Unimplemented(
      "PaddlePaddle should compile with NCCL or RCCL when used tensor model "
      "parallel op."));
#endif
}

L
Li Min 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    auto *input_x = ctx.Input<Tensor>("X");

    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_bias = ctx.Input<Tensor>("LnBias");
    auto *ln_mean = ctx.Output<Tensor>("LnMean");
    auto *ln_var = ctx.Output<Tensor>("LnVariance");
    auto *ln_out = ctx.Output<Tensor>("LnOut");

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *qkv_out = ctx.Output<Tensor>("QKVOut");
    auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");

    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
88 89
    auto *cache_kv = ctx.Input<Tensor>("CacheKV");
    auto *cache_kv_out = ctx.Output<Tensor>("CacheKVOut");
L
Li Min 已提交
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    auto *qk_out = ctx.Output<Tensor>("QKOut");
    auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
    auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
    auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
    auto *fmha_out = ctx.Output<Tensor>("FMHAOut");

    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");

    auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
    auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
    auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
    auto *bias_dropout_residual_out =
        ctx.Output<Tensor>("BiasDropoutResidualOut");
    auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
    auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
    const float ln_epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
112
    bool is_test_1 = ctx.Attr<bool>("is_test");
L
Li Min 已提交
113 114 115 116 117 118 119
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
120
    int ring_id = ctx.Attr<int>("ring_id");
L
Li Min 已提交
121 122 123 124 125 126 127 128 129 130

    // final output.
    auto *out = ctx.Output<Tensor>("Y");

    // get data ptr for qkv part.
    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    auto *x_data = input_x->data<T>();
    auto *qkv_weight_data = qkv_weight->data<T>();
131
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
L
Li Min 已提交
132
    auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
133 134 135
    auto *qkv_bias_out_data =
        (qkv_bias == nullptr) ? nullptr
                              : qkv_bias_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
136 137 138 139

    // get data ptr for FMHA.
    auto *transpose_out_2_data =
        transpose_out_2->mutable_data<T>(ctx.GetPlace());
140 141 142 143
    auto *cache_kv_out_data =
        (cache_kv_out == nullptr)
            ? nullptr
            : cache_kv_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
144 145
    auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
    auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
146 147 148
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr
                              : src_mask_out->mutable_data<T>(ctx.GetPlace());
L
Li Min 已提交
149 150 151 152 153 154 155 156 157
    auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
    auto *attn_dropout_mask_out_data =
        attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
    auto *attn_dropout_out_data =
        attn_dropout_out->mutable_data<T>(ctx.GetPlace());
    auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());

    // get data ptr for out_linear.
    auto *out_linear_weight_data = out_linear_weight->data<T>();
158 159
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
L
Li Min 已提交
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());

    // get data ptr for bias+dropout+residual+layernorm
    auto *dropout_mask_out_data =
        dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
    auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];

    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

    auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
                                               epsilon, bsz_seq, dim_embed);
181 182 183 184 185

    bool compute_bias = true;
    if (qkv_bias == nullptr) {
      compute_bias = false;
    }
L
Li Min 已提交
186
    // (transA, transB, compute_bias) = (false, true, true)
187 188 189
    auto qkv_compute =
        AttnMatMul<T>(ctx.cuda_device_context(), false, true, bsz_seq,
                      output_size, input_size, compute_bias);
L
Li Min 已提交
190 191 192 193 194 195 196 197 198 199

    AttnDropoutParam attn_dropout_param(
        is_test_1, dropout_implementation_1, attn_dropout_rate,
        is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
    auto fmha_ref_compute =
        FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
                   dim_head, attn_dropout_param);

    output_size = hidden_size;
    // (transA, transB, compute_bias) = (false, false, false)
200 201 202 203 204
    // NOTE(Yuang Liu): For general input size == output size, change the
    // position won't have effects. For mp, the output size is mp_head * dkey
    // which is actually the input size. While the input size is hidden size,
    // which is actually the output size. So for out linear, switch the
    // input size and output size.
L
Li Min 已提交
205 206
    auto out_linear_compute =
        AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
207
                      input_size, output_size, false);
L
Li Min 已提交
208 209 210 211 212 213
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
        ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
        ln_epsilon);

    if (pre_layer_norm) {
L
Li Min 已提交
214 215 216 217 218 219 220
      auto *ln_scale_data =
          (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
      auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
      auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
      auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
      auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());

L
Li Min 已提交
221 222
      layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
                                        ln_out_data, ln_mean_data, ln_var_data);
L
Li Min 已提交
223 224
      qkv_compute.ComputeForward(qkv_weight, ln_out, qkv_bias, qkv_out,
                                 qkv_bias_out);
L
Li Min 已提交
225
    } else {
L
Li Min 已提交
226 227
      qkv_compute.ComputeForward(qkv_weight, input_x, qkv_bias, qkv_out,
                                 qkv_bias_out);
L
Li Min 已提交
228
    }
229
    if (qkv_bias == nullptr) {
230 231 232 233
      fmha_ref_compute.ComputeForward(
          *qkv_out, cache_kv, src_mask, transpose_out_2, cache_kv_out, qk_out,
          src_mask_out, softmax_out, attn_dropout_mask_out, attn_dropout_out,
          qktv_out, fmha_out);
234
    } else {
235 236 237 238
      fmha_ref_compute.ComputeForward(
          *qkv_bias_out, cache_kv, src_mask, transpose_out_2, cache_kv_out,
          qk_out, src_mask_out, softmax_out, attn_dropout_mask_out,
          attn_dropout_out, qktv_out, fmha_out);
239
    }
240

L
Li Min 已提交
241 242 243
    // fmha_out: [batch_size, seq_len, num_head, head_dim]
    // weight:   [embed_dim, embed_dim]
    // out_linear_out: [batch_size, seq_len, embed_dim]
L
Li Min 已提交
244 245
    out_linear_compute.ComputeForward(out_linear_weight, fmha_out, nullptr,
                                      out_linear_out, nullptr);
246 247 248
    // tensor model parallel
    AllReduce<T>(*out_linear_out, ring_id, ctx.cuda_device_context());

L
Li Min 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
    if (pre_layer_norm) {
      // output = (residual + dropout(input + bias))
      fused_dropout_layernorm_helper.ResidualDropoutBias(
          ctx.cuda_device_context(), out_linear_out_data, x_data,
          out_linear_bias_data, final_out_data, dropout_mask_out_data);
    } else {
      auto *ln_scale_2_data =
          (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
      auto *ln_bias_2_data =
          (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
      auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
      auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
      // output = layernorm(residual + dropout(input + bias))
      fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
          ctx.cuda_device_context(), out_linear_out_data, x_data,
          out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
          bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
          ln_mean_2_data, ln_var_2_data);
    }
L
Li Min 已提交
270 271 272
  }
};

273 274 275 276 277 278 279 280 281 282
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    using U = LayerNormParamType<T>;
    const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
    const float epsilon = ctx.Attr<float>("epsilon");
    const float ln2epsilon = ctx.Attr<float>("ln_epsilon");

    float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
L
Li Min 已提交
283
    bool is_test_1 = ctx.Attr<bool>("is_test");
284 285 286 287 288 289 290
    auto &dropout_implementation_1 =
        ctx.Attr<std::string>("attn_dropout_implementation");
    bool is_upscale_in_train_1 =
        (dropout_implementation_1 == "upscale_in_train");
    auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
    bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
    int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
291
    int ring_id = ctx.Attr<int>("ring_id");
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312

    // get inputs.
    auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
    auto *d_y_data = d_y->data<T>();

    // fw input
    auto *input_x = ctx.Input<Tensor>("X");
    auto *ln_scale = ctx.Input<Tensor>("LnScale");
    auto *ln_2_scale = ctx.Input<Tensor>("Ln2Scale");
    auto *x_data = input_x->data<T>();
    auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
    auto *ln_2_scale_data =
        (ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
    // fw parameters.
    auto *src_mask = ctx.Input<Tensor>("SrcMask");
    auto *qkv_weight = ctx.Input<Tensor>("QKVW");
    auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
    auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
    auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
    auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
    auto *qkv_weight_data = qkv_weight->data<T>();
313
    auto *qkv_bias_data = (qkv_bias == nullptr) ? nullptr : qkv_bias->data<T>();
314
    auto *out_linear_weight_data = out_linear_weight->data<T>();
315 316
    auto *out_linear_bias_data =
        (out_linear_bias == nullptr) ? nullptr : out_linear_bias->data<T>();
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337

    // fw output
    auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
    auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
    auto *qk_out = ctx.Input<Tensor>("QKOut");
    auto *qktv_out = ctx.Input<Tensor>("QKTVOut");
    auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
    auto *attn_dropout_mask_out = ctx.Input<Tensor>("AttnDropoutMaskOut");
    auto *attn_dropout_out = ctx.Input<Tensor>("AttnDropoutOut");
    auto *src_mask_out = ctx.Input<Tensor>("SrcMaskOut");
    auto *out_linear_out = ctx.Input<Tensor>("OutLinearOut");
    auto *ln_2_mean = ctx.Input<Tensor>("Ln2Mean");
    auto *ln_2_var = ctx.Input<Tensor>("Ln2Variance");
    auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
    auto *bias_dropout_residual_out =
        ctx.Input<Tensor>("BiasDropoutResidualOut");
    auto *fmha_out_data = fmha_out->data<T>();
    auto *transpose_out_2_data = transpose_out_2->data<T>();
    auto *qk_out_data = qk_out->data<T>();
    auto *qktv_out_data = qktv_out->data<T>();
    auto *softmax_out_data = softmax_out->data<T>();
338 339
    auto *src_mask_out_data =
        (src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
    auto *out_linear_out_data = out_linear_out->data<T>();
    auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();

    // output's grad
    auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
    auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
    auto *d_qkv_bias_out =
        ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
    auto *d_qktv_out = ctx.Output<Tensor>(framework::GradVarName("QKTVOut"));
    auto *d_transpose_out_2 =
        ctx.Output<Tensor>(framework::GradVarName("TransposeOut2"));
    auto *d_qk_out = ctx.Output<Tensor>(framework::GradVarName("QKOut"));
    auto *d_softmax_out =
        ctx.Output<Tensor>(framework::GradVarName("SoftmaxOut"));
    auto *d_attn_dropout_out =
        ctx.Output<Tensor>(framework::GradVarName("AttnDropoutOut"));
    auto *d_src_mask_out =
        ctx.Output<Tensor>(framework::GradVarName("SrcMaskOut"));
    auto *d_fmha_out = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
    auto *d_out_linear_out =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
    auto *d_bias_dropout_residual_out =
        ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
    auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
364 365 366 367 368 369 370 371 372
    // when qkv_bias is not nullptr, d_qkv_out is equals to d_qkv_bias_out, the
    // space can be reused.
    auto *d_qkv_out_data = (d_qkv_bias_out != nullptr)
                               ? nullptr
                               : d_qkv_out->mutable_data<T>(ctx.GetPlace());
    auto *d_qkv_bias_out_data =
        (d_qkv_bias_out == nullptr)
            ? nullptr
            : d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
373 374 375 376 377 378 379
    auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
    auto *d_transpose_out_2_data =
        d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
    auto *d_qk_out_data = d_qk_out->mutable_data<T>(ctx.GetPlace());
    auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
    auto *d_attn_dropout_out_data =
        d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
380 381 382
    auto *d_src_mask_out_data =
        (src_mask == nullptr) ? nullptr
                              : d_src_mask_out->mutable_data<T>(ctx.GetPlace());
383 384 385 386 387 388 389 390 391 392 393 394 395
    auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
    auto *d_out_linear_out_data =
        d_out_linear_out->mutable_data<T>(ctx.GetPlace());

    // parameter grad
    auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
    auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
    auto *d_out_linear_weight =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearW"));
    auto *d_out_linear_bias =
        ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
    auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
    auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
396

397
    auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
398 399 400
    auto *d_qkv_bias_data = (d_qkv_bias == nullptr)
                                ? nullptr
                                : d_qkv_bias->mutable_data<T>(ctx.GetPlace());
401 402 403
    auto *d_out_linear_weight_data =
        d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
    auto *d_out_linear_bias_data =
404 405 406
        (d_out_linear_bias == nullptr)
            ? nullptr
            : d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427

    const auto input_x_dims = input_x->dims();
    const auto qkv_w_dims = qkv_weight->dims();

    int batch_size = input_x_dims[0];
    int max_seq_len = input_x_dims[1];
    int dim_embed = input_x_dims[2];
    int num_head = qkv_w_dims[1];
    int dim_head = qkv_w_dims[2];

    int bsz_seq = batch_size * max_seq_len;
    int hidden_size = num_head * dim_head;
    int output_size = 3 * hidden_size;
    int input_size = dim_embed;

    Tensor d_residual;
    d_residual.Resize(input_x_dims);
    T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());

    bool transA = false;
    bool transB = true;
428 429 430 431
    bool compute_qkv_bias = true;
    if (qkv_bias == nullptr) {
      compute_qkv_bias = false;
    }
432 433 434 435
    auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
                                               epsilon, bsz_seq, dim_embed);
    auto qkv_compute =
        AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
436
                      output_size, input_size, compute_qkv_bias);
437 438 439 440 441 442 443 444 445
    AttnDropoutParam attn_dropout_param(
        is_test_1, dropout_implementation_1, attn_dropout_prob,
        is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
    auto fmha_ref_compute =
        FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
                   dim_head, attn_dropout_param);
    output_size = hidden_size;
    transA = false;
    transB = false;
446
    bool compute_bias = false;
447
    // (b*s, num_head * dim_head) * (num_head * dim_head, dim_embed)
448 449
    auto out_linear_compute =
        AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
450
                      input_size, output_size, compute_bias);
451 452 453 454 455
    DropoutParam dropout_param2(ctx, 0);
    FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
        ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
        ln2epsilon);

L
Li Min 已提交
456 457 458 459 460 461 462 463 464 465
    if (pre_layer_norm) {
      fused_dropout_layernorm_helper.ResidualDropoutBiasGrad(
          ctx.cuda_device_context(), d_y_data, dropout_mask_out_data,
          d_out_linear_out_data, d_residual_data, d_out_linear_bias_data);
    } else {
      auto *ln_2_mean_data = ln_2_mean->data<U>();
      auto *ln_2_var_data = ln_2_var->data<U>();
      auto *bias_dropout_residual_out_data =
          bias_dropout_residual_out->data<T>();
      auto *d_ln_2_scale_data =
466 467 468
          (d_ln_2_scale == nullptr
               ? nullptr
               : d_ln_2_scale->mutable_data<U>(ctx.GetPlace()));
L
Li Min 已提交
469
      auto *d_ln_2_bias_data =
470 471 472
          (d_ln_2_bias == nullptr
               ? nullptr
               : d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
L
Li Min 已提交
473 474 475 476 477 478 479 480 481
      auto *d_bias_dropout_residual_out_data =
          d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());

      fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
          ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
          dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
          d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
          d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
    }
482

L
Li Min 已提交
483 484 485 486
    out_linear_compute.ComputeBackward(fmha_out, out_linear_weight,
                                       d_out_linear_out, d_fmha_out,
                                       d_out_linear_weight, nullptr);

487 488 489 490 491 492 493 494 495 496 497 498 499
    if (qkv_bias != nullptr) {
      fmha_ref_compute.ComputeBackward(
          *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
          *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
          d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
          d_transpose_out_2, nullptr, d_qkv_bias_out);
    } else {
      fmha_ref_compute.ComputeBackward(
          *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
          *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
          d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
          d_transpose_out_2, nullptr, d_qkv_out);
    }
500 501

    if (pre_layer_norm) {
502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
      auto *ln_mean = ctx.Input<Tensor>("LnMean");
      auto *ln_var = ctx.Input<Tensor>("LnVariance");
      auto *ln_out = ctx.Input<Tensor>("LnOut");
      auto *ln_mean_data = ln_mean->data<U>();
      auto *ln_var_data = ln_var->data<U>();
      auto *ln_out_data = ln_out->data<T>();

      auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
      auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
      auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
      auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
      auto *d_ln_scale_data =
          (d_ln_scale == nullptr ? nullptr
                                 : d_ln_scale->mutable_data<U>(ctx.GetPlace()));
      auto *d_ln_bias_data =
          (d_ln_bias == nullptr ? nullptr
                                : d_ln_bias->mutable_data<U>(ctx.GetPlace()));
519 520 521 522 523 524 525
      if (qkv_bias != nullptr) {
        qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_bias_out,
                                    d_ln_out, d_qkv_weight, d_qkv_bias);
      } else {
        qkv_compute.ComputeBackward(ln_out, qkv_weight, d_qkv_out, d_ln_out,
                                    d_qkv_weight, d_qkv_bias);
      }
526 527
      // tensor model parallel
      AllReduce<T>(*d_ln_out, ring_id, ctx.cuda_device_context());
528 529 530 531
      layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
                                         ln_mean_data, ln_var_data, d_x_data,
                                         d_ln_scale_data, d_ln_bias_data);
    } else {
532 533 534 535 536 537 538
      if (qkv_bias != nullptr) {
        qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_bias_out, d_x,
                                    d_qkv_weight, d_qkv_bias);
      } else {
        qkv_compute.ComputeBackward(input_x, qkv_weight, d_qkv_out, d_x,
                                    d_qkv_weight, d_qkv_bias);
      }
539 540
      // tensor model parallel
      AllReduce<T>(*d_x, ring_id, ctx.cuda_device_context());
541 542 543 544 545 546 547 548
    }
    // gradient accumulation
    std::vector<const Tensor *> ins;
    std::vector<Tensor *> outs;
    ins.emplace_back(&d_residual);
    ins.emplace_back(d_x);
    outs.emplace_back(d_x);
    int elewise_add_axis = -1;
549
    phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
550
        ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
551
        phi::funcs::AddFunctor<T>());
552 553 554
  }
};

L
Li Min 已提交
555 556 557 558 559 560 561 562
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
                        ops::FusedAttentionOpKernel<double>,
                        ops::FusedAttentionOpKernel<plat::float16>);
563 564 565 566
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
                        ops::FusedAttentionGradKernel<float>,
                        ops::FusedAttentionGradKernel<double>,
                        ops::FusedAttentionGradKernel<plat::float16>);