未验证 提交 64643d50 编写于 作者: L Li Min 提交者: GitHub

Add fused attention op backward and python layer. (#36498) (#36752)

功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
上级 211cf208
...@@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class FusedAttentionGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_attention_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));
// use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut"));
op->SetInput("LnMean", this->Output("LnMean"));
op->SetInput("LnVariance", this->Output("LnVariance"));
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
this->OutputGrad("TransposeOut2"));
op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
op->SetOutput(framework::GradVarName("SoftmaxOut"),
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp, REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionOpMaker); ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
...@@ -199,6 +199,237 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -199,6 +199,237 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
// get inputs.
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *d_y_data = d_y->data<T>();
// fw input
auto *input_x = ctx.Input<Tensor>("X");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_2_scale = ctx.Input<Tensor>("Ln2Scale");
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_2_scale_data =
(ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
// fw parameters.
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
// fw output
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut");
auto *qktv_out = ctx.Input<Tensor>("QKTVOut");
auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
auto *attn_dropout_mask_out = ctx.Input<Tensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<Tensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<Tensor>("SrcMaskOut");
auto *out_linear_out = ctx.Input<Tensor>("OutLinearOut");
auto *ln_2_mean = ctx.Input<Tensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<Tensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out =
ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
auto *d_qktv_out = ctx.Output<Tensor>(framework::GradVarName("QKTVOut"));
auto *d_transpose_out_2 =
ctx.Output<Tensor>(framework::GradVarName("TransposeOut2"));
auto *d_qk_out = ctx.Output<Tensor>(framework::GradVarName("QKOut"));
auto *d_softmax_out =
ctx.Output<Tensor>(framework::GradVarName("SoftmaxOut"));
auto *d_attn_dropout_out =
ctx.Output<Tensor>(framework::GradVarName("AttnDropoutOut"));
auto *d_src_mask_out =
ctx.Output<Tensor>(framework::GradVarName("SrcMaskOut"));
auto *d_fmha_out = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
auto *d_out_linear_out =
ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
auto *d_transpose_out_2_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *d_qk_out_data = d_qk_out->mutable_data<T>(ctx.GetPlace());
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight =
ctx.Output<Tensor>(framework::GradVarName("OutLinearW"));
auto *d_out_linear_bias =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
Tensor d_residual;
d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
bool transA = false;
bool transB = true;
bool compute_bias = true;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
auto fmha_ref_compute =
FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
dim_head, attn_dropout_param);
output_size = hidden_size;
transA = false;
transB = false;
compute_bias = false;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data,
bsz_seq * 3 * num_head * dim_head * sizeof(T),
cudaMemcpyDeviceToDevice);
if (pre_layer_norm) {
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
}
// gradient accumulation
std::vector<const Tensor *> ins;
std::vector<Tensor *> outs;
ins.emplace_back(&d_residual);
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
AddFunctor<T>());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -207,3 +438,7 @@ namespace plat = paddle::platform; ...@@ -207,3 +438,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>, REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
ops::FusedAttentionOpKernel<double>, ops::FusedAttentionOpKernel<double>,
ops::FusedAttentionOpKernel<plat::float16>); ops::FusedAttentionOpKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
ops::FusedAttentionGradKernel<float>,
ops::FusedAttentionGradKernel<double>,
ops::FusedAttentionGradKernel<plat::float16>);
...@@ -93,6 +93,7 @@ endforeach() ...@@ -93,6 +93,7 @@ endforeach()
if(NOT WITH_GPU) if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
endif() endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......
...@@ -34,6 +34,8 @@ class TestFusedAttentionOp(OpTest): ...@@ -34,6 +34,8 @@ class TestFusedAttentionOp(OpTest):
self.generate_input_data() self.generate_input_data()
paddle.set_default_dtype(self.x_type) paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_attention" self.__class__.op_type = "fused_attention"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = True
self.q_proj = Linear( self.q_proj = Linear(
self.embed_dim, self.embed_dim,
self.embed_dim, self.embed_dim,
...@@ -147,7 +149,9 @@ class TestFusedAttentionOp(OpTest): ...@@ -147,7 +149,9 @@ class TestFusedAttentionOp(OpTest):
final_out = self.norm1(residual_out) final_out = self.norm1(residual_out)
if self.pre_layer_norm: if self.pre_layer_norm:
final_out = self.norm2(residual_out) final_out = self.norm2(residual_out)
return final_out paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad
def GetFusedAttentionOut(self): def GetFusedAttentionOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
...@@ -196,13 +200,17 @@ class TestFusedAttentionOp(OpTest): ...@@ -196,13 +200,17 @@ class TestFusedAttentionOp(OpTest):
ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor, ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
out_linear_bias, attn_mask, self.dropout_prob, out_linear_bias, attn_mask, self.dropout_prob,
self.attn_dropout_prob, ln2_epsilon) self.attn_dropout_prob, ln2_epsilon)
return final_out paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, x.grad
def test_fused_attention_op(self): def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)
class TestFusedAttentionOpFp16(TestFusedAttentionOp): class TestFusedAttentionOpFp16(TestFusedAttentionOp):
...@@ -226,10 +234,12 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp): ...@@ -226,10 +234,12 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.key_length, self.value_length = self.query_length, self.query_length self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self): def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut() final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut() final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose( np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1) final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention
from paddle import tensor
from paddle.fluid import layers
from paddle.static import Program, program_guard
import unittest
def fc(x, weight):
return np.matmul(x, weight)
def softmax(x):
np.seterr(invalid='ignore')
output = np.zeros(x.shape, dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
x_curr = x[i, j, k, :]
e_x = np.exp(x_curr - np.amax(x_curr))
output[i, j, k, :] = e_x / np.sum(e_x)
return output
def batch_matmul(x, y):
assert x.shape[0] == y.shape[0]
assert x.shape[1] == y.shape[1]
retval = np.zeros(
(x.shape[0], x.shape[1], x.shape[2], y.shape[3]), dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
retval[i, j, :, :] = np.matmul(x[i, j, :, :], y[i, j, :, :])
return retval
def layer_norm(x, has_scale, has_bias, weight, bias, epsilon=1e-05):
batch_size, src_len, d_model = x.shape
x = x.reshape((batch_size * src_len, d_model))
mu = np.mean(x, axis=1, keepdims=True)
sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model
x1_up = (x - mu)
x1_down_1 = sigma_squar + epsilon
x1_down = np.sqrt(x1_down_1)
x1_down = x1_down.reshape((x1_down.shape[0], 1))
x1 = x1_up / x1_down
x_scaled = x1
if (has_scale):
x_scaled = weight * x1
x_scaled_bias = x_scaled
if (has_bias):
x_scaled_bias = x_scaled + bias
x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model))
return x_scaled_bias
def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
out_linear_weight, out_linear_bias):
batch_size = query.shape[0]
seq_len = query.shape[1]
embed_dim = query.shape[2]
if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias)
num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
# embed_dim, 3, num_heads, self.head_dim
qkv_weight = qkv_weight.transpose((3, 0, 1, 2))
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
query = query.reshape(batch_size, seq_len, embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim)
# q*k^t
qkv = qkv.transpose(
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
qkv = qkv.transpose(
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim
q = qkv[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim)
k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
qkt = batch_matmul(q, k / np.sqrt(head_dim, dtype=np.float64))
if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask
# softmax
softmax_out = softmax(qkt)
attn_heads = batch_matmul(softmax_out, v)
attn_heads = attn_heads.transpose(
(0, 2, 1, 3)) # [batch_size, seq_len, num_head, head_dim]
# out_linear
out_linear_input = attn_heads.reshape(batch_size, seq_len,
num_head * head_dim)
out_linear_out = fc(out_linear_input, out_linear_weight)
# bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias)
return out_linear_bias_dropout_residual_ln_out
class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self):
self.config()
self.generate_input_data()
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
self.key, self.value = self.query, self.query
def run_imperative(self):
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask))
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
fused_attn.pre_ln_bias.numpy(),
fused_attn.ln_scale.numpy(),
fused_attn.ln_bias.numpy(),
fused_attn.qkv_weight.numpy(),
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5))
def run_static(self):
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
x = paddle.static.data(
name='X',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
attn_mask = paddle.static.data(
name='SrcMask',
shape=[
self.batch_size, self.num_heads, self.query_length,
self.key_length
],
dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(Program()):
out, qkv_weight, qkv_bias, linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = self.run_static(
)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
self.assertTrue(
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
]
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid import core, dygraph_utils
from paddle import _C_ops from paddle import _C_ops
__all__ = [] __all__ = []
...@@ -217,8 +218,8 @@ def fused_multi_head_attention(x, ...@@ -217,8 +218,8 @@ def fused_multi_head_attention(x,
`[batch\_size, sequence\_len, embed\_dim]`. `[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`. qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`. linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture. pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
Default False. (False). Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None. pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None. pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. Default None. ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
...@@ -228,13 +229,19 @@ def fused_multi_head_attention(x, ...@@ -228,13 +229,19 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`. qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None. Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None. linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional): attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
It can be None when nothing wanted or needed to be prevented attention to. Default None.
dropout_rate (float, optional): The dropout probability used on attention dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention. weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0. 0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional): The dropout probability used on attention attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention. weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0. 0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5. to avoid dividing by zero. Default is 1e-5.
...@@ -248,9 +255,9 @@ def fused_multi_head_attention(x, ...@@ -248,9 +255,9 @@ def fused_multi_head_attention(x,
# input: [batch_size, seq_len, embed_dim] # input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32") x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# qkv_weight: [3, num_head, dim_head, dim_embed] # qkv_weight: [3, num_head, head_dim, embed_dim]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32") qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
# qkv_bias: [3, num_head, dim_head] # qkv_bias: [3, num_head, head_dim]
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32") qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim] # linear_weight: [embed_dim, embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32") linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
...@@ -271,6 +278,12 @@ def fused_multi_head_attention(x, ...@@ -271,6 +278,12 @@ def fused_multi_head_attention(x,
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
assert len(qkv_weight.shape
) == 4, "The dims of the shape of qkv_weight should be 4."
assert qkv_weight.shape[
0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert qkv_weight.shape[3] == x.shape[
2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention( _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask, x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
...@@ -278,3 +291,95 @@ def fused_multi_head_attention(x, ...@@ -278,3 +291,95 @@ def fused_multi_head_attention(x,
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon) ln_epsilon)
return final_out return final_out
else:
helper = LayerHelper('fused_multi_head_attention', **locals())
dtype = x.dtype
# check dtypes
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_multihead_attention')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_multi_head_attention')
# set inputs
inputs = dict()
inputs['X'] = [x]
if pre_ln_scale:
inputs['LnScale'] = [pre_ln_scale]
if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias]
if ln_scale:
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
inputs['Ln2Bias'] = [ln_bias]
# set attrs
attrs = {
'pre_layer_norm': pre_layer_norm,
'epsilon': pre_ln_epsilon,
'ln_epsilon': ln_epsilon,
'dropout_rate': dropout_rate,
'attn_dropout_rate': attn_dropout_rate
}
# set outputs
pre_ln_mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
pre_ln_variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
pre_ln_out = helper.create_variable_for_type_inference(dtype=dtype)
qkv_out = helper.create_variable_for_type_inference(dtype=dtype)
qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype)
transpose_out = helper.create_variable_for_type_inference(dtype=dtype)
qk_out = helper.create_variable_for_type_inference(dtype=dtype)
qktv_out = helper.create_variable_for_type_inference(dtype=dtype)
softmax_out = helper.create_variable_for_type_inference(dtype=dtype)
attn_dropout_mask_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
attn_dropout_out = helper.create_variable_for_type_inference(
dtype=dtype)
attn_mask_out = helper.create_variable_for_type_inference(dtype=dtype)
fmha_out = helper.create_variable_for_type_inference(dtype=dtype)
out_linear_out = helper.create_variable_for_type_inference(dtype=dtype)
dropout_mask_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
ln_mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
ln_variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
bias_dropout_residual_out = helper.create_variable_for_type_inference(
dtype=dtype)
final_out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fused_attention',
inputs=inputs,
outputs={
"LnMean": pre_ln_mean_out,
"LnVariance": pre_ln_variance_out,
"LnOut": pre_ln_out,
"QKVOut": qkv_out,
"QKVBiasOut": qkv_bias_out,
"TransposeOut2": transpose_out,
"QKOut": qk_out,
"QKTVOut": qktv_out,
"SoftmaxOut": softmax_out,
"AttnDropoutMaskOut": attn_dropout_mask_out,
"AttnDropoutOut": attn_dropout_out,
"SrcMaskOut": attn_mask_out,
"FMHAOut": fmha_out,
"OutLinearOut": out_linear_out,
"DropoutMaskOut": dropout_mask_out,
"Ln2Mean": ln_mean_out,
"Ln2Variance": ln_variance_out,
"BiasDropoutResidualOut": bias_dropout_residual_out,
'Y': final_out
},
attrs=attrs)
return final_out
...@@ -12,27 +12,42 @@ ...@@ -12,27 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from paddle.nn import functional as F
from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer
from paddle.framework import ParamAttr
import paddle
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.nn.initializer import Constant
import collections
class FusedMultiHeadAttention(Layer): class FusedMultiHeadAttention(Layer):
""" """
Attention mapps queries and a set of key-value pairs to outputs, and Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces. to information from different representation subspaces.
Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_ Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
for more details. for more details.
Parameters: Parameters:
embed_dim (int): The expected feature size in the input and output. embed_dim (int): The expected feature size in the input and output.
num_heads (int): The number of heads in multi-head attention. num_heads (int): The number of heads in multi-head attention.
dropout (float, optional): The dropout probability used on attention dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets. 0 for no dropout. Default 0 weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.5.
kdim (int, optional): The feature size in key. If None, assumed equal to kdim (int, optional): The feature size in key. If None, assumed equal to
`embed_dim`. Default None. `embed_dim`. Default None.
vdim (int, optional): The feature size in value. If None, assumed equal to vdim (int, optional): The feature size in value. If None, assumed equal to
`embed_dim`. Default None. `embed_dim`. Default None.
normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True)
or post_layer_norm architecture (False). Default False.
need_weights (bool, optional): Indicate whether to return the attention need_weights (bool, optional): Indicate whether to return the attention
weights. Default False. weights. Now, only False is supported. Default False.
weight_attr(ParamAttr, optional): To specify the weight parameter property. weight_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used. Default: None, which means the default weight parameter property is used.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
...@@ -40,35 +55,84 @@ class FusedMultiHeadAttention(Layer): ...@@ -40,35 +55,84 @@ class FusedMultiHeadAttention(Layer):
Default: None, which means the default bias parameter property is used. Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter. If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` . See usage for details in :code:`ParamAttr` .
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
# input: [batch_size, sequence_length, embed_dim]
# encoder input: [batch_size, sequence_length, d_model]
query = paddle.rand((2, 4, 128)) query = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, num_heads, query_len, query_len] # self attention mask: [batch_size, num_heads, query_len, query_len]
attn_mask = paddle.rand((2, 2, 4, 4)) attn_mask = paddle.rand((2, 2, 4, 4))
multi_head_attn = paddle.nn.MultiHeadAttention(128, 2) multi_head_attn = paddle.incubate.nn.FusedMultiHeadAttention(128, 2)
output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128] output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
""" """
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self, def __init__(self,
embed_dim, embed_dim,
num_heads, num_heads,
dropout=0., dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None, kdim=None,
vdim=None, vdim=None,
normalize_before=False,
need_weights=False, need_weights=False,
weight_attr=None, weight_attr=None,
bias_attr=None): bias_attr=None,
name=None):
super(FusedMultiHeadAttention, self).__init__() super(FusedMultiHeadAttention, self).__init__()
raise NotImplementedError()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[embed_dim, embed_dim],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(
shape=[embed_dim],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.pre_ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(
attr=self._bias_attr, shape=[embed_dim], is_bias=True)
self.ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(
attr=self._bias_attr, shape=[embed_dim], is_bias=True)
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.name = name
def forward(self, query, key=None, value=None, attn_mask=None, cache=None): def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
""" """
...@@ -97,30 +161,34 @@ class FusedMultiHeadAttention(Layer): ...@@ -97,30 +161,34 @@ class FusedMultiHeadAttention(Layer):
`-INF` values and the others have 0 values. It can be None when `-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None. nothing wanted or needed to be prevented attention to. Default None.
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional): cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors Now, only None is supported. Default None.
shaped `[batch_size, num_heads, length, embed_dim]` which are results
of linear projection, reshape and transpose calculations in
MultiHeadAttention. If it is an instance of `Cache`, `k` and `v`
fields reserve intermediate results of previous positions, which
mostly used for decoder self attention. If it is an instance of
`StaticCache`, `key` and `value` args would be ignored, `k` and
`v` fields would be used as calculated results on `key` and
`value`, which mostly used for decoder-encoder cross attention.
It is only used for inference and should be None for training.
Default None.
Returns: Returns:
Tensor|tuple: It is a tensor that has the same shape and data type \ Tensor|tuple: It is a tensor that has the same shape and data type \
as `query`, representing attention output. Or a tuple if \ as `query`, representing attention output.
`need_weights` is True or `cache` is not None. If `need_weights` \
is True, except for attention output, the tuple also includes \
the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
If `cache` is not None, the tuple then includes the new cache \
having the same type as `cache`, and if it is `StaticCache`, it \
is same as the input `cache`, if it is `Cache`, the new cache \
reserves tensors concatanating raw tensors with intermediate \
results of current query.
""" """
raise NotImplementedError() if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, query.dtype)
assert cache == None, "Only support cache is None now."
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=1e-05,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=1e-05)
return out
class FusedFeedForward(Layer): class FusedFeedForward(Layer):
...@@ -186,7 +254,8 @@ class FusedTransformerEncoderLayer(Layer): ...@@ -186,7 +254,8 @@ class FusedTransformerEncoderLayer(Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
# required: gpu
import paddle import paddle
from paddle.nn import TransformerEncoderLayer from paddle.nn import TransformerEncoderLayer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册