未验证 提交 d4906214 编写于 作者: L Li Min 提交者: GitHub

Fused attention op forward (#35905)

功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
上级 08248db0
......@@ -217,7 +217,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op" "resnet_unit_op")
"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -34,6 +34,9 @@ inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx,
TensorCopySync(*seed, platform::CPUPlace(), &seed_cpu_tensor);
*seed_data = static_cast<uint64_t>(seed_cpu_tensor.data<int>()[0]);
*increment = offset;
} else if (seed && platform::is_cpu_place(seed->place())) {
*seed_data = *(seed->data<int>());
*increment = offset;
} else if (gen_cuda->GetIsInitPy() && (!is_fix_seed)) {
auto seed_offset = gen_cuda->IncrementOffset(offset);
*seed_data = seed_offset.first;
......
......@@ -16,6 +16,7 @@ register_operators(EXCLUDES
fusion_gru_op
fusion_lstm_op
fused_bn_add_activation_op
fused_attention_op
fused_transformer_op
resnet_unit_op)
......@@ -78,6 +79,9 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
# fused_attention_op
op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedAttentionOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnMean"), "Output", "LnMean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnVariance"), "Output", "LnVariance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("LnOut"), "Output", "LnOut",
"FusedAttentionOp");
// qkv_out: [batch_size, seq_len, 3, num_head, dim_head]
OP_INOUT_CHECK(ctx->HasOutput("QKVOut"), "Output", "QKVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKVBiasOut"), "Output", "QKVBiasOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("TransposeOut2"), "Output", "TransposeOut2",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKOut"), "Output", "QKOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
"AttnDropoutMaskOut", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutOut"), "Output", "AttnDropoutOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("FMHAOut"), "Output", "FMHAOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("OutLinearOut"), "Output", "OutLinearOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output",
"BiasDropoutResidualOut", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), "Output", "DropoutMaskOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "FusedAttentionOp");
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("QKVW");
PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument(
"The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(y_dim.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3],
platform::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
x_dim, y_dim));
ctx->SetOutputDim("LnMean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnVariance", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("LnOut", ctx->GetInputDim("X"));
// [batch_size, seq_len, 3, num_head, head_size]
ctx->SetOutputDim("QKVOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
ctx->SetOutputDim("QKVBiasOut",
{x_dim[0], x_dim[1], y_dim[0], y_dim[1], y_dim[2]});
// [3, batch_size, num_head, seq_len, head_size]
ctx->SetOutputDim("TransposeOut2",
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
if (ctx->Attrs().Get<bool>("attn_dropout_is_test") == false) {
ctx->SetOutputDim("AttnDropoutMaskOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
ctx->SetOutputDim("SoftmaxOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
// [batch_size, num_heads, seq_len, head_dim]
ctx->SetOutputDim("QKTVOut", {x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch_size, seq_len, number of heads*head size]
ctx->SetOutputDim("FMHAOut", {x_dim[0], x_dim[1], y_dim[1], y_dim[2]});
ctx->SetOutputDim("OutLinearOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Ln2Mean", {x_dim[0] * x_dim[1]});
ctx->SetOutputDim("Ln2Variance", {x_dim[0] * x_dim[1]});
if (ctx->Attrs().Get<bool>("dropout_is_test") == false) {
ctx->SetOutputDim("DropoutMaskOut", ctx->GetInputDim("X"));
}
ctx->SetOutputDim("BiasDropoutResidualOut", ctx->GetInputDim("X"));
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor.");
AddInput("LnScale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("LnBias",
"(optional) Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("QKVW", "The qkv weight tensor.");
AddInput("QKVBias", "The qkv bias tensor.");
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
.AsDispensable();
AddInput("OutLinearW", "The out_linear weight tensor.");
AddInput("OutLinearBias", "The out_linear bias tensor.");
AddInput("Ln2Scale",
"(optional) Scale is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddInput("Ln2Bias",
"(optional) Bias is a 1-dimensional tensor of size "
"H. Here, H represents the last dimension of its input tensor.")
.AsDispensable();
AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("LnVariance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("LnOut", "The output of pre layer_norm.").AsIntermediate();
AddOutput("QKVOut", "Result after qkv.").AsIntermediate();
AddOutput("QKVBiasOut", "Result after qkv and bias op.").AsIntermediate();
AddOutput("TransposeOut2", "Result in fmha.").AsIntermediate();
AddOutput("QKOut", "Result in fmha.").AsIntermediate();
AddOutput("QKTVOut", "Result in fmha.").AsIntermediate();
AddOutput("SoftmaxOut", "Result in fmha.").AsIntermediate();
AddOutput("AttnDropoutMaskOut", "Result in fmha.").AsIntermediate();
AddOutput("AttnDropoutOut", "Result in fmha.").AsIntermediate();
AddOutput("SrcMaskOut", "Result in fmha.").AsIntermediate();
AddOutput("FMHAOut", "Result after fmha.").AsIntermediate();
AddOutput("OutLinearOut", "Result after out_linear.").AsIntermediate();
AddOutput("DropoutMaskOut", "The random sampled dropout mask.")
.AsIntermediate();
AddOutput("Ln2Mean", "Mean of the current mini batch.").AsIntermediate();
AddOutput("Ln2Variance", "Variance of the current mini batch.")
.AsIntermediate();
AddOutput("BiasDropoutResidualOut",
"Result of residual + dropout(src + bias).")
.AsIntermediate();
AddOutput("Y", "Result after attention.");
AddAttr<bool>("pre_layer_norm",
"if true, the attention op uses pre_layer_norm architecure, "
"else, uses post_layer_norm architecuture. "
"[default false].")
.SetDefault(false);
AddAttr<float>("epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' in Op(LayerNorm) should be between"
"0.0 and 0.001, But received [%s].",
epsilon));
});
// for dropout in fmha.
AddAttr<float>("attn_dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(
drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'attn_dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("attn_dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("attn_dropout_fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
"training. Setting this flag to true is only useful in "
"unittest or for debug that always the same output units "
"will be dropped.")
.SetDefault(true);
AddAttr<int>("attn_dropout_seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"attn_dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"There are two kinds of ways to implement dropout"
"(the mask below is a tensor have the same shape with input"
"the value of mask is 0 or 1, the ratio of 0 is dropout_rate)"
"1. downgrade_in_infer(default), downgrade the outcome at inference "
"time"
" train: out = input * mask"
" inference: out = input * (1.0 - dropout_rate)"
"2. upscale_in_train, upscale the outcome at training time, do nothing "
"in inference"
" train: out = input * mask / ( 1.0 - dropout_rate )"
" inference: out = input"
" dropout op can be removed from the program. the program will be "
"efficient")
.SetDefault("upscale_in_train")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout_rate' must be between 0.0 and 1.0."));
});
AddAttr<bool>("dropout_is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("dropout_fix_seed",
"A flag indicating whether to use a fixed seed to generate "
"random mask. NOTE: DO NOT set this flag to true in "
"training. Setting this flag to true is only useful in "
"unittest or for debug that always the same output units "
"will be dropped.")
.SetDefault(true);
AddAttr<int>("dropout_seed", "Dropout random seed.").SetDefault(0);
AddAttr<std::string>(
"dropout_implementation",
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
"The meaning is the same as 'attn_dropout_implementation'.")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<float>("ln_epsilon",
"Constant for numerical stability [default 1e-5].")
.SetDefault(1e-5)
.AddCustomChecker([](const float &ln_epsilon) {
PADDLE_ENFORCE_EQ(ln_epsilon >= 0.0f && ln_epsilon <= 0.001f, true,
platform::errors::InvalidArgument(
"'epsilon' of the second LayerNorm in Fused "
"attention op should be between"
"0.0 and 0.001, But received [%s].",
ln_epsilon));
});
AddComment(R"DOC(
Add fused attention op whose logic is as follows:
// @input: [batch_size, seq_len, 3, num_head, head_dim]
// @final_out: [batch_size, seq_len, num_heads, head_dim]
if (pre_layernorm)
out = layer_norm(input);
out = compute_qkv(out) + bias;
// fmha module
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mark + out;
out = softmax(out);
out = dropout(out);
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
}
out = out_linear(out);
final_out = layer_norm(residual + dropout(bias + out));
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionOpMaker);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/fused/attention_layer_norm.h"
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FusedAttentionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
auto *input_x = ctx.Input<Tensor>("X");
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_bias = ctx.Input<Tensor>("LnBias");
auto *ln_mean = ctx.Output<Tensor>("LnMean");
auto *ln_var = ctx.Output<Tensor>("LnVariance");
auto *ln_out = ctx.Output<Tensor>("LnOut");
// x: qkv's input [batch_size, seq_len, dim_embed]
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
auto *qkv_out = ctx.Output<Tensor>("QKVOut");
auto *qkv_bias_out = ctx.Output<Tensor>("QKVBiasOut");
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *transpose_out_2 = ctx.Output<Tensor>("TransposeOut2");
auto *qk_out = ctx.Output<Tensor>("QKOut");
auto *qktv_out = ctx.Output<Tensor>("QKTVOut");
auto *softmax_out = ctx.Output<Tensor>("SoftmaxOut");
auto *attn_dropout_mask_out = ctx.Output<Tensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Output<Tensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Output<Tensor>("SrcMaskOut");
auto *fmha_out = ctx.Output<Tensor>("FMHAOut");
auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *out_linear_out = ctx.Output<Tensor>("OutLinearOut");
auto *ln_scale_2 = ctx.Input<Tensor>("Ln2Scale");
auto *ln_bias_2 = ctx.Input<Tensor>("Ln2Bias");
auto *dropout_mask_out = ctx.Output<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Output<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_2 = ctx.Output<Tensor>("Ln2Mean");
auto *ln_var_2 = ctx.Output<Tensor>("Ln2Variance");
const float ln_epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_rate = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
// final output.
auto *out = ctx.Output<Tensor>("Y");
// get data ptr for qkv part.
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data<U>());
auto *ln_mean_data = ln_mean->mutable_data<U>(ctx.GetPlace());
auto *ln_var_data = ln_var->mutable_data<U>(ctx.GetPlace());
auto *ln_out_data = ln_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *qkv_out_data = qkv_out->mutable_data<T>(ctx.GetPlace());
auto *qkv_bias_out_data = qkv_bias_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for FMHA.
auto *transpose_out_2_data =
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *attn_dropout_out_data =
attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *fmha_out_data = fmha_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for out_linear.
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
auto *out_linear_out_data = out_linear_out->mutable_data<T>(ctx.GetPlace());
// get data ptr for bias+dropout+residual+layernorm
auto *ln_scale_2_data =
(ln_scale_2 == nullptr ? nullptr : ln_scale_2->data<U>());
auto *ln_bias_2_data =
(ln_bias_2 == nullptr ? nullptr : ln_bias_2->data<U>());
auto *dropout_mask_out_data =
dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
auto *bias_dropout_residual_out_data =
bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
auto *ln_mean_2_data = ln_mean_2->mutable_data<U>(ctx.GetPlace());
auto *ln_var_2_data = ln_var_2->mutable_data<U>(ctx.GetPlace());
auto *final_out_data = out->mutable_data<T>(ctx.GetPlace());
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
// (transA, transB, compute_bias) = (false, true, true)
auto qkv_compute = AttnMatMul<T>(ctx.cuda_device_context(), false, true,
bsz_seq, output_size, input_size, true);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_rate,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
auto fmha_ref_compute =
FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
dim_head, attn_dropout_param);
output_size = hidden_size;
// (transA, transB, compute_bias) = (false, false, false)
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), false, false, bsz_seq,
output_size, input_size, false);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln_epsilon);
if (pre_layer_norm) {
layer_norm_compute.ComputeForward(x_data, ln_scale_data, ln_bias_data,
ln_out_data, ln_mean_data, ln_var_data);
qkv_compute.ComputeForward(qkv_weight_data, ln_out_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
} else {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
out_linear_compute.ComputeForward(out_linear_weight_data, fmha_out_data,
nullptr, out_linear_out_data, nullptr);
// output = layernorm(residual + dropout(input + bias))
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx.cuda_device_context(), out_linear_out_data, x_data,
out_linear_bias_data, ln_scale_2_data, ln_bias_2_data,
bias_dropout_residual_out_data, dropout_mask_out_data, final_out_data,
ln_mean_2_data, ln_var_2_data);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
ops::FusedAttentionOpKernel<double>,
ops::FusedAttentionOpKernel<plat::float16>);
......@@ -66,7 +66,7 @@ struct DropoutParam {
} else {
pre_fix = pre_fix + "_";
}
dropout_prob = context.Attr<float>(pre_fix + "prob");
dropout_prob = context.Attr<float>(pre_fix + "rate");
auto& dropout_implementation =
context.Attr<std::string>(pre_fix + "implementation");
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
......
......@@ -40,6 +40,9 @@
// need to manually specify them in this map.
std::map<std::string, std::set<std::string>> op_ins_map = {
{"layer_norm", {"X", "Scale", "Bias"}},
{"fused_attention",
{"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW",
"OutLinearBias", "Ln2Scale", "Ln2Bias"}},
{"instance_norm", {"X", "Scale", "Bias"}},
{"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}},
{"label_smooth", {"X", "PriorDist"}},
......@@ -92,6 +95,11 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
{"fused_attention",
{"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2",
"QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut",
"SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean",
"Ln2Variance", "BiasDropoutResidualOut", "Y"}},
{"sync_batch_norm",
{"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance",
"ReserveSpace"}},
......
......@@ -97,6 +97,10 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
endforeach()
if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
LIST(REMOVE_ITEM TEST_OPS test_c_concat)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test import OpTest
class TestFusedAttentionOp(OpTest):
def setUp(self):
self.config()
self.generate_input_data()
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_attention"
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.k_proj = Linear(
self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.v_proj = Linear(
self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
self.out_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr)
paddle.set_default_dtype(np.float32)
self.norm1 = LayerNorm(self.embed_dim)
self.norm2 = LayerNorm(self.embed_dim)
paddle.set_default_dtype(self.x_type)
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length,
self.embed_dim)).astype(self.x_type)
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_query)
q = self.q_proj(ln1_out)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(ln1_out)
v = self.v_proj(ln1_out)
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
qk_out = layers.matmul(
x=q_out, y=k_out, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train")
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
out_linear_in = tensor.reshape(
x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]])
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
final_out = self.norm1(residual_out)
if self.pre_layer_norm:
final_out = self.norm2(residual_out)
return final_out
def GetFusedAttentionOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False)
q_proj_bias = paddle.to_tensor(self.q_proj.bias, stop_gradient=False)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False)
k_proj_bias = paddle.to_tensor(self.k_proj.bias, stop_gradient=False)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False)
v_proj_bias = paddle.to_tensor(self.v_proj.bias, stop_gradient=False)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False)
ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight))
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim))
qkv_bias = np.concatenate(
(q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()))
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05
ln2_epsilon = 1e-05
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, x.dtype)
final_out = F.fused_multi_head_attention(
x, qkv_weight_tensor, out_linear_weight, self.pre_layer_norm,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
out_linear_bias, attn_mask, self.dropout_prob,
self.attn_dropout_prob, ln2_epsilon)
return final_out
def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
def config(self):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
if __name__ == "__main__":
unittest.main()
......@@ -61,6 +61,7 @@ from .common import class_center_sample # noqa: F401
from .conv import conv1d # noqa: F401
from .conv import conv1d_transpose # noqa: F401
from .common import linear # noqa: F401
from .fused_transformer import fused_multi_head_attention # noqa: F401
from .conv import conv2d # noqa: F401
from .conv import conv2d_transpose # noqa: F401
from .conv import conv3d # noqa: F401
......@@ -211,5 +212,6 @@ __all__ = [ #noqa
'layer_norm',
'instance_norm',
'class_center_sample',
'fused_multi_head_attention',
'sparse_attention',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ...fluid.framework import in_dygraph_mode
from paddle import _C_ops
__all__ = []
def fused_multi_head_attention(x,
qkv_weight,
linear_weight,
pre_layer_norm=False,
pre_ln_scale=None,
pre_ln_bias=None,
ln_scale=None,
ln_bias=None,
pre_ln_epsilon=1e-05,
qkv_bias=None,
linear_bias=None,
attn_mask=None,
dropout_rate=0.5,
attn_dropout_rate=0.5,
ln_epsilon=1e-05,
name=None):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces. This API only
support self_attention. The pseudo code is as follows:
if pre_layer_norm:
out = layer_norm(x);
out = linear(out) + qkv)bias
else:
out = linear(x) + bias;
out = transpose(out, perm=[2, 0, 3, 1, 4]);
# extract q, k and v from out.
q = out[0:1,::]
k = out[1:2,::]
v = out[2:3,::]
out = q * k^t;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
out = out_linear(out);
out = layer_norm(x + dropout(linear_bias + out));
Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
`[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture.
Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
to avoid dividing by zero. Default is 1e-5.
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional):
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
Examples:
.. code-block:: python
# required: gpu
import paddle
import paddle.nn.functional as F
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# qkv_weight: [3, num_head, dim_head, dim_embed]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
# qkv_bias: [3, num_head, dim_head]
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
# linear_bias: [embed_dim]
linear_bias = paddle.rand(shape=[128], dtype="float32")
# self attention mask: [batch_size, num_heads, seq_len, seq_len]
attn_mask = paddle.rand(shape=(2, 4, 4, 4), dtype="float32")
# output: [batch_size, seq_len, embed_dim]
output = F.fused_multi_head_attention(
x, qkv_weight, linear_weight, False,
None, None, None, None, 1e-5, qkv_bias,
linear_bias, attn_mask)
# [2, 4, 128]
print(output.shape)
"""
if in_dygraph_mode():
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate',
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon)
return final_out
......@@ -26,7 +26,7 @@ from ... import tensor
from ...fluid import layers
from .. import Layer, LayerList
from ...framework import ParamAttr
from ...fluid.data_feeder import convert_dtype
from paddle.fluid.data_feeder import convert_dtype
__all__ = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册