fused_attention_op.cc 29.0 KB
Newer Older
L
Li Min 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

class FusedAttentionOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
                   "FusedAttentionOp");

34 35 36 37 38 39 40
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
                     "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
                     "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
                     "FusedAttentionOp");
L
Li Min 已提交
41 42 43 44 45 46 47
    } else {
      OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
                     "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
                     "FusedAttentionOp");
      OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
                     "BiasDropoutResidualOut", "FusedAttentionOp");
48 49
    }

L
Li Min 已提交
50 51 52
    // qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
    OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
                   "FusedAttentionOp");
53 54 55 56
    if (ctx->HasInput("QKVBias")) {
      OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
                     "FusedAttentionOp");
    }
L
Li Min 已提交
57 58 59 60 61 62
    OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
                   "FusedAttentionOp");
63

64 65 66 67
    if (ctx->HasInput("CacheKV")) {
      OP_INOUT_CHECK(ctx->HasOutput("CacheKVOut"), "Output", "CacheKVOut",
                     "FusedAttentionOp");
    }
68 69 70 71
    if (ctx->HasInput("SrcMask")) {
      OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
                     "FusedAttentionOp");
    }
L
Li Min 已提交
72 73 74 75 76 77 78 79 80 81
    OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
                   "AttnDropoutMaskOut", "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"), "Output", "AttnDropoutOut",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
                   "FusedAttentionOp");
L
Li Min 已提交
82

L
Li Min 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
                   "FusedAttentionOp");
    OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");

    // x: qkv's input [batch_size, seq_len, dim_embed]
    // y: qkv's weight: [3, num_head, dim_head, dim_embed]
    auto x_dim = ctx->GetInputDim("X");
    auto y_dim = ctx->GetInputDim("QKVW");
    PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument(
                                           "The dimensions of x must be 3"
                                           "(batch_size, seq_len, dim_embed),"
                                           "but received dimensions of"
                                           "Input is [%d]",
                                           x_dim.size()));
    PADDLE_ENFORCE_EQ(y_dim.size(), 4,
                      platform::errors::InvalidArgument(
                          "The dimensions of qkv_weight must be 4"
                          "(3, num_head, dim_head, dim_embed),"
                          "but received dimensions of"
                          "Input is [%d]",
                          y_dim.size()));
    PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3],
                      platform::errors::InvalidArgument(
                          "ShapeError: the dimension of x_dim[2] and y_dim[3]"
                          "must be equal. But received: the shape "
                          "of input x = [%s], and the shape of "
                          "input qkv_weight = [%s]",
                          x_dim, y_dim));

112 113 114 115 116 117 118 119
    if (ctx->Attrs().Get<int>("ring_id") == -1) {
      PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
                        platform::errors::InvalidArgument(
                            "The dimensions of qkv_weight must be 4"
                            "(3, num_head, dim_head, dim_embed),"
                            "and must satisfy the limitations: "
                            "(num_head * dim_head == dim_embed)"));
    }
120

121 122 123 124
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
L
Li Min 已提交
125 126 127 128
    } else {
      ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
      ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
129
    }
L
Li Min 已提交
130 131 132
    // [batch_size, seq_len, 3, num_head, head_size]
    ctx->SetOutputDim("QKVOut",
                      {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
133 134 135 136 137

    if (ctx->HasInput("QKVBias")) {
      ctx->SetOutputDim("QKVBiasOut",
                        {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
    }
L
Li Min 已提交
138 139 140
    // [3, batch_size, num_head, seq_len, head_size]
    ctx->SetOutputDim("TransposeOut2",
                      {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

    // cache_seq_len + seq_len if cache else seq_len
    auto out_seq_len = x_dim[1];
    if (ctx->HasInput("CacheKV")) {
      // [2, batch_size, num_head, cache_seq_len, head_size]
      auto c_dim = ctx->GetInputDim("CacheKV");

      PADDLE_ENFORCE_EQ(
          c_dim.size(), 5,
          paddle::platform::errors::InvalidArgument(
              "The CacheKV must be 5 dims, but got %d", c_dim.size()));
      PADDLE_ENFORCE_EQ(c_dim[0], 2,
                        paddle::platform::errors::InvalidArgument(
                            "The first dim of CacheKV must be 2, but got %d",
                            c_dim[0]));  // 2
      PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0],
                        paddle::platform::errors::InvalidArgument(
                            "The second dim of CacheKV must be equal with "
                            "batch size %d, but got %d",
                            x_dim[0], c_dim[1]));  // batch_size
      PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1],
                        paddle::platform::errors::InvalidArgument(
                            "The third dim of CacheKV must be equal with num "
                            "head %d, but got %d",
                            y_dim[1], c_dim[2]));  // num_head
166 167 168 169 170 171 172 173 174
      // In compile stage, input seq_len can be -1, in that case
      // c_dim[3] may < 0 in while
      if (ctx->IsRuntime()) {
        PADDLE_ENFORCE_GE(
            c_dim[3], 0,
            paddle::platform::errors::InvalidArgument(
                "The forth dim of CacheKV must be greater than 0, but got %d",
                c_dim[3]));  // cache_seq_len
      }
175 176 177 178 179 180 181 182 183 184 185 186 187 188
      PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
                        paddle::platform::errors::InvalidArgument(
                            "The fifth dim of CacheKV must be equal with head "
                            "size %d, but got %d",
                            y_dim[2], c_dim[4]));  // head_size

      out_seq_len += c_dim[3];
      // [3, batch_size, num_head, cache_seq_len + seq_len, head_size]
      ctx->SetOutputDim("CacheKVOut",
                        {c_dim[0], c_dim[1], c_dim[2], out_seq_len, c_dim[4]});
    }

    // [batch, num_head, seq_len, out_seq_len]
    ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
189 190

    if (ctx->HasInput("SrcMask")) {
191 192
      ctx->SetOutputDim("SrcMaskOut",
                        {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
193
    }
L
Li Min 已提交
194 195
    // the same as QKOut's shape.
    ctx->SetOutputDim("AttnDropoutOut",
196
                      {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
197 198
    if (ctx->Attrs().Get<bool>("attn_dropout_is_test") == false) {
      ctx->SetOutputDim("AttnDropoutMaskOut",
199
                        {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
200
    }
201 202
    ctx->SetOutputDim("SoftmaxOut",
                      {x_dim[0], y_dim[1], x_dim[1], out_seq_len});
L
Li Min 已提交
203 204 205 206 207 208 209 210 211
    // [batch_size, num_heads, seq_len, head_dim]
    ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
    // [batch_size, seq_len, number of heads*head size]
    ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
    ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));

    if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
      ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
    }
L
Li Min 已提交
212

L
Li Min 已提交
213 214 215 216 217 218 219
    ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto input = ctx.Input<Tensor>("X");
220
    auto input_data_type = framework::TransToProtoVarType(input->dtype());
L
Li Min 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "The input tensor.");
    AddInput("LnScale",
             "(optional) Scale is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("LnBias",
             "(optional) Bias is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("QKVW", "The qkv weight tensor.");
238
    AddInput("QKVBias", "The qkv bias tensor.").AsDispensable();
239 240
    AddInput("CacheKV", "(optional) The cached KV for generation inference.")
        .AsDispensable();
L
Li Min 已提交
241 242 243
    AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
        .AsDispensable();
    AddInput("OutLinearW", "The out_linear weight tensor.");
244
    AddInput("OutLinearBias", "The out_linear bias tensor.").AsDispensable();
L
Li Min 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    AddInput("Ln2Scale",
             "(optional) Scale is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddInput("Ln2Bias",
             "(optional) Bias is a 1-dimensional tensor of size "
             "H. Here, H represents the last dimension of its input tensor.")
        .AsDispensable();
    AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
    AddOutput("LnVariance", "Variance of the current mini batch.")
        .AsIntermediate();
    AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate();
    AddOutput("QKVOut", "Result after qkv.").AsIntermediate();
    AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate();
    AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate();
    AddOutput("QKOut", "Result in fmha.").AsIntermediate();
    AddOutput("QKTVOut", "Result in fmha.").AsIntermediate();
    AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
    AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate();
    AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate();
    AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate();
    AddOutput("FMHAOut", "Result after fmha.").AsIntermediate();
    AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate();
    AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
        .AsIntermediate();
    AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate();
    AddOutput("Ln2Variance", "Variance of the current mini batch.")
        .AsIntermediate();
    AddOutput("BiasDropoutResidualOut",
              "Result of residual + dropout(src + bias).")
        .AsIntermediate();
276
    AddOutput("CacheKVOut", "The udpated cache KV.");
L
Li Min 已提交
277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
    AddOutput("Y", "Result after attention.");

    AddAttr<bool>("pre_layer_norm",
                  "if true, the attention op uses pre_layer_norm architecure, "
                  "else, uses post_layer_norm architecuture. "
                  "[default false].")
        .SetDefault(false);
    AddAttr<float>("epsilon",
                   "Constant for numerical stability [default 1e-5].")
        .SetDefault(1e-5)
        .AddCustomChecker([](const float &epsilon) {
          PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
                            platform::errors::InvalidArgument(
                                "'epsilon' in Op(LayerNorm) should be between"
                                "0.0 and 0.001, But received [%s].",
                                epsilon));
        });

    // for dropout in fmha.
    AddAttr<float>("attn_dropout_rate", "Probability of setting units to zero.")
        .SetDefault(.5f)
        .AddCustomChecker([](const float &drop_p) {
          PADDLE_ENFORCE_EQ(
              drop_p >= 0.0f && drop_p <= 1.0f, true,
              platform::errors::InvalidArgument(
                  "'attn_dropout_rate' must be between 0.0 and 1.0."));
        });
    AddAttr<bool>("attn_dropout_is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddAttr<bool>("attn_dropout_fix_seed",
                  "A flag indicating whether to use a fixed seed to generate "
                  "random mask. NOTE: DO NOT set this flag to true in "
                  "training. Setting this flag to true is only useful in "
                  "unittest or for debug that always the same output units "
                  "will be dropped.")
        .SetDefault(true);
    AddAttr<int>("attn_dropout_seed", "Dropout random seed.").SetDefault(0);
    AddAttr<std::string>(
        "attn_dropout_implementation",
        "[\"downgrade_in_infer\"|\"upscale_in_train\"]"
        "There are two kinds of ways to implement dropout"
        "(the mask below is a tensor have the same shape with input"
        "the value of mask is 0 or 1, the ratio of 0 is dropout_rate)"
        "1. downgrade_in_infer(default), downgrade the outcome at inference "
        "time"
        "   train: out = input * mask"
        "   inference: out = input * (1.0 - dropout_rate)"
        "2. upscale_in_train, upscale the outcome at training time, do nothing "
        "in inference"
        "   train: out = input * mask / ( 1.0 - dropout_rate )"
        "   inference: out = input"
        "   dropout op can be removed from the program. the program will be "
        "efficient")
        .SetDefault("upscale_in_train")
        .AddCustomChecker([](const std::string &type) {
          PADDLE_ENFORCE_EQ(
              type == "downgrade_in_infer" || type == "upscale_in_train", true,
              platform::errors::InvalidArgument(
                  "dropout_implementation can only be downgrade_in_infer or "
                  "upscale_in_train"));
        });

    AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
        .SetDefault(.5f)
        .AddCustomChecker([](const float &drop_p) {
          PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
                            platform::errors::InvalidArgument(
                                "'dropout_rate' must be between 0.0 and 1.0."));
        });

    AddAttr<bool>("dropout_is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddAttr<bool>("dropout_fix_seed",
                  "A flag indicating whether to use a fixed seed to generate "
                  "random mask. NOTE: DO NOT set this flag to true in "
                  "training. Setting this flag to true is only useful in "
                  "unittest or for debug that always the same output units "
                  "will be dropped.")
        .SetDefault(true);
    AddAttr<int>("dropout_seed", "Dropout random seed.").SetDefault(0);
    AddAttr<std::string>(
        "dropout_implementation",
        "[\"downgrade_in_infer\"|\"upscale_in_train\"]"
        "The meaning is the same as 'attn_dropout_implementation'.")
        .SetDefault("downgrade_in_infer")
        .AddCustomChecker([](const std::string &type) {
          PADDLE_ENFORCE_EQ(
              type == "downgrade_in_infer" || type == "upscale_in_train", true,
              platform::errors::InvalidArgument(
                  "dropout_implementation can only be downgrade_in_infer or "
                  "upscale_in_train"));
        });
    AddAttr<float>("ln_epsilon",
                   "Constant for numerical stability [default 1e-5].")
        .SetDefault(1e-5)
        .AddCustomChecker([](const float &ln_epsilon) {
          PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true,
                            platform::errors::InvalidArgument(
                                "'epsilon' of the second LayerNorm in Fused "
                                "attention op should be between"
                                "0.0 and 0.001, But received [%s].",
                                ln_epsilon));
        });
384 385 386 387
    AddAttr<int>(
        "ring_id",
        "ring id for tensor model parallel. distributed training and inference")
        .SetDefault(-1);
L
Li Min 已提交
388 389

    AddComment(R"DOC(
L
Li Min 已提交
390 391 392 393 394
  Add fused attention op whose logic is as follows:
  // @input: [batch_size, seq_len, 3, num_head, head_dim] 
  // @final_out: [batch_size, seq_len, num_heads, head_dim] 
  if (pre_layernorm)
    out = layer_norm(input);
L
Li Min 已提交
395 396
	out = compute_qkv(out) + bias;
	// fmha module
L
Li Min 已提交
397 398 399 400 401 402 403 404
  {
    out = transpose(out, perm=[2, 0, 3, 1, 4]);
    out = q * k^t;
    out = attn_mask + out;
    out = softmax(out);
    out = dropout(out);
    out = out * v;
    out = transpose(out, perm=[0, 2, 1, 3]);
L
Li Min 已提交
405
                
L
Li Min 已提交
406
  }
L
Li Min 已提交
407
	out = out_linear(out);
L
Li Min 已提交
408 409 410 411
  if (pre_layernorm)
    final_out = residual + dropout(bias + out);
  else
    final_out = layer_norm(residual + dropout(bias + out));
L
Li Min 已提交
412 413 414 415
    )DOC");
  }
};

416 417 418 419 420 421 422 423 424 425
class FusedAttentionGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
    PADDLE_ENFORCE_EQ(
        ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
        platform::errors::InvalidArgument(
            "GradOp is only callable when attn_dropout_is_test is false"));

L
Li Min 已提交
426 427 428 429 430 431 432 433 434 435 436 437 438 439
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == false) {
      OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
                     "FusedAttentionGrad");
      OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
                     "FusedAttentionGrad");
      if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
        ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
                          ctx->GetInputDim("Ln2Scale"));
      }
      if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
        ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
                          ctx->GetInputDim("Ln2Bias"));
      }
    } else {
440 441 442 443
      OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
                     "FusedAttentionGrad");
      OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
                     "FusedAttentionGrad");
444 445 446
      OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
                     "FusedAttentionGrad");
    }
L
Li Min 已提交
447 448

    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
449 450 451 452 453
    OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
                   "FusedAttentionGrad");
    OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
                   "FusedAttentionGrad");

454 455 456 457 458 459 460 461 462
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
        ctx->SetOutputDim(framework::GradVarName("LnScale"),
                          ctx->GetInputDim("LnScale"));
      }
      if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
        ctx->SetOutputDim(framework::GradVarName("LnBias"),
                          ctx->GetInputDim("LnBias"));
      }
463 464 465 466
    }
    if (ctx->HasOutput(framework::GradVarName("X"))) {
      ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
    }
467 468 469 470
    if (ctx->HasOutput(framework::GradVarName("OutLinearBias"))) {
      ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
                        ctx->GetInputDim("OutLinearBias"));
    }
471 472 473
    ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
                      ctx->GetInputDim("OutLinearW"));
    ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
474 475 476 477
    if (ctx->HasOutput(framework::GradVarName("QKVBias"))) {
      ctx->SetOutputDim(framework::GradVarName("QKVBias"),
                        ctx->GetInputDim("QKVBias"));
    }
478

479 480 481
    if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
      ctx->SetOutputDim(framework::GradVarName("LnOut"),
                        ctx->GetInputDim("LnOut"));
L
Li Min 已提交
482 483 484
    } else {
      ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
                        ctx->GetInputDim("BiasDropoutResidualOut"));
485
    }
486 487 488 489 490 491 492 493 494 495 496 497
    ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
                      ctx->GetInputDim("FMHAOut"));
    ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
                      ctx->GetInputDim("QKTVOut"));
    ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
                      ctx->GetInputDim("TransposeOut2"));
    ctx->SetOutputDim(framework::GradVarName("QKOut"),
                      ctx->GetInputDim("QKOut"));
    ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
                      ctx->GetInputDim("SoftmaxOut"));
    ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
                      ctx->GetInputDim("AttnDropoutOut"));
498 499 500 501 502

    if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
      ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
                        ctx->GetInputDim("SrcMaskOut"));
    }
503 504
    ctx->SetOutputDim(framework::GradVarName("QKVOut"),
                      ctx->GetInputDim("QKVOut"));
505 506 507 508
    if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) {
      ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
                        ctx->GetInputDim("QKVBiasOut"));
    }
509 510 511 512 513 514 515 516
    ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
                      ctx->GetInputDim("OutLinearOut"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    auto input = ctx.Input<Tensor>("X");
517
    auto input_data_type = framework::TransToProtoVarType(input->dtype());
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> op) const override {
    op->SetType("fused_attention_grad");
    op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));

    // inputs x, parameters and their grad.
    op->SetInput("X", this->Input("X"));
    op->SetInput("QKVW", this->Input("QKVW"));
535 536 537 538 539 540 541 542 543

    if (this->HasInput("QKVBias")) {
      op->SetInput("QKVBias", this->Input("QKVBias"));
      op->SetOutput(framework::GradVarName("QKVBias"),
                    this->InputGrad("QKVBias"));
      op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
      op->SetOutput(framework::GradVarName("QKVBiasOut"),
                    this->OutputGrad("QKVBiasOut"));
    }
544 545 546 547 548 549 550 551

    if (this->HasInput("SrcMask")) {
      op->SetInput("SrcMask", this->Input("SrcMask"));
      op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
      op->SetOutput(framework::GradVarName("SrcMaskOut"),
                    this->OutputGrad("SrcMaskOut"));
    }

552
    op->SetInput("OutLinearW", this->Input("OutLinearW"));
553 554 555 556 557
    if (this->HasInput("OutLinearBias")) {
      op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
      op->SetOutput(framework::GradVarName("OutLinearBias"),
                    this->InputGrad("OutLinearBias"));
    }
558 559 560 561 562 563 564 565 566 567 568 569 570 571 572

    op->SetAttrMap(this->Attrs());
    bool is_pre_layer_norm =
        BOOST_GET_CONST(bool, op->GetAttr("pre_layer_norm"));
    if (is_pre_layer_norm) {
      if (this->HasInput("LnScale")) {
        op->SetInput("LnScale", this->Input("LnScale"));
        op->SetOutput(framework::GradVarName("LnScale"),
                      this->InputGrad("LnScale"));
      }
      if (this->HasInput("LnBias")) {
        op->SetInput("LnBias", this->Input("LnBias"));
        op->SetOutput(framework::GradVarName("LnBias"),
                      this->InputGrad("LnBias"));
      }
L
Li Min 已提交
573 574 575 576 577 578 579 580 581 582 583
    } else {
      if (this->HasInput("Ln2Scale")) {
        op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
        op->SetOutput(framework::GradVarName("Ln2Scale"),
                      this->InputGrad("Ln2Scale"));
      }
      if (this->HasInput("Ln2Bias")) {
        op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
        op->SetOutput(framework::GradVarName("Ln2Bias"),
                      this->InputGrad("Ln2Bias"));
      }
584 585 586 587
    }

    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
588

589 590 591 592
    op->SetOutput(framework::GradVarName("OutLinearW"),
                  this->InputGrad("OutLinearW"));

    // use forward outputs as backward inputs.
593 594 595 596 597 598 599 600 601 602
    if (is_pre_layer_norm) {
      if (this->HasOutput("LnOut")) {
        op->SetInput("LnOut", this->Output("LnOut"));
      }
      if (this->HasOutput("LnMean")) {
        op->SetInput("LnMean", this->Output("LnMean"));
      }
      if (this->HasOutput("LnVariance")) {
        op->SetInput("LnVariance", this->Output("LnVariance"));
      }
L
Li Min 已提交
603 604 605 606 607
    } else {
      op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
      op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
      op->SetInput("BiasDropoutResidualOut",
                   this->Output("BiasDropoutResidualOut"));
608
    }
609
    op->SetInput("QKVOut", this->Output("QKVOut"));
610

611 612 613 614 615 616
    op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
    op->SetInput("QKOut", this->Output("QKOut"));
    op->SetInput("QKTVOut", this->Output("QKTVOut"));
    op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
    op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
    op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
617

618 619 620 621 622 623
    op->SetInput("FMHAOut", this->Output("FMHAOut"));
    op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
    op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
    op->SetInput("QKVOut", this->Output("QKVOut"));

    // backward outputs: dinput
624 625 626 627 628
    if (is_pre_layer_norm) {
      if (this->HasOutput("LnOut")) {
        op->SetOutput(framework::GradVarName("LnOut"),
                      this->OutputGrad("LnOut"));
      }
L
Li Min 已提交
629 630 631
    } else {
      op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
                    this->OutputGrad("BiasDropoutResidualOut"));
632
    }
L
Li Min 已提交
633

634
    op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
635

636 637 638 639 640 641 642 643 644
    op->SetOutput(framework::GradVarName("QKTVOut"),
                  this->OutputGrad("QKTVOut"));
    op->SetOutput(framework::GradVarName("TransposeOut2"),
                  this->OutputGrad("TransposeOut2"));
    op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
    op->SetOutput(framework::GradVarName("SoftmaxOut"),
                  this->OutputGrad("SoftmaxOut"));
    op->SetOutput(framework::GradVarName("AttnDropoutOut"),
                  this->OutputGrad("AttnDropoutOut"));
645

646 647 648 649 650 651 652
    op->SetOutput(framework::GradVarName("FMHAOut"),
                  this->OutputGrad("FMHAOut"));
    op->SetOutput(framework::GradVarName("OutLinearOut"),
                  this->OutputGrad("OutLinearOut"));
  }
};

L
Li Min 已提交
653 654 655 656 657
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
658 659 660 661
                  ops::FusedAttentionOpMaker,
                  ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
                  ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);