From d4906214656754cb7216279d07b327b49c3ee617 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Fri, 22 Oct 2021 15:33:53 +0800 Subject: [PATCH] Fused attention op forward (#35905) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 功能:本PR的目标是提高attention模块的计算性能。 为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op; 为了减少防存开销,本PR采取了两种优化方法: (1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次; (2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据; --- cmake/operators.cmake | 2 +- paddle/fluid/operators/dropout_impl_util.h | 3 + paddle/fluid/operators/fused/CMakeLists.txt | 4 + .../operators/fused/fused_attention_op.cc | 336 ++++++++++++++++++ .../operators/fused/fused_attention_op.cu | 209 +++++++++++ .../operators/fused/fused_dropout_helper.h | 2 +- paddle/fluid/pybind/op_function_generator.cc | 8 + .../fluid/tests/unittests/CMakeLists.txt | 4 + .../unittests/test_fused_attention_op.py | 235 ++++++++++++ python/paddle/nn/functional/__init__.py | 2 + .../paddle/nn/functional/fused_transformer.py | 127 +++++++ python/paddle/nn/layer/transformer.py | 2 +- 12 files changed, 931 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cc create mode 100644 paddle/fluid/operators/fused/fused_attention_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_fused_attention_op.py create mode 100644 python/paddle/nn/functional/fused_transformer.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 5eecbefa2fc..a396af570f3 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -217,7 +217,7 @@ function(op_library TARGET) "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op" "resnet_unit_op") +"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index f2038d12528..e11640d0706 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -34,6 +34,9 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor); *seed_data = static_cast(seed_cpu_tensor.data()[0]); *increment = offset; + } else if (seed && platform::is_cpu_place(seed->place())) { + *seed_data = *(seed->data()); + *increment = offset; } else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) { auto seed_offset = gen_cuda->IncrementOffset(offset); *seed_data = seed_offset.first; diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 2286aaaf859..845e5659a88 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -16,6 +16,7 @@ register_operators(EXCLUDES fusion_gru_op fusion_lstm_op fused_bn_add_activation_op + fused_attention_op fused_transformer_op resnet_unit_op) @@ -78,6 +79,9 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) + # fused_attention_op + op_library(fused_attention_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n") endif() # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc new file mode 100644 index 00000000000..a286c39f7f8 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -0,0 +1,336 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FusedAttentionOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", + "FusedAttentionOp"); + + OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut", + "FusedAttentionOp"); + // qkv_out: [batch_size, seq_len, 3, num_head, dim_head] + OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", + "AttnDropoutMaskOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"), "Output", "AttnDropoutOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", + "BiasDropoutResidualOut", "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut", + "FusedAttentionOp"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("QKVW"); + PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument( + "The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ(y_dim.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3], + platform::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3]" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, y_dim)); + + ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("LnOut", ctx->GetInputDim("X")); + // [batch_size, seq_len, 3, num_head, head_size] + ctx->SetOutputDim("QKVOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("QKVBiasOut", + {x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]}); + // [3, batch_size, num_head, seq_len, head_size] + ctx->SetOutputDim("TransposeOut2", + {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // [batch, num_head, seq_len, seq_len] + ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // the same as QKOut's shape. + ctx->SetOutputDim("AttnDropoutOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + if (ctx->Attrs().Get("attn_dropout_is_test") == false) { + ctx->SetOutputDim("AttnDropoutMaskOut", + {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + } + ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); + // [batch_size, num_heads, seq_len, head_dim] + ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); + // [batch_size, seq_len, number of heads*head size] + ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]}); + ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X")); + + ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]}); + ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]}); + if (ctx->Attrs().Get("dropout_is_test") == false) { + ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X")); + } + ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X")); + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input = ctx.Input("X"); + auto input_data_type = input->type(); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor."); + AddInput("LnScale", + "(optional) Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("LnBias", + "(optional) Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("QKVW", "The qkv weight tensor."); + AddInput("QKVBias", "The qkv bias tensor."); + AddInput("SrcMask", "(optional) The attention mask tensor in fmha.") + .AsDispensable(); + AddInput("OutLinearW", "The out_linear weight tensor."); + AddInput("OutLinearBias", "The out_linear bias tensor."); + AddInput("Ln2Scale", + "(optional) Scale is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddInput("Ln2Bias", + "(optional) Bias is a 1-dimensional tensor of size " + "H. Here, H represents the last dimension of its input tensor.") + .AsDispensable(); + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("LnVariance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate(); + AddOutput("QKVOut", "Result after qkv.").AsIntermediate(); + AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate(); + AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate(); + AddOutput("QKOut", "Result in fmha.").AsIntermediate(); + AddOutput("QKTVOut", "Result in fmha.").AsIntermediate(); + AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate(); + AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate(); + AddOutput("FMHAOut", "Result after fmha.").AsIntermediate(); + AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate(); + AddOutput("DropoutMaskOut", "The random sampled dropout mask.") + .AsIntermediate(); + AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate(); + AddOutput("Ln2Variance", "Variance of the current mini batch.") + .AsIntermediate(); + AddOutput("BiasDropoutResidualOut", + "Result of residual + dropout(src + bias).") + .AsIntermediate(); + AddOutput("Y", "Result after attention."); + + AddAttr("pre_layer_norm", + "if true, the attention op uses pre_layer_norm architecure, " + "else, uses post_layer_norm architecuture. " + "[default false].") + .SetDefault(false); + AddAttr("epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' in Op(LayerNorm) should be between" + "0.0 and 0.001, But received [%s].", + epsilon)); + }); + + // for dropout in fmha. + AddAttr("attn_dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ( + drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'attn_dropout_rate' must be between 0.0 and 1.0.")); + }); + AddAttr("attn_dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("attn_dropout_fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("attn_dropout_seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "attn_dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "There are two kinds of ways to implement dropout" + "(the mask below is a tensor have the same shape with input" + "the value of mask is 0 or 1, the ratio of 0 is dropout_rate)" + "1. downgrade_in_infer(default), downgrade the outcome at inference " + "time" + " train: out = input * mask" + " inference: out = input * (1.0 - dropout_rate)" + "2. upscale_in_train, upscale the outcome at training time, do nothing " + "in inference" + " train: out = input * mask / ( 1.0 - dropout_rate )" + " inference: out = input" + " dropout op can be removed from the program. the program will be " + "efficient") + .SetDefault("upscale_in_train") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + + AddAttr("dropout_rate", "Probability of setting units to zero.") + .SetDefault(.5f) + .AddCustomChecker([](const float &drop_p) { + PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true, + platform::errors::InvalidArgument( + "'dropout_rate' must be between 0.0 and 1.0.")); + }); + + AddAttr("dropout_is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("dropout_fix_seed", + "A flag indicating whether to use a fixed seed to generate " + "random mask. NOTE: DO NOT set this flag to true in " + "training. Setting this flag to true is only useful in " + "unittest or for debug that always the same output units " + "will be dropped.") + .SetDefault(true); + AddAttr("dropout_seed", "Dropout random seed.").SetDefault(0); + AddAttr( + "dropout_implementation", + "[\"downgrade_in_infer\"|\"upscale_in_train\"]" + "The meaning is the same as 'attn_dropout_implementation'.") + .SetDefault("downgrade_in_infer") + .AddCustomChecker([](const std::string &type) { + PADDLE_ENFORCE_EQ( + type == "downgrade_in_infer" || type == "upscale_in_train", true, + platform::errors::InvalidArgument( + "dropout_implementation can only be downgrade_in_infer or " + "upscale_in_train")); + }); + AddAttr("ln_epsilon", + "Constant for numerical stability [default 1e-5].") + .SetDefault(1e-5) + .AddCustomChecker([](const float &ln_epsilon) { + PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true, + platform::errors::InvalidArgument( + "'epsilon' of the second LayerNorm in Fused " + "attention op should be between" + "0.0 and 0.001, But received [%s].", + ln_epsilon)); + }); + + AddComment(R"DOC( + Add fused attention op whose logic is as follows: + // @input: [batch_size, seq_len, 3, num_head, head_dim] + // @final_out: [batch_size, seq_len, num_heads, head_dim] + if (pre_layernorm) + out = layer_norm(input); + out = compute_qkv(out) + bias; + // fmha module + { + out = transpose(out, perm=[2, 0, 3, 1, 4]); + out = q * k^t; + out = attn_mark + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + + } + out = out_linear(out); + final_out = layer_norm(residual + dropout(bias + out)); + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, + ops::FusedAttentionOpMaker); diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu new file mode 100644 index 00000000000..18a42b5c2ce --- /dev/null +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -0,0 +1,209 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +#include "paddle/fluid/operators/fused/attention_layer_norm.h" +#include "paddle/fluid/operators/fused/attn_gemm.h" +#include "paddle/fluid/operators/fused/fmha_ref.h" +#include "paddle/fluid/operators/fused/fused_dropout_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FusedAttentionOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto *input_x = ctx.Input("X"); + + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto *ln_scale = ctx.Input("LnScale"); + auto *ln_bias = ctx.Input("LnBias"); + auto *ln_mean = ctx.Output("LnMean"); + auto *ln_var = ctx.Output("LnVariance"); + auto *ln_out = ctx.Output("LnOut"); + + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto *qkv_weight = ctx.Input("QKVW"); + auto *qkv_bias = ctx.Input("QKVBias"); + auto *qkv_out = ctx.Output("QKVOut"); + auto *qkv_bias_out = ctx.Output("QKVBiasOut"); + + auto *src_mask = ctx.Input("SrcMask"); + auto *transpose_out_2 = ctx.Output("TransposeOut2"); + auto *qk_out = ctx.Output("QKOut"); + auto *qktv_out = ctx.Output("QKTVOut"); + auto *softmax_out = ctx.Output("SoftmaxOut"); + auto *attn_dropout_mask_out = ctx.Output("AttnDropoutMaskOut"); + auto *attn_dropout_out = ctx.Output("AttnDropoutOut"); + auto *src_mask_out = ctx.Output("SrcMaskOut"); + auto *fmha_out = ctx.Output("FMHAOut"); + + auto *out_linear_weight = ctx.Input("OutLinearW"); + auto *out_linear_bias = ctx.Input("OutLinearBias"); + auto *out_linear_out = ctx.Output("OutLinearOut"); + + auto *ln_scale_2 = ctx.Input("Ln2Scale"); + auto *ln_bias_2 = ctx.Input("Ln2Bias"); + auto *dropout_mask_out = ctx.Output("DropoutMaskOut"); + auto *bias_dropout_residual_out = + ctx.Output("BiasDropoutResidualOut"); + auto *ln_mean_2 = ctx.Output("Ln2Mean"); + auto *ln_var_2 = ctx.Output("Ln2Variance"); + const float ln_epsilon = ctx.Attr("ln_epsilon"); + + float attn_dropout_rate = ctx.Attr("attn_dropout_rate"); + bool is_test_1 = ctx.Attr("attn_dropout_is_test"); + auto &dropout_implementation_1 = + ctx.Attr("attn_dropout_implementation"); + bool is_upscale_in_train_1 = + (dropout_implementation_1 == "upscale_in_train"); + auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input("Seed1") : nullptr; + bool is_fix_seed_1 = ctx.Attr("attn_dropout_fix_seed"); + int seed_val_1 = ctx.Attr("attn_dropout_seed"); + + // final output. + auto *out = ctx.Output("Y"); + + // get data ptr for qkv part. + const auto input_x_dims = input_x->dims(); + const auto qkv_w_dims = qkv_weight->dims(); + + auto *x_data = input_x->data(); + auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); + auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); + auto *ln_mean_data = ln_mean->mutable_data(ctx.GetPlace()); + auto *ln_var_data = ln_var->mutable_data(ctx.GetPlace()); + auto *ln_out_data = ln_out->mutable_data(ctx.GetPlace()); + + auto *qkv_weight_data = qkv_weight->data(); + auto *qkv_bias_data = qkv_bias->data(); + auto *qkv_out_data = qkv_out->mutable_data(ctx.GetPlace()); + auto *qkv_bias_out_data = qkv_bias_out->mutable_data(ctx.GetPlace()); + + // get data ptr for FMHA. + auto *transpose_out_2_data = + transpose_out_2->mutable_data(ctx.GetPlace()); + auto *qk_out_data = qk_out->mutable_data(ctx.GetPlace()); + auto *qktv_out_data = qktv_out->mutable_data(ctx.GetPlace()); + auto *src_mask_out_data = src_mask_out->mutable_data(ctx.GetPlace()); + auto *softmax_out_data = softmax_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_mask_out_data = + attn_dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *attn_dropout_out_data = + attn_dropout_out->mutable_data(ctx.GetPlace()); + auto *fmha_out_data = fmha_out->mutable_data(ctx.GetPlace()); + + // get data ptr for out_linear. + auto *out_linear_weight_data = out_linear_weight->data(); + auto *out_linear_bias_data = out_linear_bias->data(); + auto *out_linear_out_data = out_linear_out->mutable_data(ctx.GetPlace()); + + // get data ptr for bias+dropout+residual+layernorm + auto *ln_scale_2_data = + (ln_scale_2 == nullptr ? nullptr : ln_scale_2->data()); + auto *ln_bias_2_data = + (ln_bias_2 == nullptr ? nullptr : ln_bias_2->data()); + auto *dropout_mask_out_data = + dropout_mask_out->mutable_data(ctx.GetPlace()); + auto *bias_dropout_residual_out_data = + bias_dropout_residual_out->mutable_data(ctx.GetPlace()); + auto *ln_mean_2_data = ln_mean_2->mutable_data(ctx.GetPlace()); + auto *ln_var_2_data = ln_var_2->mutable_data(ctx.GetPlace()); + auto *final_out_data = out->mutable_data(ctx.GetPlace()); + + int batch_size = input_x_dims[0]; + int max_seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + + int num_head = qkv_w_dims[1]; + int dim_head = qkv_w_dims[2]; + + int bsz_seq = batch_size * max_seq_len; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + auto layer_norm_compute = AttnLayerNorm(ctx.cuda_device_context(), + epsilon, bsz_seq, dim_embed); + // (transA, transB, compute_bias) = (false, true, true) + auto qkv_compute = AttnMatMul(ctx.cuda_device_context(), false, true, + bsz_seq, output_size, input_size, true); + + AttnDropoutParam attn_dropout_param( + is_test_1, dropout_implementation_1, attn_dropout_rate, + is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1); + auto fmha_ref_compute = + FMHARef(ctx.cuda_device_context(), batch_size, max_seq_len, num_head, + dim_head, attn_dropout_param); + + output_size = hidden_size; + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = + AttnMatMul(ctx.cuda_device_context(), false, false, bsz_seq, + output_size, input_size, false); + DropoutParam dropout_param2(ctx, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2, + ln_epsilon); + + if (pre_layer_norm) { + layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data, + ln_out_data, ln_mean_data, ln_var_data); + qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } else { + qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, + qkv_out_data, qkv_bias_out_data); + } + fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, + qk_out, src_mask_out, softmax_out, + attn_dropout_mask_out, attn_dropout_out, + qktv_out, fmha_out); + // fmha_out: [batch_size, seq_len, num_head, head_dim] + // weight: [embed_dim, embed_dim] + // out_linear_out: [batch_size, seq_len, embed_dim] + out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data, + nullptr, out_linear_out_data, nullptr); + // output = layernorm(residual + dropout(input + bias)) + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + ctx.cuda_device_context(), out_linear_out_data, x_data, + out_linear_bias_data, ln_scale_2_data, ln_bias_2_data, + bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data, + ln_mean_2_data, ln_var_2_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel, + ops::FusedAttentionOpKernel); diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index fcfa405a52f..33fde64164d 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -66,7 +66,7 @@ struct DropoutParam { } else { pre_fix = pre_fix + "_"; } - dropout_prob = context.Attr(pre_fix + "prob"); + dropout_prob = context.Attr(pre_fix + "rate"); auto& dropout_implementation = context.Attr(pre_fix + "implementation"); is_upscale_in_train = (dropout_implementation == "upscale_in_train"); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index d031709b765..08ab1d7d344 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -40,6 +40,9 @@ // need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, + {"fused_attention", + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", + "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, {"instance_norm", {"X", "Scale", "Bias"}}, {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, {"label_smooth", {"X", "PriorDist"}}, @@ -92,6 +95,11 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"fused_attention", + {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", + "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, {"sync_batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1c9ce2bef5e..c6d90ee404f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -97,6 +97,10 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() +if(NOT WITH_GPU) + LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) +endif() + if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) LIST(REMOVE_ITEM TEST_OPS test_c_concat) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py new file mode 100644 index 00000000000..a5578d71c5c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -0,0 +1,235 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.fluid.core as core +import paddle.nn.functional as F +from paddle.nn.layer.norm import LayerNorm +from paddle.nn.layer.common import Linear, Dropout +from paddle.nn.layer.transformer import _convert_attention_mask +from paddle import tensor +from paddle.fluid import layers +import unittest +from op_test import OpTest + + +class TestFusedAttentionOp(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_attention" + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def generate_input_data(self): + self.query = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.x_type) + self.attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") + self.key, self.value = self.query, self.query + + self.dout = np.random.random((self.batch_size, self.query_length, + self.embed_dim)).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + residual = tensor_query + + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + qk_out = layers.matmul( + x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train") + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + if self.pre_layer_norm: + final_out = self.norm2(residual_out) + return final_out + + def GetFusedAttentionOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False) + q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False) + k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False) + v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight)) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim)) + + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy())) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + + x = paddle.to_tensor(self.query, stop_gradient=False) + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + final_out = F.fused_multi_head_attention( + x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm, + ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, + out_linear_bias, attn_mask, self.dropout_prob, + self.attn_dropout_prob, ln2_epsilon) + return final_out + + def test_fused_attention_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) + + +class TestFusedAttentionOpFp16(TestFusedAttentionOp): + def config(self): + self.x_type = np.float16 + self.attn_mask_type = np.float64 + self.pre_layer_norm = True + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = self.query_length, self.query_length + + def test_fused_attention_op(self): + final_out_ref = self.GetBaselineOut() + final_out = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 1af53e0826b..8daae3d0ca9 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -61,6 +61,7 @@ from .common import class_center_sample # noqa: F401 from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 +from .fused_transformer import fused_multi_head_attention # noqa: F401 from .conv import conv2d # noqa: F401 from .conv import conv2d_transpose # noqa: F401 from .conv import conv3d # noqa: F401 @@ -211,5 +212,6 @@ __all__ = [ #noqa 'layer_norm', 'instance_norm', 'class_center_sample', + 'fused_multi_head_attention', 'sparse_attention', ] diff --git a/python/paddle/nn/functional/fused_transformer.py b/python/paddle/nn/functional/fused_transformer.py new file mode 100644 index 00000000000..565ef223a96 --- /dev/null +++ b/python/paddle/nn/functional/fused_transformer.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from ...fluid.framework import in_dygraph_mode +from paddle import _C_ops + +__all__ = [] + + +def fused_multi_head_attention(x, + qkv_weight, + linear_weight, + pre_layer_norm=False, + pre_ln_scale=None, + pre_ln_bias=None, + ln_scale=None, + ln_bias=None, + pre_ln_epsilon=1e-05, + qkv_bias=None, + linear_bias=None, + attn_mask=None, + dropout_rate=0.5, + attn_dropout_rate=0.5, + ln_epsilon=1e-05, + name=None): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. This API only + support self_attention. The pseudo code is as follows: + if pre_layer_norm: + out = layer_norm(x); + out = linear(out) + qkv)bias + else: + out = linear(x) + bias; + out = transpose(out, perm=[2, 0, 3, 1, 4]); + # extract q, k and v from out. + q = out[0:1,::] + k = out[1:2,::] + v = out[2:3,::] + out = q * k^t; + out = attn_mask + out; + out = softmax(out); + out = dropout(out); + out = out * v; + out = transpose(out, perm=[0, 2, 1, 3]); + out = out_linear(out); + out = layer_norm(x + dropout(linear_bias + out)); + + Parameters: + x (Tensor): The input tensor of fused_multi_head_attention. The shape is + `[batch\_size, sequence\_len, embed\_dim]`. + qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. + linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. + pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture. + Default False. + pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. + pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. + ln_scale (Tensor, optional): The weight tensor of layernorm. Default None. + ln_bias (Tensor, optional): The bias tensor of layernorm. Default None. + pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm + to avoid dividing by zero. Default is 1e-5. + qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. + Default None. + linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. + attn_mask (Tensor, optional): + dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout after attention. + 0 for no dropout. Default 0. + attn_dropout_rate (float, optional): The dropout probability used on attention + weights to drop some attention targets for the dropout in attention. + 0 for no dropout. Default 0. + ln_epsilon (float, optional): Small float value added to denominator of layer_norm + to avoid dividing by zero. Default is 1e-5. + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + import paddle.nn.functional as F + + # input: [batch_size, seq_len, embed_dim] + x = paddle.rand(shape=(2, 4, 128), dtype="float32") + # qkv_weight: [3, num_head, dim_head, dim_embed] + qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") + # qkv_bias: [3, num_head, dim_head] + qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") + # linear_weight: [embed_dim, embed_dim] + linear_weight = paddle.rand(shape=(128, 128), dtype="float32") + # linear_bias: [embed_dim] + linear_bias = paddle.rand(shape=[128], dtype="float32") + # self attention mask: [batch_size, num_heads, seq_len, seq_len] + attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32") + + # output: [batch_size, seq_len, embed_dim] + output = F.fused_multi_head_attention( + x, qkv_weight, linear_weight, False, + None, None, None, None, 1e-5, qkv_bias, + linear_bias, attn_mask) + # [2, 4, 128] + print(output.shape) + """ + if in_dygraph_mode(): + # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, + # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, + # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out + _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( + x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, + linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', + pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', + dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', + ln_epsilon) + return final_out diff --git a/python/paddle/nn/layer/transformer.py b/python/paddle/nn/layer/transformer.py index eacf5aac9da..36bc8364796 100644 --- a/python/paddle/nn/layer/transformer.py +++ b/python/paddle/nn/layer/transformer.py @@ -26,7 +26,7 @@ from ... import tensor from ...fluid import layers from .. import Layer, LayerList from ...framework import ParamAttr -from ...fluid.data_feeder import convert_dtype +from paddle.fluid.data_feeder import convert_dtype __all__ = [] -- GitLab