fused_attention_op.cc 30.6 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include <string>
17

L
Li Min 已提交
18
#include "paddle/fluid/framework/op_registry.h"
19
#include "paddle/fluid/framework/op_version_registry.h"
L
Li Min 已提交
20 21 22 23 24 25 26 27 28 29 30

namespace paddle {
namespace operators {

class FusedAttentionOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
31 32
    OP_INOUT_CHECK(
        ctx->HasInput("OutLinearW"), "Input", "OutLinearW", "FusedAttentionOp");
L
Li Min 已提交
33

34
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
35 36 37 38 39
      OP_INOUT_CHECK(
          ctx->HasOutput("LnMean"), "Output", "LnMean", "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("LnVariance"),
                     "Output",
                     "LnVariance",
40
                     "FusedAttentionOp");
41 42
      OP_INOUT_CHECK(
          ctx->HasOutput("LnOut"), "Output", "LnOut", "FusedAttentionOp");
L
Li Min 已提交
43
    } else {
44 45 46 47 48
      OP_INOUT_CHECK(
          ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"),
                     "Output",
                     "Ln2Variance",
L
Li Min 已提交
49
                     "FusedAttentionOp");
50 51 52
      OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"),
                     "Output",
                     "BiasDropoutResidualOut",
L
Li Min 已提交
53
                     "FusedAttentionOp");
54 55
    }

L
Li Min 已提交
56
    // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
57 58
    OP_INOUT_CHECK(
        ctx->HasOutput("QKVOut"), "Output", "QKVOut", "FusedAttentionOp");
59
    if (ctx->HasInput("QKVBias")) {
60 61 62
      OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"),
                     "Output",
                     "QKVBiasOut",
63 64
                     "FusedAttentionOp");
    }
65 66 67
    OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"),
                   "Output",
                   "TransposeOut2",
L
Li Min 已提交
68
                   "FusedAttentionOp");
69 70 71 72
    OP_INOUT_CHECK(
        ctx->HasOutput("QKOut"), "Output", "QKOut", "FusedAttentionOp");
    OP_INOUT_CHECK(
        ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", "FusedAttentionOp");
73

74
    if (ctx->HasInput("CacheKV")) {
75 76 77
      OP_INOUT_CHECK(ctx->HasOutput("CacheKVOut"),
                     "Output",
                     "CacheKVOut",
78 79
                     "FusedAttentionOp");
    }
80
    if (ctx->HasInput("SrcMask")) {
81 82 83
      OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"),
                     "Output",
                     "SrcMaskOut",
84 85
                     "FusedAttentionOp");
    }
86 87 88
    OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"),
                   "Output",
                   "SoftmaxOut",
L
Li Min 已提交
89
                   "FusedAttentionOp");
90 91 92
    OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"),
                   "Output",
                   "AttnDropoutMaskOut",
L
Li Min 已提交
93
                   "FusedAttentionOp");
94 95 96
    OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"),
                   "Output",
                   "AttnDropoutOut",
L
Li Min 已提交
97
                   "FusedAttentionOp");
98 99 100 101 102
    OP_INOUT_CHECK(
        ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"),
                   "Output",
                   "OutLinearOut",
L
Li Min 已提交
103
                   "FusedAttentionOp");
L
Li Min 已提交
104

105 106 107
    OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"),
                   "Output",
                   "DropoutMaskOut",
L
Li Min 已提交
108 109 110 111 112 113 114
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
    auto x_dim = ctx->GetInputDim("X");
    auto y_dim = ctx->GetInputDim("QKVW");
115
    PADDLE_ENFORCE_EQ(
116 117
        x_dim.size(),
        3,
118 119 120 121 122
        platform::errors::InvalidArgument("The dimensions of x must be 3"
                                          "(batch_size, seq_len, dim_embed),"
                                          "but received dimensions of"
                                          "Input is [%d]",
                                          x_dim.size()));
123 124
    PADDLE_ENFORCE_EQ(y_dim.size(),
                      4,
L
Li Min 已提交
125 126 127 128 129 130
                      platform::errors::InvalidArgument(
                          "The dimensions of qkv_weight must be 4"
                          "(3, num_head, dim_head, dim_embed),"
                          "but received dimensions of"
                          "Input is [%d]",
                          y_dim.size()));
131 132
    PADDLE_ENFORCE_EQ(x_dim[2],
                      y_dim[3],
L
Li Min 已提交
133 134 135 136 137
                      platform::errors::InvalidArgument(
                          "ShapeError: the dimension of x_dim[2] and y_dim[3]"
                          "must be equal. But received: the shape "
                          "of input x = [%s], and the shape of "
                          "input qkv_weight = [%s]",
138 139
                          x_dim,
                          y_dim));
L
Li Min 已提交
140

141
    if (ctx->Attrs().Get<int>("ring_id") == -1) {
142 143
      PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
                        y_dim[3],
144 145 146 147 148 149
                        platform::errors::InvalidArgument(
                            "The dimensions of qkv_weight must be 4"
                            "(3, num_head, dim_head, dim_embed),"
                            "and must satisfy the limitations: "
                            "(num_head * dim_head == dim_embed)"));
    }
150

151 152 153 154
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
L
Li Min 已提交
155 156 157 158
    } else {
      ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
159
    }
L
Li Min 已提交
160 161 162
    // [batch_size, seq_len, 3, num_head, head_size]
    ctx->SetOutputDim("QKVOut",
                      {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
163 164 165 166 167

    if (ctx->HasInput("QKVBias")) {
      ctx->SetOutputDim("QKVBiasOut",
                        {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
    }
L
Li Min 已提交
168 169 170
    // [3, batch_size, num_head, seq_len, head_size]
    ctx->SetOutputDim("TransposeOut2",
                      {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
171 172 173 174 175 176 177 178

    // cache_seq_len + seq_len if cache else seq_len
    auto out_seq_len = x_dim[1];
    if (ctx->HasInput("CacheKV")) {
      // [2, batch_size, num_head, cache_seq_len, head_size]
      auto c_dim = ctx->GetInputDim("CacheKV");

      PADDLE_ENFORCE_EQ(
179 180
          c_dim.size(),
          5,
181 182
          paddle::platform::errors::InvalidArgument(
              "The CacheKV must be 5 dims, but got %d", c_dim.size()));
183 184
      PADDLE_ENFORCE_EQ(c_dim[0],
                        2,
185 186 187
                        paddle::platform::errors::InvalidArgument(
                            "The first dim of CacheKV must be 2, but got %d",
                            c_dim[0]));  // 2
188 189
      PADDLE_ENFORCE_EQ(c_dim[1],
                        x_dim[0],
190 191 192
                        paddle::platform::errors::InvalidArgument(
                            "The second dim of CacheKV must be equal with "
                            "batch size %d, but got %d",
193 194 195 196
                            x_dim[0],
                            c_dim[1]));  // batch_size
      PADDLE_ENFORCE_EQ(c_dim[2],
                        y_dim[1],
197 198 199
                        paddle::platform::errors::InvalidArgument(
                            "The third dim of CacheKV must be equal with num "
                            "head %d, but got %d",
200 201
                            y_dim[1],
                            c_dim[2]));  // num_head
202 203 204 205
      // In compile stage, input seq_len can be -1, in that case
      // c_dim[3] may < 0 in while
      if (ctx->IsRuntime()) {
        PADDLE_ENFORCE_GE(
206 207
            c_dim[3],
            0,
208 209 210 211
            paddle::platform::errors::InvalidArgument(
                "The forth dim of CacheKV must be greater than 0, but got %d",
                c_dim[3]));  // cache_seq_len
      }
212 213
      PADDLE_ENFORCE_EQ(c_dim[4],
                        y_dim[2],
214 215 216
                        paddle::platform::errors::InvalidArgument(
                            "The fifth dim of CacheKV must be equal with head "
                            "size %d, but got %d",
217 218
                            y_dim[2],
                            c_dim[4]));  // head_size
219 220 221 222 223 224 225 226 227

      out_seq_len += c_dim[3];
      // [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
      ctx->SetOutputDim("CacheKVOut",
                        {c_dim[0], c_dim[1], c_dim[2], out_seq_len, c_dim[4]});
    }

    // [batch, num_head, seq_len, out_seq_len]
    ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
228 229

    if (ctx->HasInput("SrcMask")) {
230 231
      ctx->SetOutputDim("SrcMaskOut",
                        {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
232
    }
L
Li Min 已提交
233 234
    // the same as QKOut's shape.
    ctx->SetOutputDim("AttnDropoutOut",
235
                      {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
236
    if (ctx->Attrs().Get<bool>("is_test") == false) {
L
Li Min 已提交
237
      ctx->SetOutputDim("AttnDropoutMaskOut",
238
                        {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
239
    }
240 241
    ctx->SetOutputDim("SoftmaxOut",
                      {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
242 243 244 245 246 247
    // [batch_size, num_heads, seq_len, head_dim]
    ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
    // [batch_size, seq_len, number of heads*head size]
    ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
    ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));

L
Li Min 已提交
248
    if (ctx->Attrs().Get<bool>("is_test") == false) {
L
Li Min 已提交
249 250
      ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
    }
L
Li Min 已提交
251

L
Li Min 已提交
252 253 254 255 256 257
    ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
258
    auto input = ctx.Input<phi::DenseTensor>("X");
259
    auto input_data_type = framework::TransToProtoVarType(input->dtype());
L
Li Min 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "The input tensor.");
    AddInput("LnScale",
             "(optional) Scale is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("LnBias",
             "(optional) Bias is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("QKVW", "The qkv weight tensor.");
277
    AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
278 279
    AddInput("CacheKV", "(optional) The cached KV for generation inference.")
        .AsDispensable();
L
Li Min 已提交
280 281 282
    AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
        .AsDispensable();
    AddInput("OutLinearW", "The out_linear weight tensor.");
283
    AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable();
L
Li Min 已提交
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
    AddInput("Ln2Scale",
             "(optional) Scale is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("Ln2Bias",
             "(optional) Bias is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
    AddOutput("LnVariance", "Variance of the current mini batch.")
        .AsIntermediate();
    AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate();
    AddOutput("QKVOut", "Result after qkv.").AsIntermediate();
    AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate();
    AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate();
    AddOutput("QKOut", "Result in fmha.").AsIntermediate();
    AddOutput("QKTVOut", "Result in fmha.").AsIntermediate();
    AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
    AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate();
    AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate();
    AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate();
    AddOutput("FMHAOut", "Result after fmha.").AsIntermediate();
    AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate();
    AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
        .AsIntermediate();
    AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate();
    AddOutput("Ln2Variance", "Variance of the current mini batch.")
        .AsIntermediate();
    AddOutput("BiasDropoutResidualOut",
              "Result of residual + dropout(src + bias).")
        .AsIntermediate();
315
    AddOutput("CacheKVOut", "The udpated cache KV.");
L
Li Min 已提交
316 317 318 319 320 321 322 323 324 325 326
    AddOutput("Y", "Result after attention.");

    AddAttr<bool>("pre_layer_norm",
                  "if true, the attention op uses pre_layer_norm architecure, "
                  "else, uses post_layer_norm architecuture. "
                  "[default false].")
        .SetDefault(false);
    AddAttr<float>("epsilon",
                   "Constant for numerical stability [default 1e-5].")
        .SetDefault(1e-5)
        .AddCustomChecker([](const float &epsilon) {
327 328
          PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f,
                            true,
L
Li Min 已提交
329 330 331 332 333 334 335 336 337 338 339
                            platform::errors::InvalidArgument(
                                "'epsilon' in Op(LayerNorm) should be between"
                                "0.0 and 0.001, But received [%s].",
                                epsilon));
        });

    // for dropout in fmha.
    AddAttr<float>("attn_dropout_rate", "Probability of setting units to zero.")
        .SetDefault(.5f)
        .AddCustomChecker([](const float &drop_p) {
          PADDLE_ENFORCE_EQ(
340 341
              drop_p >= 0.0f && drop_p <= 1.0f,
              true,
L
Li Min 已提交
342 343 344
              platform::errors::InvalidArgument(
                  "'attn_dropout_rate' must be between 0.0 and 1.0."));
        });
L
Li Min 已提交
345
    AddAttr<bool>("is_test",
L
Li Min 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddAttr<bool>("attn_dropout_fix_seed",
                  "A flag indicating whether to use a fixed seed to generate "
                  "random mask. NOTE: DO NOT set this flag to true in "
                  "training. Setting this flag to true is only useful in "
                  "unittest or for debug that always the same output units "
                  "will be dropped.")
        .SetDefault(true);
    AddAttr<int>("attn_dropout_seed", "Dropout random seed.").SetDefault(0);
    AddAttr<std::string>(
        "attn_dropout_implementation",
        "[\"downgrade_in_infer\"|\"upscale_in_train\"]"
        "There are two kinds of ways to implement dropout"
        "(the mask below is a tensor have the same shape with input"
        "the value of mask is 0 or 1, the ratio of 0 is dropout_rate)"
        "1. downgrade_in_infer(default), downgrade the outcome at inference "
        "time"
        "   train: out = input * mask"
        "   inference: out = input * (1.0 - dropout_rate)"
        "2. upscale_in_train, upscale the outcome at training time, do nothing "
        "in inference"
        "   train: out = input * mask / ( 1.0 - dropout_rate )"
        "   inference: out = input"
        "   dropout op can be removed from the program. the program will be "
        "efficient")
        .SetDefault("upscale_in_train")
        .AddCustomChecker([](const std::string &type) {
          PADDLE_ENFORCE_EQ(
376 377
              type == "downgrade_in_infer" || type == "upscale_in_train",
              true,
L
Li Min 已提交
378 379 380 381 382 383 384 385
              platform::errors::InvalidArgument(
                  "dropout_implementation can only be downgrade_in_infer or "
                  "upscale_in_train"));
        });

    AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
        .SetDefault(.5f)
        .AddCustomChecker([](const float &drop_p) {
386 387
          PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f,
                            true,
L
Li Min 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
                            platform::errors::InvalidArgument(
                                "'dropout_rate' must be between 0.0 and 1.0."));
        });
    AddAttr<bool>("dropout_fix_seed",
                  "A flag indicating whether to use a fixed seed to generate "
                  "random mask. NOTE: DO NOT set this flag to true in "
                  "training. Setting this flag to true is only useful in "
                  "unittest or for debug that always the same output units "
                  "will be dropped.")
        .SetDefault(true);
    AddAttr<int>("dropout_seed", "Dropout random seed.").SetDefault(0);
    AddAttr<std::string>(
        "dropout_implementation",
        "[\"downgrade_in_infer\"|\"upscale_in_train\"]"
        "The meaning is the same as 'attn_dropout_implementation'.")
        .SetDefault("downgrade_in_infer")
        .AddCustomChecker([](const std::string &type) {
          PADDLE_ENFORCE_EQ(
406 407
              type == "downgrade_in_infer" || type == "upscale_in_train",
              true,
L
Li Min 已提交
408 409 410 411 412 413 414 415
              platform::errors::InvalidArgument(
                  "dropout_implementation can only be downgrade_in_infer or "
                  "upscale_in_train"));
        });
    AddAttr<float>("ln_epsilon",
                   "Constant for numerical stability [default 1e-5].")
        .SetDefault(1e-5)
        .AddCustomChecker([](const float &ln_epsilon) {
416 417
          PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f,
                            true,
L
Li Min 已提交
418 419 420 421 422 423
                            platform::errors::InvalidArgument(
                                "'epsilon' of the second LayerNorm in Fused "
                                "attention op should be between"
                                "0.0 and 0.001, But received [%s].",
                                ln_epsilon));
        });
424
    AddAttr<bool>("add_residual", "Whether to add residual.").SetDefault(true);
425 426 427 428
    AddAttr<int>(
        "ring_id",
        "ring id for tensor model parallel. distributed training and inference")
        .SetDefault(-1);
L
Li Min 已提交
429 430

    AddComment(R"DOC(
431 432
  The fused_attention operator is the same as following pseudo codes:

433 434
  // @input: [batch_size, seq_len, embed_dim]
  // @final_out: [batch_size, seq_len, num_heads, head_dim]
435
  residual = input
L
Li Min 已提交
436
  if (pre_layernorm)
437 438 439
    query = layer_norm(input);
  out = compute_qkv(query) + qkv_bias;
  // fmha module
L
Li Min 已提交
440 441 442 443 444 445 446 447
  {
    out = transpose(out, perm=[2, 0, 3, 1, 4]);
    out = q * k^t;
    out = attn_mask + out;
    out = softmax(out);
    out = dropout(out);
    out = out * v;
    out = transpose(out, perm=[0, 2, 1, 3]);
448

L
Li Min 已提交
449
  }
450 451 452 453 454 455 456 457
  // out linear
  out = linear(out);
  if add_residual:
    out = residual + dropout(out);
  else:
    out = dropout(out);
  if (!pre_layernorm)
    out = layer_norm(out);
L
Li Min 已提交
458 459 460 461
    )DOC");
  }
};

462 463 464 465 466
class FusedAttentionGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
467 468
    PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"),
                      false,
L
Li Min 已提交
469 470
                      platform::errors::InvalidArgument(
                          "GradOp is only callable when is_test is false"));
471

L
Li Min 已提交
472
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
473 474 475 476 477
      OP_INOUT_CHECK(
          ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean", "FusedAttentionGrad");
      OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"),
                     "Input",
                     "Ln2Variance",
L
Li Min 已提交
478 479 480 481 482 483 484 485 486 487
                     "FusedAttentionGrad");
      if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
        ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
                          ctx->GetInputDim("Ln2Scale"));
      }
      if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
        ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
                          ctx->GetInputDim("Ln2Bias"));
      }
    } else {
488 489 490 491 492
      OP_INOUT_CHECK(
          ctx->HasInput("LnMean"), "Input", "LnMean", "FusedAttentionGrad");
      OP_INOUT_CHECK(ctx->HasInput("LnVariance"),
                     "Input",
                     "LnVariance",
493
                     "FusedAttentionGrad");
494 495
      OP_INOUT_CHECK(
          ctx->HasInput("LnOut"), "Input", "LnOut", "FusedAttentionGrad");
496
    }
L
Li Min 已提交
497 498

    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
499 500 501 502 503
    OP_INOUT_CHECK(
        ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionGrad");
    OP_INOUT_CHECK(ctx->HasInput("OutLinearW"),
                   "Input",
                   "OutLinearW",
504 505
                   "FusedAttentionGrad");

506 507 508 509 510 511 512 513 514
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
        ctx->SetOutputDim(framework::GradVarName("LnScale"),
                          ctx->GetInputDim("LnScale"));
      }
      if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
        ctx->SetOutputDim(framework::GradVarName("LnBias"),
                          ctx->GetInputDim("LnBias"));
      }
515 516 517 518
    }
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
    }
519 520 521 522
    if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) {
      ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
                        ctx->GetInputDim("OutLinearBias"));
    }
523 524 525
    ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
                      ctx->GetInputDim("OutLinearW"));
    ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
526 527 528 529
    if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
      ctx->SetOutputDim(framework::GradVarName("QKVBias"),
                        ctx->GetInputDim("QKVBias"));
    }
530

531 532 533
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      ctx->SetOutputDim(framework::GradVarName("LnOut"),
                        ctx->GetInputDim("LnOut"));
L
Li Min 已提交
534 535 536
    } else {
      ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
                        ctx->GetInputDim("BiasDropoutResidualOut"));
537
    }
538 539 540 541 542 543 544 545 546 547
    ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
                      ctx->GetInputDim("FMHAOut"));
    ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
                      ctx->GetInputDim("QKTVOut"));
    ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
                      ctx->GetInputDim("TransposeOut2"));
    ctx->SetOutputDim(framework::GradVarName("QKOut"),
                      ctx->GetInputDim("QKOut"));
    ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
                      ctx->GetInputDim("SoftmaxOut"));
548 549 550 551
    if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) {
      ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
                        ctx->GetInputDim("AttnDropoutOut"));
    }
552 553 554 555 556

    if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
      ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
                        ctx->GetInputDim("SrcMaskOut"));
    }
557 558
    ctx->SetOutputDim(framework::GradVarName("QKVOut"),
                      ctx->GetInputDim("QKVOut"));
559 560 561 562
    if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
      ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
                        ctx->GetInputDim("QKVBiasOut"));
    }
563 564 565 566 567 568 569
    ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
                      ctx->GetInputDim("OutLinearOut"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
570
    auto input = ctx.Input<phi::DenseTensor>("X");
571
    auto input_data_type = framework::TransToProtoVarType(input->dtype());
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("fused_attention_grad");
    op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

    // inputs x, parameters and their grad.
    op->SetInput("X", this->Input("X"));
    op->SetInput("QKVW", this->Input("QKVW"));
589 590 591 592 593 594 595 596 597

    if (this->HasInput("QKVBias")) {
      op->SetInput("QKVBias", this->Input("QKVBias"));
      op->SetOutput(framework::GradVarName("QKVBias"),
                    this->InputGrad("QKVBias"));
      op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
      op->SetOutput(framework::GradVarName("QKVBiasOut"),
                    this->OutputGrad("QKVBiasOut"));
    }
598 599 600 601 602 603 604 605

    if (this->HasInput("SrcMask")) {
      op->SetInput("SrcMask", this->Input("SrcMask"));
      op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
      op->SetOutput(framework::GradVarName("SrcMaskOut"),
                    this->OutputGrad("SrcMaskOut"));
    }

606
    op->SetInput("OutLinearW", this->Input("OutLinearW"));
607 608 609 610 611
    if (this->HasInput("OutLinearBias")) {
      op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
      op->SetOutput(framework::GradVarName("OutLinearBias"),
                    this->InputGrad("OutLinearBias"));
    }
612 613 614

    op->SetAttrMap(this->Attrs());
    bool is_pre_layer_norm =
R
Ruibiao Chen 已提交
615
        PADDLE_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
616 617 618 619 620 621 622 623 624 625 626
    if (is_pre_layer_norm) {
      if (this->HasInput("LnScale")) {
        op->SetInput("LnScale", this->Input("LnScale"));
        op->SetOutput(framework::GradVarName("LnScale"),
                      this->InputGrad("LnScale"));
      }
      if (this->HasInput("LnBias")) {
        op->SetInput("LnBias", this->Input("LnBias"));
        op->SetOutput(framework::GradVarName("LnBias"),
                      this->InputGrad("LnBias"));
      }
L
Li Min 已提交
627 628 629 630 631 632 633 634 635 636 637
    } else {
      if (this->HasInput("Ln2Scale")) {
        op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
        op->SetOutput(framework::GradVarName("Ln2Scale"),
                      this->InputGrad("Ln2Scale"));
      }
      if (this->HasInput("Ln2Bias")) {
        op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
        op->SetOutput(framework::GradVarName("Ln2Bias"),
                      this->InputGrad("Ln2Bias"));
      }
638 639 640 641
    }

    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
642

643 644 645 646
    op->SetOutput(framework::GradVarName("OutLinearW"),
                  this->InputGrad("OutLinearW"));

    // use forward outputs as backward inputs.
647 648 649 650 651 652 653 654 655 656
    if (is_pre_layer_norm) {
      if (this->HasOutput("LnOut")) {
        op->SetInput("LnOut", this->Output("LnOut"));
      }
      if (this->HasOutput("LnMean")) {
        op->SetInput("LnMean", this->Output("LnMean"));
      }
      if (this->HasOutput("LnVariance")) {
        op->SetInput("LnVariance", this->Output("LnVariance"));
      }
L
Li Min 已提交
657 658 659 660 661
    } else {
      op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
      op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
      op->SetInput("BiasDropoutResidualOut",
                   this->Output("BiasDropoutResidualOut"));
662
    }
663
    op->SetInput("QKVOut", this->Output("QKVOut"));
664

665 666 667 668 669 670
    op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
    op->SetInput("QKOut", this->Output("QKOut"));
    op->SetInput("QKTVOut", this->Output("QKTVOut"));
    op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
    op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
    op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
671

672 673 674 675 676 677
    op->SetInput("FMHAOut", this->Output("FMHAOut"));
    op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
    op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
    op->SetInput("QKVOut", this->Output("QKVOut"));

    // backward outputs: dinput
678 679 680 681 682
    if (is_pre_layer_norm) {
      if (this->HasOutput("LnOut")) {
        op->SetOutput(framework::GradVarName("LnOut"),
                      this->OutputGrad("LnOut"));
      }
L
Li Min 已提交
683 684 685
    } else {
      op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
                    this->OutputGrad("BiasDropoutResidualOut"));
686
    }
L
Li Min 已提交
687

688
    op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
689

690 691 692 693 694 695 696 697 698
    op->SetOutput(framework::GradVarName("QKTVOut"),
                  this->OutputGrad("QKTVOut"));
    op->SetOutput(framework::GradVarName("TransposeOut2"),
                  this->OutputGrad("TransposeOut2"));
    op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
    op->SetOutput(framework::GradVarName("SoftmaxOut"),
                  this->OutputGrad("SoftmaxOut"));
    op->SetOutput(framework::GradVarName("AttnDropoutOut"),
                  this->OutputGrad("AttnDropoutOut"));
699

700 701 702 703 704 705 706
    op->SetOutput(framework::GradVarName("FMHAOut"),
                  this->OutputGrad("FMHAOut"));
    op->SetOutput(framework::GradVarName("OutLinearOut"),
                  this->OutputGrad("OutLinearOut"));
  }
};

707 708 709 710 711
DECLARE_NO_NEED_BUFFER_VARS_INFERER(FusedAttentionGradNoNeedBufferInferer,
                                    "QKVBiasOut",
                                    "QKVOut",
                                    "QKOut",
                                    "QKTVOut",
712 713
                                    "OutLinearOut",
                                    "SrcMask");
714

L
Li Min 已提交
715 716 717 718
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
719 720
REGISTER_OPERATOR(fused_attention,
                  ops::FusedAttentionOp,
721 722 723
                  ops::FusedAttentionOpMaker,
                  ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
                  ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
724 725 726
REGISTER_OPERATOR(fused_attention_grad,
                  ops::FusedAttentionGradOp,
                  ops::FusedAttentionGradNoNeedBufferInferer);
727 728 729 730 731 732

REGISTER_OP_VERSION(fused_attention)
    .AddCheckpoint(
        R"ROC(
              Add a new attribute [add_residual] )ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
733 734
            "add_residual",
            "A flag to indicate whether to add residual.",
735
            true));