fused_gate_attention_op.cu 24.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fused_gate_attention.h"
19
#include "paddle/phi/backends/gpu/gpu_device_function.h"
20 21 22 23 24
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
namespace operators {

25
using Tensor = phi::DenseTensor;
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

template <typename T>
struct SigmoidMultiplyFunctor {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
  // out = sigmoid(x) * y
  inline HOSTDEVICE T operator()(T x, T y) const {
    MPType x_mp = static_cast<MPType>(x);
    T sigmoid_out = static_cast<T>(one / (one + exp(-x_mp)));
    return sigmoid_out * y;
  }
};

template <typename T>
struct SigmoidMultiplyGradFunctor {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // Gradient of Multiply:
  //  dx = dout * y
  //  dy = dout * x
  // Gradient of Sigmoid: dx = dout * out * (1 - out)
50 51
  inline HOSTDEVICE phi::Array<T, 2> operator()(const T dout,
                                                const T x,
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
                                                T y) const {
    MPType x_mp = static_cast<MPType>(x);
    T sigmoid_out = static_cast<T>(one / (one + exp(-x_mp)));
    T d_sigmoid_out = dout * y;
    phi::Array<T, 2> outs;
    outs[0] = d_sigmoid_out * sigmoid_out *
              (static_cast<T>(1.0f) - sigmoid_out);  // dx
    outs[1] = dout * sigmoid_out;                    // dy
    return outs;
  }
};

template <typename T>
void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
                                   const GateAttentionConfig<T> &config,
67 68
                                   const Tensor *query,
                                   Tensor *qkv_out) {
69
  // query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim]
70 71
  // qkv_weight: shape=[3, num_heads, head_dim, qkv_dim]
  // qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim]
72
  auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVWeight");
73 74 75

  // qkv_out = GEMM(query, qkv_weight^T)
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
76
  int n = 3 * config.num_heads * config.head_dim;
77 78 79 80 81 82 83
  int k = config.q_dim;
  auto qkv_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
  qkv_compute.ComputeForward(qkv_weight, query, nullptr, qkv_out, nullptr);
}

template <typename T>
84 85 86 87
void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
                                    const GateAttentionGradConfig<T> &config,
                                    const Tensor *query,
                                    const Tensor *qkv_out_grad,
88 89
                                    Tensor *query_grad,
                                    bool use_addto) {
90
  auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVWeight");
91
  auto *qkv_weight_grad =
92
      ctx.Output<phi::DenseTensor>(framework::GradVarName("QKVWeight"));
93 94
  auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
  dev_ctx.Alloc<T>(qkv_weight_grad, qkv_weight_grad->numel() * sizeof(T));
95 96 97

  // Gradient of GEMM(query, qkv_weight)
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
98
  int n = 3 * config.num_heads * config.head_dim;
99 100 101
  int k = config.q_dim;
  auto qkv_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
102 103 104 105 106 107 108
  qkv_compute.ComputeBackward(query,
                              qkv_weight,
                              qkv_out_grad,
                              query_grad,
                              qkv_weight_grad,
                              nullptr,
                              use_addto);
109 110 111 112 113
}

template <typename T>
void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
                                      const GateAttentionConfig<T> &config,
114 115 116 117
                                      const Tensor *query,
                                      const Tensor *key,
                                      Tensor *query_out,
                                      Tensor *key_out,
118
                                      Tensor *value_out) {
119 120 121
  auto *query_weight = ctx.Input<phi::DenseTensor>("QueryWeight");
  auto *key_weight = ctx.Input<phi::DenseTensor>("KeyWeight");
  auto *value_weight = ctx.Input<phi::DenseTensor>("ValueWeight");
122 123 124

  // query_out = GEMM(query, query_weight)
  // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
125 126
  // query_weight: shape=[q_dim, num_heads, head_dim]
  // query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, head_dim]
127
  int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
128
  int q_n = config.num_heads * config.head_dim;
129
  int q_k = config.q_dim;
130 131
  auto q_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false);
132 133 134 135
  q_compute.ComputeForward(query_weight, query, nullptr, query_out, nullptr);

  // k_out = GEMM(key, key_weight)
  // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
136 137
  // key_weight: shape=[kv_dim, num_heads, head_dim]
  // key_out: shape=[batch_size, seq_len_m, m_size, num_heads, head_dim]
138
  int kv_m = config.batch_size * config.seq_len_m * config.m_size;
139
  int kv_n = config.num_heads * config.head_dim;
140
  int kv_k = config.kv_dim;
141 142
  auto kv_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false);
143 144 145 146 147 148 149
  kv_compute.ComputeForward(key_weight, key, nullptr, key_out, nullptr);

  // value_out = GEMM(value, value_weight)
  kv_compute.ComputeForward(value_weight, key, nullptr, value_out, nullptr);
}

template <typename T>
150 151
void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
                                       const GateAttentionGradConfig<T> &config,
152 153
                                       const Tensor *query,
                                       const Tensor *key,
154 155 156
                                       const Tensor *query_out_grad,
                                       const Tensor *key_out_grad,
                                       const Tensor *value_out_grad,
157 158
                                       Tensor *query_grad,
                                       Tensor *key_grad,
159
                                       bool use_addto) {
160
  // Gradient of GEMM(key, k_weight)
161
  const auto *key_weight = ctx.Input<phi::DenseTensor>("KeyWeight");
162
  auto *key_weight_grad =
163
      ctx.Output<phi::DenseTensor>(framework::GradVarName("KeyWeight"));
164 165
  auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
  dev_ctx.Alloc<T>(key_weight_grad, key_weight_grad->numel() * sizeof(T));
166 167

  int kv_m = config.batch_size * config.seq_len_m * config.m_size;
168
  int kv_n = config.num_heads * config.head_dim;
169
  int kv_k = config.kv_dim;
170 171 172 173
  auto kv_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false);
  kv_compute.ComputeBackward(
      key, key_weight, key_out_grad, key_grad, key_weight_grad, nullptr, false);
174 175

  // Gradient of GEMM(value, v_weight)
176
  auto *value_weight = ctx.Input<phi::DenseTensor>("ValueWeight");
177
  auto *value_weight_grad =
178
      ctx.Output<phi::DenseTensor>(framework::GradVarName("ValueWeight"));
179
  dev_ctx.Alloc<T>(value_weight_grad, value_weight_grad->numel() * sizeof(T));
180

181 182 183 184 185 186 187
  kv_compute.ComputeBackward(key,
                             value_weight,
                             value_out_grad,
                             key_grad,
                             value_weight_grad,
                             nullptr,
                             true);
188 189

  // Gradient of GEMM(query, query_weight)
190
  const auto *query_weight = ctx.Input<phi::DenseTensor>("QueryWeight");
191
  auto *query_weight_grad =
192
      ctx.Output<phi::DenseTensor>(framework::GradVarName("QueryWeight"));
193
  dev_ctx.Alloc<T>(query_weight_grad, query_weight_grad->numel() * sizeof(T));
194 195

  int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
196
  int q_n = config.num_heads * config.head_dim;
197
  int q_k = config.q_dim;
198 199 200 201 202 203 204 205 206
  auto q_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false);
  q_compute.ComputeBackward(query,
                            query_weight,
                            query_out_grad,
                            query_grad,
                            query_weight_grad,
                            nullptr,
                            use_addto);
207 208 209
}

template <typename T>
210 211
void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
                                const GateAttentionConfig<T> &config,
212 213
                                const Tensor *query,
                                const Tensor *fmha_out,
214
                                Tensor *gate_out) {
215 216
  auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight");
  auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias");
217 218 219 220 221 222

  // The first gate_bias_out stores the result of the multiplication,
  // and the second gate_bias_out stores the result of the multiplication +
  // bias.
  //   gate_out = GEMM(query, gate_weight) + gate_bias
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
223
  int n = config.num_heads * config.head_dim;
224 225 226
  int k = config.q_dim;
  auto gate_attn_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
227 228
  gate_attn_compute.ComputeForward(
      gate_weight, query, gate_bias, gate_out, gate_out);
229 230 231 232

  // gate_out = sigmoid(gate_out) * fmha_out
  std::vector<const Tensor *> ins = {gate_out, fmha_out};
  std::vector<Tensor *> outs = {gate_out};
233 234
  phi::funcs::ElementwiseKernel<T>(
      ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor<T>());
235 236 237
}

template <typename T>
238 239
void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
                                 const GateAttentionGradConfig<T> &config,
240 241
                                 const Tensor *query,
                                 const Tensor *fmha_out,
242
                                 const Tensor *gate_out_grad,
243 244
                                 Tensor *query_grad,
                                 Tensor *fmha_out_grad) {
245 246
  const auto *gate_weight = ctx.Input<phi::DenseTensor>("GateWeight");
  const auto *gate_bias = ctx.Input<phi::DenseTensor>("GateBias");
247
  auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
248 249 250
  // Re-compute gate_bias_out
  Tensor gate_bias_out;
  gate_bias_out.Resize(config.gate_out_dims);
251
  dev_ctx.Alloc<T>(&gate_bias_out, gate_bias_out.numel() * sizeof(T));
252 253

  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
254
  int n = config.num_heads * config.head_dim;
255 256 257
  int k = config.q_dim;
  auto gate_attn_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
258 259
  gate_attn_compute.ComputeForward(
      gate_weight, query, gate_bias, &gate_bias_out, &gate_bias_out);
260 261 262 263 264 265 266 267 268 269

  // Gradient of sigmoid(gate_bias_out) * fmha_out
  // Compute inplace and save gate_bias_out_grad to gate_bias_out.
  std::vector<const Tensor *> ins = {gate_out_grad, &gate_bias_out, fmha_out};
  std::vector<Tensor *> outs = {&gate_bias_out, fmha_out_grad};
  phi::funcs::ElementwiseKernel<T, SigmoidMultiplyGradFunctor<T>, 2>(
      ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyGradFunctor<T>());

  // Gradient of GEMM(query, gate_weight) + gate_bias
  auto *gate_weight_grad =
270 271 272
      ctx.Output<phi::DenseTensor>(framework::GradVarName("GateWeight"));
  auto *gate_bias_grad =
      ctx.Output<phi::DenseTensor>(framework::GradVarName("GateBias"));
273 274
  dev_ctx.Alloc<T>(gate_weight_grad, gate_weight_grad->numel() * sizeof(T));
  dev_ctx.Alloc<T>(gate_bias_grad, gate_bias_grad->numel() * sizeof(T));
275

276 277 278 279 280
  gate_attn_compute.ComputeBackward(query,
                                    gate_weight,
                                    &gate_bias_out,
                                    query_grad,
                                    gate_weight_grad,
281 282 283 284
                                    gate_bias_grad);
}

template <typename T>
285 286
void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
                                const GateAttentionConfig<T> &config,
287 288
                                const Tensor *fmha_or_gate_out,
                                Tensor *out) {
289 290 291
  const auto *out_linear_weight =
      ctx.Input<phi::DenseTensor>("OutLinearWeight");
  const auto *out_linear_bias = ctx.Input<phi::DenseTensor>("OutLinearBias");
292 293 294 295

  // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
  int n = config.q_dim;
296
  int k = config.num_heads * config.head_dim;
297 298
  auto out_linear_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
299 300
  out_linear_compute.ComputeForward(
      out_linear_weight, fmha_or_gate_out, out_linear_bias, out, out);
301 302 303
}

template <typename T>
304 305
void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
                                 const GateAttentionGradConfig<T> &config,
306 307
                                 const Tensor *input,
                                 Tensor *input_grad) {
308
  auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
309 310 311 312
  const auto *out_grad =
      ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
  const auto *out_linear_weight =
      ctx.Input<phi::DenseTensor>("OutLinearWeight");
313 314

  auto *out_linear_weight_grad =
315
      ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearWeight"));
316
  auto *out_linear_bias_grad =
317
      ctx.Output<phi::DenseTensor>(framework::GradVarName("OutLinearBias"));
318

319 320 321 322
  dev_ctx.Alloc<T>(out_linear_weight_grad,
                   out_linear_weight_grad->numel() * sizeof(T));
  dev_ctx.Alloc<T>(out_linear_bias_grad,
                   out_linear_bias_grad->numel() * sizeof(T));
323

324 325
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
  int n = config.q_dim;
326
  int k = config.num_heads * config.head_dim;
327 328
  auto out_linear_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
329 330 331 332 333
  out_linear_compute.ComputeBackward(input,
                                     out_linear_weight,
                                     out_grad,
                                     input_grad,
                                     out_linear_weight_grad,
334 335 336 337 338 339 340
                                     out_linear_bias_grad);
}

template <typename T>
class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
341 342 343 344
    const auto *query = ctx.Input<phi::DenseTensor>("Query");
    const auto *key = ctx.Input<phi::DenseTensor>("Key");
    const auto *query_weight = ctx.Input<phi::DenseTensor>("QueryWeight");
    const auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVWeight");
345

346 347
    const auto *src_mask = ctx.Input<phi::DenseTensor>("SrcMask");
    const auto *nonbatched_bias = ctx.Input<phi::DenseTensor>("NonbatchedBias");
348

349 350 351 352
    auto *q_transpose_out = ctx.Output<phi::DenseTensor>("QueryTransposeOut");
    auto *k_transpose_out = ctx.Output<phi::DenseTensor>("KeyTransposeOut");
    auto *v_transpose_out = ctx.Output<phi::DenseTensor>("ValueTransposeOut");
    auto *qkv_transpose_out = ctx.Output<phi::DenseTensor>("QKVTransposeOut");
353

354 355 356 357
    auto *softmax_out = ctx.Output<phi::DenseTensor>("SoftmaxOut");
    auto *fmha_out = ctx.Output<phi::DenseTensor>("FMHAOut");
    auto *gate_out = ctx.Output<phi::DenseTensor>("GateOut");
    auto *out = ctx.Output<phi::DenseTensor>("Out");
358 359 360 361

    const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
    const bool has_gating = ctx.Attr<bool>("has_gating");

L
Leo Chen 已提交
362
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
363 364 365 366 367 368 369 370
    AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
    AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
    if (has_gating) {
      AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
    }
    AllocWithDebugInfo<T>(dev_ctx, "out", out);

    // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
371 372
    GateAttentionConfig<T> config(
        dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
373 374

    if (merge_qkv) {
375 376 377 378 379 380 381 382
      PADDLE_ENFORCE_EQ(
          !key || query == key || query->data<T>() == key->data<T>(),
          true,
          platform::errors::InvalidArgument(
              "key is expected to be nullptr or the same as "
              "query, but recieved key=%p, query=%p.",
              key,
              query));
383

384
      // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc)
385
      Tensor *qkv_out = config.GetQKVOut();
386 387
      ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);

388
      AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
389 390
    } else {
      // 1. Separated QKV Matmul
391 392 393
      Tensor *query_out = config.GetQueryOut();
      Tensor *key_out = config.GetKeyOut();
      Tensor *value_out = config.GetValueOut();
394 395
      ComputeSeparatedQKVMatmulForward<T>(
          ctx, config, query, key, query_out, key_out, value_out);
396

397 398 399
      AllocWithDebugInfo<T>(dev_ctx, "q_transpose_out", q_transpose_out);
      AllocWithDebugInfo<T>(dev_ctx, "k_transpose_out", k_transpose_out);
      AllocWithDebugInfo<T>(dev_ctx, "v_transpose_out", v_transpose_out);
400 401 402 403
    }

    // 2. FMHA
    auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
404 405 406 407 408 409 410 411 412 413
    fmha_compute.ComputeForward(nonbatched_bias,
                                src_mask,
                                q_transpose_out,
                                k_transpose_out,
                                v_transpose_out,
                                qkv_transpose_out,
                                softmax_out,
                                fmha_out,
                                gate_out,
                                &config);
414 415

    // 3. Gating Linear
416 417 418
    if (has_gating) {
      ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out);
    }
419 420

    // 4. Output Linear
421 422
    Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
    ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out);
423 424 425 426 427 428 429 430
  }
};

template <typename T>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    // forward input
431 432 433 434
    const auto *query = ctx.Input<phi::DenseTensor>("Query");
    const auto *key = ctx.Input<phi::DenseTensor>("Key");
    const auto *query_weight = ctx.Input<phi::DenseTensor>("QueryWeight");
    const auto *qkv_weight = ctx.Input<phi::DenseTensor>("QKVWeight");
435 436

    // forward output, backward input
437 438 439 440 441 442 443 444 445 446 447
    const auto *q_transpose_out =
        ctx.Input<phi::DenseTensor>("QueryTransposeOut");
    const auto *k_transpose_out =
        ctx.Input<phi::DenseTensor>("KeyTransposeOut");
    const auto *v_transpose_out =
        ctx.Input<phi::DenseTensor>("ValueTransposeOut");
    const auto *qkv_transpose_out =
        ctx.Input<phi::DenseTensor>("QKVTransposeOut");
    const auto *softmax_out = ctx.Input<phi::DenseTensor>("SoftmaxOut");
    const auto *fmha_out = ctx.Input<phi::DenseTensor>("FMHAOut");
    const auto *gate_out = ctx.Input<phi::DenseTensor>("GateOut");
448 449

    // backward output
450 451
    auto *query_grad =
        ctx.Output<phi::DenseTensor>(framework::GradVarName("Query"));
452
    auto *nonbatched_bias_grad =
453
        ctx.Output<phi::DenseTensor>(framework::GradVarName("NonbatchedBias"));
454 455 456

    bool has_gating = ctx.Attr<bool>("has_gating");
    bool merge_qkv = ctx.Attr<bool>("merge_qkv");
457

L
Leo Chen 已提交
458
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
459
    AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
460

461 462
    GateAttentionGradConfig<T> config(
        dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
463

464 465 466
    Tensor fmha_out_grad;
    fmha_out_grad.Resize(config.gate_out_dims);
    AllocWithDebugInfo<T>(dev_ctx, "fmha_out_grad", &fmha_out_grad);
467
    if (has_gating) {
468 469 470 471 472 473 474 475
      // 1. Gradient of Output Linear: out = Linear(gate_out)
      Tensor gate_out_grad;
      gate_out_grad.Resize(config.gate_out_dims);
      AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
      ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad);

      // 2. Gradient of Gating Linear
      // Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
476 477 478 479 480 481
      ComputeGatingLinearBackward<T>(ctx,
                                     config,
                                     query,
                                     fmha_out,
                                     &gate_out_grad,
                                     query_grad,
482 483 484 485
                                     &fmha_out_grad);
    } else {
      // 1. Gradient of Output Linear: out = Linear(fmha_grad)
      ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad);
486 487 488 489
    }

    // 3. Gradient of FMHA
    if (nonbatched_bias_grad) {
490 491
      AllocWithDebugInfo<T>(
          dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad);
492 493 494
    }

    auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
495 496 497 498 499 500 501 502 503
    fmha_compute.ComputeBackward(q_transpose_out,
                                 k_transpose_out,
                                 v_transpose_out,
                                 qkv_transpose_out,
                                 softmax_out,
                                 &fmha_out_grad,
                                 nullptr,
                                 nonbatched_bias_grad,
                                 &config);
504 505 506 507

    bool use_addto = has_gating ? true : false;
    if (merge_qkv) {
      // 4. Gradient of Merged QKV Matmul
508
      Tensor *qkv_out_grad = config.GetQKVOutGrad();
509 510
      ComputeMergedQKVMatmulBackward<T>(
          ctx, config, query, qkv_out_grad, query_grad, use_addto);
511 512
    } else {
      // 4. Gradient of Separated QKV Matmul
513 514
      auto *key_grad =
          ctx.Output<phi::DenseTensor>(framework::GradVarName("Key"));
515
      if (key_grad) {
516
        AllocWithDebugInfo<T>(dev_ctx, "key_grad", key_grad);
517
      }
518 519 520
      Tensor *query_out_grad = config.GetQueryOutGrad();
      Tensor *key_out_grad = config.GetKeyOutGrad();
      Tensor *value_out_grad = config.GetValueOutGrad();
521 522 523 524 525 526 527 528 529 530
      ComputeSeparatedQKVMatmulBackward<T>(ctx,
                                           config,
                                           query,
                                           key,
                                           query_out_grad,
                                           key_out_grad,
                                           value_out_grad,
                                           query_grad,
                                           key_grad,
                                           use_addto);
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(fused_gate_attention,
                        ops::FusedGateAttentionOpKernel<float>,
                        ops::FusedGateAttentionOpKernel<plat::float16>,
                        ops::FusedGateAttentionOpKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad,
                        ops::FusedGateAttentionGradKernel<float>,
                        ops::FusedGateAttentionGradKernel<plat::float16>,
                        ops::FusedGateAttentionGradKernel<plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(fused_gate_attention,
                        ops::FusedGateAttentionOpKernel<float>,
                        ops::FusedGateAttentionOpKernel<double>,
                        ops::FusedGateAttentionOpKernel<plat::float16>,
                        ops::FusedGateAttentionOpKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad,
                        ops::FusedGateAttentionGradKernel<float>,
                        ops::FusedGateAttentionGradKernel<double>,
                        ops::FusedGateAttentionGradKernel<plat::float16>,
                        ops::FusedGateAttentionGradKernel<plat::bfloat16>);
#endif