fused_gate_attention_op.cu 23.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fused_gate_attention.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T>
struct SigmoidMultiplyFunctor {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // sigmoid(x) = 1 / (1 + exp(-x))
  // out = sigmoid(x) * y
  inline HOSTDEVICE T operator()(T x, T y) const {
    MPType x_mp = static_cast<MPType>(x);
    T sigmoid_out = static_cast<T>(one / (one + exp(-x_mp)));
    return sigmoid_out * y;
  }
};

template <typename T>
struct SigmoidMultiplyGradFunctor {
  using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
  MPType one = static_cast<MPType>(1.0f);

  // Gradient of Multiply:
  //  dx = dout * y
  //  dy = dout * x
  // Gradient of Sigmoid: dx = dout * out * (1 - out)
50 51
  inline HOSTDEVICE phi::Array<T, 2> operator()(const T dout,
                                                const T x,
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
                                                T y) const {
    MPType x_mp = static_cast<MPType>(x);
    T sigmoid_out = static_cast<T>(one / (one + exp(-x_mp)));
    T d_sigmoid_out = dout * y;
    phi::Array<T, 2> outs;
    outs[0] = d_sigmoid_out * sigmoid_out *
              (static_cast<T>(1.0f) - sigmoid_out);  // dx
    outs[1] = dout * sigmoid_out;                    // dy
    return outs;
  }
};

template <typename T>
void ComputeMergedQKVMatmulForward(const framework::ExecutionContext &ctx,
                                   const GateAttentionConfig<T> &config,
67 68
                                   const Tensor *query,
                                   Tensor *qkv_out) {
69
  // query: shape=[batch_size, seq_len_m, seq_len_r, qkv_dim]
70 71
  // qkv_weight: shape=[3, num_heads, head_dim, qkv_dim]
  // qkv_out: shape=[batch_size, seq_len_m, seq_len_r, 3, num_heads, head_dim]
72 73 74 75
  auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");

  // qkv_out = GEMM(query, qkv_weight^T)
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
76
  int n = 3 * config.num_heads * config.head_dim;
77 78 79 80 81 82 83
  int k = config.q_dim;
  auto qkv_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
  qkv_compute.ComputeForward(qkv_weight, query, nullptr, qkv_out, nullptr);
}

template <typename T>
84 85 86 87
void ComputeMergedQKVMatmulBackward(const framework::ExecutionContext &ctx,
                                    const GateAttentionGradConfig<T> &config,
                                    const Tensor *query,
                                    const Tensor *qkv_out_grad,
88 89
                                    Tensor *query_grad,
                                    bool use_addto) {
90 91 92 93 94 95 96
  auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");
  auto *qkv_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("QKVWeight"));
  qkv_weight_grad->mutable_data<T>(ctx.GetPlace());

  // Gradient of GEMM(query, qkv_weight)
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
97
  int n = 3 * config.num_heads * config.head_dim;
98 99 100
  int k = config.q_dim;
  auto qkv_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, true, m, n, k, false);
101 102 103 104 105 106 107
  qkv_compute.ComputeBackward(query,
                              qkv_weight,
                              qkv_out_grad,
                              query_grad,
                              qkv_weight_grad,
                              nullptr,
                              use_addto);
108 109 110 111 112
}

template <typename T>
void ComputeSeparatedQKVMatmulForward(const framework::ExecutionContext &ctx,
                                      const GateAttentionConfig<T> &config,
113 114 115 116
                                      const Tensor *query,
                                      const Tensor *key,
                                      Tensor *query_out,
                                      Tensor *key_out,
117 118 119 120 121 122 123
                                      Tensor *value_out) {
  auto *query_weight = ctx.Input<Tensor>("QueryWeight");
  auto *key_weight = ctx.Input<Tensor>("KeyWeight");
  auto *value_weight = ctx.Input<Tensor>("ValueWeight");

  // query_out = GEMM(query, query_weight)
  // query: shape=[batch_size, seq_len_m, seq_len_r, q_dim]
124 125
  // query_weight: shape=[q_dim, num_heads, head_dim]
  // query_out: shape=[batch_size, seq_len_m, seq_len_r, num_heads, head_dim]
126
  int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
127
  int q_n = config.num_heads * config.head_dim;
128
  int q_k = config.q_dim;
129 130
  auto q_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false);
131 132 133 134
  q_compute.ComputeForward(query_weight, query, nullptr, query_out, nullptr);

  // k_out = GEMM(key, key_weight)
  // key: shape=[batch_size, seq_len_m, m_size, kv_dim]
135 136
  // key_weight: shape=[kv_dim, num_heads, head_dim]
  // key_out: shape=[batch_size, seq_len_m, m_size, num_heads, head_dim]
137
  int kv_m = config.batch_size * config.seq_len_m * config.m_size;
138
  int kv_n = config.num_heads * config.head_dim;
139
  int kv_k = config.kv_dim;
140 141
  auto kv_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false);
142 143 144 145 146 147 148
  kv_compute.ComputeForward(key_weight, key, nullptr, key_out, nullptr);

  // value_out = GEMM(value, value_weight)
  kv_compute.ComputeForward(value_weight, key, nullptr, value_out, nullptr);
}

template <typename T>
149 150
void ComputeSeparatedQKVMatmulBackward(const framework::ExecutionContext &ctx,
                                       const GateAttentionGradConfig<T> &config,
151 152
                                       const Tensor *query,
                                       const Tensor *key,
153 154 155
                                       const Tensor *query_out_grad,
                                       const Tensor *key_out_grad,
                                       const Tensor *value_out_grad,
156 157
                                       Tensor *query_grad,
                                       Tensor *key_grad,
158
                                       bool use_addto) {
159 160 161 162 163 164 165
  // Gradient of GEMM(key, k_weight)
  const auto *key_weight = ctx.Input<Tensor>("KeyWeight");
  auto *key_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("KeyWeight"));
  key_weight_grad->mutable_data<T>(ctx.GetPlace());

  int kv_m = config.batch_size * config.seq_len_m * config.m_size;
166
  int kv_n = config.num_heads * config.head_dim;
167
  int kv_k = config.kv_dim;
168 169 170 171
  auto kv_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, kv_m, kv_n, kv_k, false);
  kv_compute.ComputeBackward(
      key, key_weight, key_out_grad, key_grad, key_weight_grad, nullptr, false);
172 173 174 175 176 177 178

  // Gradient of GEMM(value, v_weight)
  auto *value_weight = ctx.Input<Tensor>("ValueWeight");
  auto *value_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("ValueWeight"));
  value_weight_grad->mutable_data<T>(ctx.GetPlace());

179 180 181 182 183 184 185
  kv_compute.ComputeBackward(key,
                             value_weight,
                             value_out_grad,
                             key_grad,
                             value_weight_grad,
                             nullptr,
                             true);
186 187 188 189 190 191 192 193

  // Gradient of GEMM(query, query_weight)
  const auto *query_weight = ctx.Input<Tensor>("QueryWeight");
  auto *query_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("QueryWeight"));
  query_weight_grad->mutable_data<T>(ctx.GetPlace());

  int q_m = config.batch_size * config.seq_len_m * config.seq_len_r;
194
  int q_n = config.num_heads * config.head_dim;
195
  int q_k = config.q_dim;
196 197 198 199 200 201 202 203 204
  auto q_compute = AttnMatMul<T>(
      ctx.cuda_device_context(), false, false, q_m, q_n, q_k, false);
  q_compute.ComputeBackward(query,
                            query_weight,
                            query_out_grad,
                            query_grad,
                            query_weight_grad,
                            nullptr,
                            use_addto);
205 206 207
}

template <typename T>
208 209
void ComputeGatingLinearForward(const framework::ExecutionContext &ctx,
                                const GateAttentionConfig<T> &config,
210 211
                                const Tensor *query,
                                const Tensor *fmha_out,
212
                                Tensor *gate_out) {
213 214 215 216 217 218 219 220
  auto *gate_weight = ctx.Input<Tensor>("GateWeight");
  auto *gate_bias = ctx.Input<Tensor>("GateBias");

  // The first gate_bias_out stores the result of the multiplication,
  // and the second gate_bias_out stores the result of the multiplication +
  // bias.
  //   gate_out = GEMM(query, gate_weight) + gate_bias
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
221
  int n = config.num_heads * config.head_dim;
222 223 224
  int k = config.q_dim;
  auto gate_attn_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
225 226
  gate_attn_compute.ComputeForward(
      gate_weight, query, gate_bias, gate_out, gate_out);
227 228 229 230

  // gate_out = sigmoid(gate_out) * fmha_out
  std::vector<const Tensor *> ins = {gate_out, fmha_out};
  std::vector<Tensor *> outs = {gate_out};
231 232
  phi::funcs::ElementwiseKernel<T>(
      ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor<T>());
233 234 235
}

template <typename T>
236 237
void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx,
                                 const GateAttentionGradConfig<T> &config,
238 239
                                 const Tensor *query,
                                 const Tensor *fmha_out,
240
                                 const Tensor *gate_out_grad,
241 242
                                 Tensor *query_grad,
                                 Tensor *fmha_out_grad) {
243 244 245 246 247 248 249 250 251
  const auto *gate_weight = ctx.Input<Tensor>("GateWeight");
  const auto *gate_bias = ctx.Input<Tensor>("GateBias");

  // Re-compute gate_bias_out
  Tensor gate_bias_out;
  gate_bias_out.Resize(config.gate_out_dims);
  gate_bias_out.mutable_data<T>(ctx.GetPlace());

  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
252
  int n = config.num_heads * config.head_dim;
253 254 255
  int k = config.q_dim;
  auto gate_attn_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
256 257
  gate_attn_compute.ComputeForward(
      gate_weight, query, gate_bias, &gate_bias_out, &gate_bias_out);
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272

  // Gradient of sigmoid(gate_bias_out) * fmha_out
  // Compute inplace and save gate_bias_out_grad to gate_bias_out.
  std::vector<const Tensor *> ins = {gate_out_grad, &gate_bias_out, fmha_out};
  std::vector<Tensor *> outs = {&gate_bias_out, fmha_out_grad};
  phi::funcs::ElementwiseKernel<T, SigmoidMultiplyGradFunctor<T>, 2>(
      ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyGradFunctor<T>());

  // Gradient of GEMM(query, gate_weight) + gate_bias
  auto *gate_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("GateWeight"));
  auto *gate_bias_grad = ctx.Output<Tensor>(framework::GradVarName("GateBias"));
  gate_weight_grad->mutable_data<T>(ctx.GetPlace());
  gate_bias_grad->mutable_data<T>(ctx.GetPlace());

273 274 275 276 277
  gate_attn_compute.ComputeBackward(query,
                                    gate_weight,
                                    &gate_bias_out,
                                    query_grad,
                                    gate_weight_grad,
278 279 280 281
                                    gate_bias_grad);
}

template <typename T>
282 283
void ComputeOutputLinearForward(const framework::ExecutionContext &ctx,
                                const GateAttentionConfig<T> &config,
284 285
                                const Tensor *fmha_or_gate_out,
                                Tensor *out) {
286 287 288 289 290 291
  const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");
  const auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");

  // out = GEMM(fmha_or_gate_out, out_linear_weight) + out_linear_bias
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
  int n = config.q_dim;
292
  int k = config.num_heads * config.head_dim;
293 294
  auto out_linear_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
295 296
  out_linear_compute.ComputeForward(
      out_linear_weight, fmha_or_gate_out, out_linear_bias, out, out);
297 298 299
}

template <typename T>
300 301
void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx,
                                 const GateAttentionGradConfig<T> &config,
302 303
                                 const Tensor *input,
                                 Tensor *input_grad) {
304 305 306 307 308 309 310 311 312 313
  const auto *out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
  const auto *out_linear_weight = ctx.Input<Tensor>("OutLinearWeight");

  auto *out_linear_weight_grad =
      ctx.Output<Tensor>(framework::GradVarName("OutLinearWeight"));
  auto *out_linear_bias_grad =
      ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));

  out_linear_weight_grad->mutable_data<T>(ctx.GetPlace());
  out_linear_bias_grad->mutable_data<T>(ctx.GetPlace());
314

315 316
  int m = config.batch_size * config.seq_len_m * config.seq_len_r;
  int n = config.q_dim;
317
  int k = config.num_heads * config.head_dim;
318 319
  auto out_linear_compute =
      AttnMatMul<T>(ctx.cuda_device_context(), false, false, m, n, k, true);
320 321 322 323 324
  out_linear_compute.ComputeBackward(input,
                                     out_linear_weight,
                                     out_grad,
                                     input_grad,
                                     out_linear_weight_grad,
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
                                     out_linear_bias_grad);
}

template <typename T>
class FusedGateAttentionOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    const auto *query = ctx.Input<Tensor>("Query");
    const auto *key = ctx.Input<Tensor>("Key");
    const auto *query_weight = ctx.Input<Tensor>("QueryWeight");
    const auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");

    const auto *src_mask = ctx.Input<Tensor>("SrcMask");
    const auto *nonbatched_bias = ctx.Input<Tensor>("NonbatchedBias");

    auto *q_transpose_out = ctx.Output<Tensor>("QueryTransposeOut");
    auto *k_transpose_out = ctx.Output<Tensor>("KeyTransposeOut");
    auto *v_transpose_out = ctx.Output<Tensor>("ValueTransposeOut");
    auto *qkv_transpose_out = ctx.Output<Tensor>("QKVTransposeOut");

    auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
    auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
347 348
    auto *gate_out = ctx.Output<Tensor>("GateOut");
    auto *out = ctx.Output<Tensor>("Out");
349 350 351 352

    const bool merge_qkv = ctx.Attr<bool>("merge_qkv");
    const bool has_gating = ctx.Attr<bool>("has_gating");

L
Leo Chen 已提交
353
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
354 355 356 357 358 359 360 361
    AllocWithDebugInfo<T>(dev_ctx, "softmax_out", softmax_out);
    AllocWithDebugInfo<T>(dev_ctx, "fmha_out", fmha_out);
    if (has_gating) {
      AllocWithDebugInfo<T>(dev_ctx, "gate_out", gate_out);
    }
    AllocWithDebugInfo<T>(dev_ctx, "out", out);

    // When seq_len_r = m_size, q_dim = kv_dim, QKV matmul can be merged.
362 363
    GateAttentionConfig<T> config(
        dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
364 365

    if (merge_qkv) {
366 367 368 369 370 371 372 373
      PADDLE_ENFORCE_EQ(
          !key || query == key || query->data<T>() == key->data<T>(),
          true,
          platform::errors::InvalidArgument(
              "key is expected to be nullptr or the same as "
              "query, but recieved key=%p, query=%p.",
              key,
              query));
374

375
      // 1. Merged QKV Matmul: einsum(nbhqk,nbkhc -> nbqhc)
376
      Tensor *qkv_out = config.GetQKVOut();
377 378
      ComputeMergedQKVMatmulForward<T>(ctx, config, query, qkv_out);

379
      AllocWithDebugInfo<T>(dev_ctx, "qkv_transpose_out", qkv_transpose_out);
380 381
    } else {
      // 1. Separated QKV Matmul
382 383 384
      Tensor *query_out = config.GetQueryOut();
      Tensor *key_out = config.GetKeyOut();
      Tensor *value_out = config.GetValueOut();
385 386
      ComputeSeparatedQKVMatmulForward<T>(
          ctx, config, query, key, query_out, key_out, value_out);
387

388 389 390
      AllocWithDebugInfo<T>(dev_ctx, "q_transpose_out", q_transpose_out);
      AllocWithDebugInfo<T>(dev_ctx, "k_transpose_out", k_transpose_out);
      AllocWithDebugInfo<T>(dev_ctx, "v_transpose_out", v_transpose_out);
391 392 393 394
    }

    // 2. FMHA
    auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
395 396 397 398 399 400 401 402 403 404
    fmha_compute.ComputeForward(nonbatched_bias,
                                src_mask,
                                q_transpose_out,
                                k_transpose_out,
                                v_transpose_out,
                                qkv_transpose_out,
                                softmax_out,
                                fmha_out,
                                gate_out,
                                &config);
405 406

    // 3. Gating Linear
407 408 409
    if (has_gating) {
      ComputeGatingLinearForward<T>(ctx, config, query, fmha_out, gate_out);
    }
410 411

    // 4. Output Linear
412 413
    Tensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out;
    ComputeOutputLinearForward<T>(ctx, config, fmha_or_gate_out, out);
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
  }
};

template <typename T>
class FusedGateAttentionGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &ctx) const override {
    // forward input
    const auto *query = ctx.Input<Tensor>("Query");
    const auto *key = ctx.Input<Tensor>("Key");
    const auto *query_weight = ctx.Input<Tensor>("QueryWeight");
    const auto *qkv_weight = ctx.Input<Tensor>("QKVWeight");

    // forward output, backward input
    const auto *q_transpose_out = ctx.Input<Tensor>("QueryTransposeOut");
    const auto *k_transpose_out = ctx.Input<Tensor>("KeyTransposeOut");
    const auto *v_transpose_out = ctx.Input<Tensor>("ValueTransposeOut");
    const auto *qkv_transpose_out = ctx.Input<Tensor>("QKVTransposeOut");
    const auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
    const auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
434
    const auto *gate_out = ctx.Input<Tensor>("GateOut");
435 436 437 438 439

    // backward output
    auto *query_grad = ctx.Output<Tensor>(framework::GradVarName("Query"));
    auto *nonbatched_bias_grad =
        ctx.Output<Tensor>(framework::GradVarName("NonbatchedBias"));
440 441 442

    bool has_gating = ctx.Attr<bool>("has_gating");
    bool merge_qkv = ctx.Attr<bool>("merge_qkv");
443

L
Leo Chen 已提交
444
    auto &dev_ctx = ctx.template device_context<phi::GPUContext>();
445
    AllocWithDebugInfo<T>(dev_ctx, "query_grad", query_grad);
446

447 448
    GateAttentionGradConfig<T> config(
        dev_ctx, query, key, query_weight, qkv_weight, merge_qkv, has_gating);
449

450 451 452
    Tensor fmha_out_grad;
    fmha_out_grad.Resize(config.gate_out_dims);
    AllocWithDebugInfo<T>(dev_ctx, "fmha_out_grad", &fmha_out_grad);
453
    if (has_gating) {
454 455 456 457 458 459 460 461
      // 1. Gradient of Output Linear: out = Linear(gate_out)
      Tensor gate_out_grad;
      gate_out_grad.Resize(config.gate_out_dims);
      AllocWithDebugInfo<T>(dev_ctx, "gate_out_grad", &gate_out_grad);
      ComputeOutputLinearBackward<T>(ctx, config, gate_out, &gate_out_grad);

      // 2. Gradient of Gating Linear
      // Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out
462 463 464 465 466 467
      ComputeGatingLinearBackward<T>(ctx,
                                     config,
                                     query,
                                     fmha_out,
                                     &gate_out_grad,
                                     query_grad,
468 469 470 471
                                     &fmha_out_grad);
    } else {
      // 1. Gradient of Output Linear: out = Linear(fmha_grad)
      ComputeOutputLinearBackward<T>(ctx, config, fmha_out, &fmha_out_grad);
472 473 474 475
    }

    // 3. Gradient of FMHA
    if (nonbatched_bias_grad) {
476 477
      AllocWithDebugInfo<T>(
          dev_ctx, "nonbatched_bias_grad", nonbatched_bias_grad);
478 479 480
    }

    auto fmha_compute = FMHAGateRef<T>(dev_ctx, merge_qkv);
481 482 483 484 485 486 487 488 489
    fmha_compute.ComputeBackward(q_transpose_out,
                                 k_transpose_out,
                                 v_transpose_out,
                                 qkv_transpose_out,
                                 softmax_out,
                                 &fmha_out_grad,
                                 nullptr,
                                 nonbatched_bias_grad,
                                 &config);
490 491 492 493

    bool use_addto = has_gating ? true : false;
    if (merge_qkv) {
      // 4. Gradient of Merged QKV Matmul
494
      Tensor *qkv_out_grad = config.GetQKVOutGrad();
495 496
      ComputeMergedQKVMatmulBackward<T>(
          ctx, config, query, qkv_out_grad, query_grad, use_addto);
497 498 499 500
    } else {
      // 4. Gradient of Separated QKV Matmul
      auto *key_grad = ctx.Output<Tensor>(framework::GradVarName("Key"));
      if (key_grad) {
501
        AllocWithDebugInfo<T>(dev_ctx, "key_grad", key_grad);
502
      }
503 504 505
      Tensor *query_out_grad = config.GetQueryOutGrad();
      Tensor *key_out_grad = config.GetKeyOutGrad();
      Tensor *value_out_grad = config.GetValueOutGrad();
506 507 508 509 510 511 512 513 514 515
      ComputeSeparatedQKVMatmulBackward<T>(ctx,
                                           config,
                                           query,
                                           key,
                                           query_out_grad,
                                           key_out_grad,
                                           value_out_grad,
                                           query_grad,
                                           key_grad,
                                           use_addto);
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(fused_gate_attention,
                        ops::FusedGateAttentionOpKernel<float>,
                        ops::FusedGateAttentionOpKernel<plat::float16>,
                        ops::FusedGateAttentionOpKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad,
                        ops::FusedGateAttentionGradKernel<float>,
                        ops::FusedGateAttentionGradKernel<plat::float16>,
                        ops::FusedGateAttentionGradKernel<plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL(fused_gate_attention,
                        ops::FusedGateAttentionOpKernel<float>,
                        ops::FusedGateAttentionOpKernel<double>,
                        ops::FusedGateAttentionOpKernel<plat::float16>,
                        ops::FusedGateAttentionOpKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(fused_gate_attention_grad,
                        ops::FusedGateAttentionGradKernel<float>,
                        ops::FusedGateAttentionGradKernel<double>,
                        ops::FusedGateAttentionGradKernel<plat::float16>,
                        ops::FusedGateAttentionGradKernel<plat::bfloat16>);
#endif