未验证 提交 64643d50 编写于 作者: L Li Min 提交者: GitHub

Add fused attention op backward and python layer. (#36498) (#36752)

功能:本PR的目标是提高attention模块的计算性能。
为了减少框架层对op的调度开销,本PR通过在C++层手动实现attention模块,对外提供attention 大op;
为了减少防存开销,本PR采取了两种优化方法:
(1)在q,k,v计算时通过共享输入X,将该处的gemm,transpose和bias add从三次调用减少为一次;
(2)使用kernel融合优化技术,在不同cuda kernel之间通过寄存器传输数据;
上级 211cf208
......@@ -328,9 +328,206 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class FusedAttentionGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->Attrs().Get<bool>("attn_dropout_is_test"), false,
platform::errors::InvalidArgument(
"GradOp is only callable when attn_dropout_is_test is false"));
OP_INOUT_CHECK(ctx->HasInput("Ln2Mean"), "Input", "Ln2Mean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("Ln2Variance"), "Input", "Ln2Variance",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("Ln2Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Scale"),
ctx->GetInputDim("Ln2Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Ln2Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Ln2Bias"),
ctx->GetInputDim("Ln2Bias"));
}
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnMean"), "Input", "LnMean",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("LnVariance"), "Input", "LnVariance",
"FusedAttentionGrad");
if (ctx->Attrs().Get<bool>("pre_layer_norm") == true) {
OP_INOUT_CHECK(ctx->HasInput("LnOut"), "Input", "LnOut",
"FusedAttentionGrad");
}
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
"FusedAttentionGrad");
if (ctx->HasOutput(framework::GradVarName("LnScale"))) {
ctx->SetOutputDim(framework::GradVarName("LnScale"),
ctx->GetInputDim("LnScale"));
}
if (ctx->HasOutput(framework::GradVarName("LnBias"))) {
ctx->SetOutputDim(framework::GradVarName("LnBias"),
ctx->GetInputDim("LnBias"));
}
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
ctx->SetOutputDim(framework::GradVarName("OutLinearBias"),
ctx->GetInputDim("OutLinearBias"));
ctx->SetOutputDim(framework::GradVarName("OutLinearW"),
ctx->GetInputDim("OutLinearW"));
ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW"));
ctx->SetOutputDim(framework::GradVarName("QKVBias"),
ctx->GetInputDim("QKVBias"));
ctx->SetOutputDim(framework::GradVarName("LnOut"),
ctx->GetInputDim("LnOut"));
ctx->SetOutputDim(framework::GradVarName("FMHAOut"),
ctx->GetInputDim("FMHAOut"));
ctx->SetOutputDim(framework::GradVarName("QKTVOut"),
ctx->GetInputDim("QKTVOut"));
ctx->SetOutputDim(framework::GradVarName("TransposeOut2"),
ctx->GetInputDim("TransposeOut2"));
ctx->SetOutputDim(framework::GradVarName("QKOut"),
ctx->GetInputDim("QKOut"));
ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"),
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
ctx->GetInputDim("QKVBiasOut"));
ctx->SetOutputDim(framework::GradVarName("OutLinearOut"),
ctx->GetInputDim("OutLinearOut"));
ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"),
ctx->GetInputDim("BiasDropoutResidualOut"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto input_data_type = input->type();
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("fused_attention_grad");
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
// inputs x, parameters and their grad.
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
if (this->HasInput("LnScale")) {
op->SetInput("LnScale", this->Input("LnScale"));
op->SetOutput(framework::GradVarName("LnScale"),
this->InputGrad("LnScale"));
}
if (this->HasInput("LnBias")) {
op->SetInput("LnBias", this->Input("LnBias"));
op->SetOutput(framework::GradVarName("LnBias"),
this->InputGrad("LnBias"));
}
if (this->HasInput("Ln2Scale")) {
op->SetInput("Ln2Scale", this->Input("Ln2Scale"));
op->SetOutput(framework::GradVarName("Ln2Scale"),
this->InputGrad("Ln2Scale"));
}
if (this->HasInput("Ln2Bias")) {
op->SetInput("Ln2Bias", this->Input("Ln2Bias"));
op->SetOutput(framework::GradVarName("Ln2Bias"),
this->InputGrad("Ln2Bias"));
}
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("QKVW"), this->InputGrad("QKVW"));
op->SetOutput(framework::GradVarName("QKVBias"),
this->InputGrad("QKVBias"));
op->SetOutput(framework::GradVarName("OutLinearBias"),
this->InputGrad("OutLinearBias"));
op->SetOutput(framework::GradVarName("OutLinearW"),
this->InputGrad("OutLinearW"));
// use forward outputs as backward inputs.
op->SetInput("LnOut", this->Output("LnOut"));
op->SetInput("LnMean", this->Output("LnMean"));
op->SetInput("LnVariance", this->Output("LnVariance"));
op->SetInput("QKVOut", this->Output("QKVOut"));
op->SetInput("QKVBiasOut", this->Output("QKVBiasOut"));
op->SetInput("TransposeOut2", this->Output("TransposeOut2"));
op->SetInput("QKOut", this->Output("QKOut"));
op->SetInput("QKTVOut", this->Output("QKTVOut"));
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
op->SetInput("Ln2Mean", this->Output("Ln2Mean"));
op->SetInput("Ln2Variance", this->Output("Ln2Variance"));
op->SetInput("DropoutMaskOut", this->Output("DropoutMaskOut"));
op->SetInput("BiasDropoutResidualOut",
this->Output("BiasDropoutResidualOut"));
op->SetInput("QKVOut", this->Output("QKVOut"));
// backward outputs: dinput
op->SetOutput(framework::GradVarName("LnOut"), this->OutputGrad("LnOut"));
op->SetOutput(framework::GradVarName("QKVOut"), this->OutputGrad("QKVOut"));
op->SetOutput(framework::GradVarName("QKVBiasOut"),
this->OutputGrad("QKVBiasOut"));
op->SetOutput(framework::GradVarName("QKTVOut"),
this->OutputGrad("QKTVOut"));
op->SetOutput(framework::GradVarName("TransposeOut2"),
this->OutputGrad("TransposeOut2"));
op->SetOutput(framework::GradVarName("QKOut"), this->OutputGrad("QKOut"));
op->SetOutput(framework::GradVarName("SoftmaxOut"),
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
this->OutputGrad("BiasDropoutResidualOut"));
op->SetOutput(framework::GradVarName("OutLinearOut"),
this->OutputGrad("OutLinearOut"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_attention, ops::FusedAttentionOp,
ops::FusedAttentionOpMaker);
ops::FusedAttentionOpMaker,
ops::FusedAttentionGradOpMaker<paddle::framework::OpDesc>,
ops::FusedAttentionGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(fused_attention_grad, ops::FusedAttentionGradOp);
......@@ -199,6 +199,237 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class FusedAttentionGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const auto pre_layer_norm = ctx.Attr<bool>("pre_layer_norm");
const float epsilon = ctx.Attr<float>("epsilon");
const float ln2epsilon = ctx.Attr<float>("ln_epsilon");
float attn_dropout_prob = ctx.Attr<float>("attn_dropout_rate");
bool is_test_1 = ctx.Attr<bool>("attn_dropout_is_test");
auto &dropout_implementation_1 =
ctx.Attr<std::string>("attn_dropout_implementation");
bool is_upscale_in_train_1 =
(dropout_implementation_1 == "upscale_in_train");
auto *seed_1 = ctx.HasInput("Seed1") ? ctx.Input<Tensor>("Seed1") : nullptr;
bool is_fix_seed_1 = ctx.Attr<bool>("attn_dropout_fix_seed");
int seed_val_1 = ctx.Attr<int>("attn_dropout_seed");
// get inputs.
auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *d_y_data = d_y->data<T>();
// fw input
auto *input_x = ctx.Input<Tensor>("X");
auto *ln_scale = ctx.Input<Tensor>("LnScale");
auto *ln_2_scale = ctx.Input<Tensor>("Ln2Scale");
auto *x_data = input_x->data<T>();
auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data<U>());
auto *ln_2_scale_data =
(ln_2_scale == nullptr ? nullptr : ln_2_scale->data<U>());
// fw parameters.
auto *src_mask = ctx.Input<Tensor>("SrcMask");
auto *qkv_weight = ctx.Input<Tensor>("QKVW");
auto *qkv_bias = ctx.Input<Tensor>("QKVBias");
auto *out_linear_weight = ctx.Input<Tensor>("OutLinearW");
auto *out_linear_bias = ctx.Input<Tensor>("OutLinearBias");
auto *src_mask_data = (src_mask == nullptr ? nullptr : src_mask->data<T>());
auto *qkv_weight_data = qkv_weight->data<T>();
auto *qkv_bias_data = qkv_bias->data<T>();
auto *out_linear_weight_data = out_linear_weight->data<T>();
auto *out_linear_bias_data = out_linear_bias->data<T>();
// fw output
auto *ln_mean = ctx.Input<Tensor>("LnMean");
auto *ln_var = ctx.Input<Tensor>("LnVariance");
auto *ln_out = ctx.Input<Tensor>("LnOut");
auto *fmha_out = ctx.Input<Tensor>("FMHAOut");
auto *transpose_out_2 = ctx.Input<Tensor>("TransposeOut2");
auto *qk_out = ctx.Input<Tensor>("QKOut");
auto *qktv_out = ctx.Input<Tensor>("QKTVOut");
auto *softmax_out = ctx.Input<Tensor>("SoftmaxOut");
auto *attn_dropout_mask_out = ctx.Input<Tensor>("AttnDropoutMaskOut");
auto *attn_dropout_out = ctx.Input<Tensor>("AttnDropoutOut");
auto *src_mask_out = ctx.Input<Tensor>("SrcMaskOut");
auto *out_linear_out = ctx.Input<Tensor>("OutLinearOut");
auto *ln_2_mean = ctx.Input<Tensor>("Ln2Mean");
auto *ln_2_var = ctx.Input<Tensor>("Ln2Variance");
auto *dropout_mask_out = ctx.Input<Tensor>("DropoutMaskOut");
auto *bias_dropout_residual_out =
ctx.Input<Tensor>("BiasDropoutResidualOut");
auto *ln_mean_data = ln_mean->data<U>();
auto *ln_var_data = ln_var->data<U>();
auto *ln_out_data = ln_out->data<T>();
auto *fmha_out_data = fmha_out->data<T>();
auto *transpose_out_2_data = transpose_out_2->data<T>();
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
auto *dropout_mask_out_data = dropout_mask_out->data<uint8_t>();
auto *bias_dropout_residual_out_data = bias_dropout_residual_out->data<T>();
// output's grad
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_ln_out = ctx.Output<Tensor>(framework::GradVarName("LnOut"));
auto *d_qkv_out = ctx.Output<Tensor>(framework::GradVarName("QKVOut"));
auto *d_qkv_bias_out =
ctx.Output<Tensor>(framework::GradVarName("QKVBiasOut"));
auto *d_qktv_out = ctx.Output<Tensor>(framework::GradVarName("QKTVOut"));
auto *d_transpose_out_2 =
ctx.Output<Tensor>(framework::GradVarName("TransposeOut2"));
auto *d_qk_out = ctx.Output<Tensor>(framework::GradVarName("QKOut"));
auto *d_softmax_out =
ctx.Output<Tensor>(framework::GradVarName("SoftmaxOut"));
auto *d_attn_dropout_out =
ctx.Output<Tensor>(framework::GradVarName("AttnDropoutOut"));
auto *d_src_mask_out =
ctx.Output<Tensor>(framework::GradVarName("SrcMaskOut"));
auto *d_fmha_out = ctx.Output<Tensor>(framework::GradVarName("FMHAOut"));
auto *d_out_linear_out =
ctx.Output<Tensor>(framework::GradVarName("OutLinearOut"));
auto *d_bias_dropout_residual_out =
ctx.Output<Tensor>(framework::GradVarName("BiasDropoutResidualOut"));
auto *d_x_data = d_x->mutable_data<T>(ctx.GetPlace());
auto *d_ln_out_data = d_ln_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_out_data = d_qkv_out->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_out_data = d_qkv_bias_out->mutable_data<T>(ctx.GetPlace());
auto *d_qktv_out_data = d_qktv_out->mutable_data<T>(ctx.GetPlace());
auto *d_transpose_out_2_data =
d_transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *d_qk_out_data = d_qk_out->mutable_data<T>(ctx.GetPlace());
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
auto *d_bias_dropout_residual_out_data =
d_bias_dropout_residual_out->mutable_data<T>(ctx.GetPlace());
// parameter grad
auto *d_ln_scale = ctx.Output<Tensor>(framework::GradVarName("LnScale"));
auto *d_ln_bias = ctx.Output<Tensor>(framework::GradVarName("LnBias"));
auto *d_qkv_weight = ctx.Output<Tensor>(framework::GradVarName("QKVW"));
auto *d_qkv_bias = ctx.Output<Tensor>(framework::GradVarName("QKVBias"));
auto *d_out_linear_weight =
ctx.Output<Tensor>(framework::GradVarName("OutLinearW"));
auto *d_out_linear_bias =
ctx.Output<Tensor>(framework::GradVarName("OutLinearBias"));
auto *d_ln_2_scale = ctx.Output<Tensor>(framework::GradVarName("Ln2Scale"));
auto *d_ln_2_bias = ctx.Output<Tensor>(framework::GradVarName("Ln2Bias"));
auto *d_ln_scale_data =
(d_ln_scale == nullptr ? nullptr
: d_ln_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_ln_bias_data =
(d_ln_bias == nullptr ? nullptr
: d_ln_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_qkv_weight_data = d_qkv_weight->mutable_data<T>(ctx.GetPlace());
auto *d_qkv_bias_data = d_qkv_bias->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_weight_data =
d_out_linear_weight->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_bias_data =
d_out_linear_bias->mutable_data<T>(ctx.GetPlace());
auto *d_ln_2_scale_data =
(d_ln_2_scale == nullptr ? nullptr : d_ln_2_scale->mutable_data<U>(
ctx.GetPlace()));
auto *d_ln_2_bias_data =
(d_ln_2_bias == nullptr ? nullptr
: d_ln_2_bias->mutable_data<U>(ctx.GetPlace()));
const auto input_x_dims = input_x->dims();
const auto qkv_w_dims = qkv_weight->dims();
int batch_size = input_x_dims[0];
int max_seq_len = input_x_dims[1];
int dim_embed = input_x_dims[2];
int num_head = qkv_w_dims[1];
int dim_head = qkv_w_dims[2];
int bsz_seq = batch_size * max_seq_len;
int hidden_size = num_head * dim_head;
int output_size = 3 * hidden_size;
int input_size = dim_embed;
Tensor d_residual;
d_residual.Resize(input_x_dims);
T *d_residual_data = d_residual.mutable_data<T>(ctx.GetPlace());
bool transA = false;
bool transB = true;
bool compute_bias = true;
auto layer_norm_compute = AttnLayerNorm<T>(ctx.cuda_device_context(),
epsilon, bsz_seq, dim_embed);
auto qkv_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
AttnDropoutParam attn_dropout_param(
is_test_1, dropout_implementation_1, attn_dropout_prob,
is_upscale_in_train_1, is_fix_seed_1, seed_val_1, seed_1);
auto fmha_ref_compute =
FMHARef<T>(ctx.cuda_device_context(), batch_size, max_seq_len, num_head,
dim_head, attn_dropout_param);
output_size = hidden_size;
transA = false;
transB = false;
compute_bias = false;
auto out_linear_compute =
AttnMatMul<T>(ctx.cuda_device_context(), transA, transB, bsz_seq,
output_size, input_size, compute_bias);
DropoutParam dropout_param2(ctx, 0);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx.cuda_device_context(), bsz_seq, dim_embed, dropout_param2,
ln2epsilon);
fused_dropout_layernorm_helper.LayernormResidualDropoutBiasGrad(
ctx.cuda_device_context(), d_y_data, bias_dropout_residual_out_data,
dropout_mask_out_data, ln_2_scale_data, ln_2_mean_data, ln_2_var_data,
d_bias_dropout_residual_out_data, d_ln_2_scale_data, d_ln_2_bias_data,
d_out_linear_out_data, d_out_linear_bias_data, d_residual_data);
out_linear_compute.ComputeBackward(fmha_out_data, out_linear_weight_data,
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
cudaMemcpyAsync(d_qkv_out_data, d_qkv_bias_out_data,
bsz_seq * 3 * num_head * dim_head * sizeof(T),
cudaMemcpyDeviceToDevice);
if (pre_layer_norm) {
qkv_compute.ComputeBackward(ln_out_data, qkv_weight_data,
d_qkv_bias_out_data, d_ln_out_data,
d_qkv_weight_data, d_qkv_bias_data);
layer_norm_compute.ComputeBackward(x_data, d_ln_out_data, ln_scale_data,
ln_mean_data, ln_var_data, d_x_data,
d_ln_scale_data, d_ln_bias_data);
} else {
qkv_compute.ComputeBackward(x_data, qkv_weight_data, d_qkv_bias_out_data,
d_x_data, d_qkv_weight_data, d_qkv_bias_data);
}
// gradient accumulation
std::vector<const Tensor *> ins;
std::vector<Tensor *> outs;
ins.emplace_back(&d_residual);
ins.emplace_back(d_x);
outs.emplace_back(d_x);
int elewise_add_axis = -1;
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
ctx.cuda_device_context(), ins, &outs, elewise_add_axis,
AddFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -207,3 +438,7 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(fused_attention, ops::FusedAttentionOpKernel<float>,
ops::FusedAttentionOpKernel<double>,
ops::FusedAttentionOpKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(fused_attention_grad,
ops::FusedAttentionGradKernel<float>,
ops::FusedAttentionGradKernel<double>,
ops::FusedAttentionGradKernel<plat::float16>);
......@@ -93,6 +93,7 @@ endforeach()
if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
endif()
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
......
......@@ -34,6 +34,8 @@ class TestFusedAttentionOp(OpTest):
self.generate_input_data()
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_attention"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = True
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
......@@ -147,7 +149,9 @@ class TestFusedAttentionOp(OpTest):
final_out = self.norm1(residual_out)
if self.pre_layer_norm:
final_out = self.norm2(residual_out)
return final_out
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, tensor_query.grad
def GetFusedAttentionOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
......@@ -196,13 +200,17 @@ class TestFusedAttentionOp(OpTest):
ln1_scale, ln1_bias, ln2_scale, ln2_bias, epsilon, qkv_bias_tensor,
out_linear_bias, attn_mask, self.dropout_prob,
self.attn_dropout_prob, ln2_epsilon)
return final_out
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True)
return final_out, x.grad
def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut()
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-5)
class TestFusedAttentionOpFp16(TestFusedAttentionOp):
......@@ -226,10 +234,12 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref = self.GetBaselineOut()
final_out = self.GetFusedAttentionOut()
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.nn.functional as F
from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention
from paddle import tensor
from paddle.fluid import layers
from paddle.static import Program, program_guard
import unittest
def fc(x, weight):
return np.matmul(x, weight)
def softmax(x):
np.seterr(invalid='ignore')
output = np.zeros(x.shape, dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
for k in range(x.shape[2]):
x_curr = x[i, j, k, :]
e_x = np.exp(x_curr - np.amax(x_curr))
output[i, j, k, :] = e_x / np.sum(e_x)
return output
def batch_matmul(x, y):
assert x.shape[0] == y.shape[0]
assert x.shape[1] == y.shape[1]
retval = np.zeros(
(x.shape[0], x.shape[1], x.shape[2], y.shape[3]), dtype=np.float64)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
retval[i, j, :, :] = np.matmul(x[i, j, :, :], y[i, j, :, :])
return retval
def layer_norm(x, has_scale, has_bias, weight, bias, epsilon=1e-05):
batch_size, src_len, d_model = x.shape
x = x.reshape((batch_size * src_len, d_model))
mu = np.mean(x, axis=1, keepdims=True)
sigma_squar = np.sum(np.square(x - mu), axis=1) / d_model
x1_up = (x - mu)
x1_down_1 = sigma_squar + epsilon
x1_down = np.sqrt(x1_down_1)
x1_down = x1_down.reshape((x1_down.shape[0], 1))
x1 = x1_up / x1_down
x_scaled = x1
if (has_scale):
x_scaled = weight * x1
x_scaled_bias = x_scaled
if (has_bias):
x_scaled_bias = x_scaled + bias
x_scaled_bias = x_scaled_bias.reshape((batch_size, src_len, d_model))
return x_scaled_bias
def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
out_linear_weight, out_linear_bias):
batch_size = query.shape[0]
seq_len = query.shape[1]
embed_dim = query.shape[2]
if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias)
num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
# embed_dim, 3, num_heads, self.head_dim
qkv_weight = qkv_weight.transpose((3, 0, 1, 2))
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
query = query.reshape(batch_size, seq_len, embed_dim)
qkv = qkv.reshape(batch_size, seq_len, 3, num_head, head_dim)
# q*k^t
qkv = qkv.transpose(
(2, 0, 1, 3, 4)) # 3, batch_size, seq_len, num_head, head_dim
qkv = qkv.transpose(
(0, 1, 3, 2, 4)) # 3, batch_size, num_head, seq_len, head_dim
q = qkv[0:1, ::]
q = q.reshape(batch_size, num_head, seq_len, head_dim)
k = qkv[1:2, ::] #[1, batch_size, num_head, seq_len, head_dim]
k = k.reshape(batch_size, num_head, seq_len, head_dim)
v = qkv[2::]
v = v.reshape(batch_size, num_head, seq_len, head_dim)
k = k.transpose([0, 1, 3, 2]) #[batch_size, num_head, head_dim, seq_len]
qkt = batch_matmul(q, k / np.sqrt(head_dim, dtype=np.float64))
if attn_mask is not None:
if attn_mask.dtype.name == 'int64':
attn_mask = (attn_mask.astype(qkt.dtype) - 1.0) * 1e9
else:
attn_mask = attn_mask.astype(qkt.dtype)
qkt += attn_mask
# softmax
softmax_out = softmax(qkt)
attn_heads = batch_matmul(softmax_out, v)
attn_heads = attn_heads.transpose(
(0, 2, 1, 3)) # [batch_size, seq_len, num_head, head_dim]
# out_linear
out_linear_input = attn_heads.reshape(batch_size, seq_len,
num_head * head_dim)
out_linear_out = fc(out_linear_input, out_linear_weight)
# bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
out_linear_bias_dropout_residual_ln_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, ln_2_bias)
return out_linear_bias_dropout_residual_ln_out
class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self):
self.config()
self.generate_input_data()
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
dtype=self.attn_mask_type)
if self.attn_mask_type == np.int64:
self.attn_mask = np.tril(self.attn_mask)
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
self.key, self.value = self.query, self.query
def run_imperative(self):
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask))
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
fused_attn.pre_ln_bias.numpy(),
fused_attn.ln_scale.numpy(),
fused_attn.ln_bias.numpy(),
fused_attn.qkv_weight.numpy(),
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5))
def run_static(self):
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
x = paddle.static.data(
name='X',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
attn_mask = paddle.static.data(
name='SrcMask',
shape=[
self.batch_size, self.num_heads, self.query_length,
self.key_length
],
dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(Program()):
out, qkv_weight, qkv_bias, linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = self.run_static(
)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
self.assertTrue(
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .layer.fused_transformer import FusedMultiHeadAttention # noqa: F401
__all__ = [ #noqa
'FusedMultiHeadAttention',
]
......@@ -15,6 +15,7 @@
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid import core, dygraph_utils
from paddle import _C_ops
__all__ = []
......@@ -217,8 +218,8 @@ def fused_multi_head_attention(x,
`[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture.
Default False.
pre_layer_norm (bool, optional): whether it is pre_layer_norm (True) or post_layer_norm architecture
(False). Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
......@@ -228,13 +229,19 @@ def fused_multi_head_attention(x,
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional):
attn_mask (Tensor, optional): A tensor used in multi-head attention to prevents attention to
some unwanted positions, usually the paddings or the subsequent positions. It is a tensor
with shape broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`. When the
data type is bool, the unwanted positions have `False` values and the others have `True` values.
When the data type is int, the unwanted positions have 0 values and the others have 1 values.
When the data type is float, the unwanted positions have `-INF` values and the others have 0 values.
It can be None when nothing wanted or needed to be prevented attention to. Default None.
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.
0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.
0 for no dropout. Default 0.5.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
......@@ -248,9 +255,9 @@ def fused_multi_head_attention(x,
# input: [batch_size, seq_len, embed_dim]
x = paddle.rand(shape=(2, 4, 128), dtype="float32")
# qkv_weight: [3, num_head, dim_head, dim_embed]
# qkv_weight: [3, num_head, head_dim, embed_dim]
qkv_weight = paddle.rand(shape=(3, 4, 32, 128), dtype="float32")
# qkv_bias: [3, num_head, dim_head]
# qkv_bias: [3, num_head, head_dim]
qkv_bias = paddle.rand(shape=(3, 4, 32), dtype="float32")
# linear_weight: [embed_dim, embed_dim]
linear_weight = paddle.rand(shape=(128, 128), dtype="float32")
......@@ -271,6 +278,12 @@ def fused_multi_head_attention(x,
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
assert len(qkv_weight.shape
) == 4, "The dims of the shape of qkv_weight should be 4."
assert qkv_weight.shape[
0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert qkv_weight.shape[3] == x.shape[
2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm',
......@@ -278,3 +291,95 @@ def fused_multi_head_attention(x,
dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon',
ln_epsilon)
return final_out
else:
helper = LayerHelper('fused_multi_head_attention', **locals())
dtype = x.dtype
# check dtypes
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_multihead_attention')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_multi_head_attention')
# set inputs
inputs = dict()
inputs['X'] = [x]
if pre_ln_scale:
inputs['LnScale'] = [pre_ln_scale]
if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias]
if ln_scale:
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
inputs['Ln2Bias'] = [ln_bias]
# set attrs
attrs = {
'pre_layer_norm': pre_layer_norm,
'epsilon': pre_ln_epsilon,
'ln_epsilon': ln_epsilon,
'dropout_rate': dropout_rate,
'attn_dropout_rate': attn_dropout_rate
}
# set outputs
pre_ln_mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
pre_ln_variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
pre_ln_out = helper.create_variable_for_type_inference(dtype=dtype)
qkv_out = helper.create_variable_for_type_inference(dtype=dtype)
qkv_bias_out = helper.create_variable_for_type_inference(dtype=dtype)
transpose_out = helper.create_variable_for_type_inference(dtype=dtype)
qk_out = helper.create_variable_for_type_inference(dtype=dtype)
qktv_out = helper.create_variable_for_type_inference(dtype=dtype)
softmax_out = helper.create_variable_for_type_inference(dtype=dtype)
attn_dropout_mask_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
attn_dropout_out = helper.create_variable_for_type_inference(
dtype=dtype)
attn_mask_out = helper.create_variable_for_type_inference(dtype=dtype)
fmha_out = helper.create_variable_for_type_inference(dtype=dtype)
out_linear_out = helper.create_variable_for_type_inference(dtype=dtype)
dropout_mask_out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True)
ln_mean_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
ln_variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
bias_dropout_residual_out = helper.create_variable_for_type_inference(
dtype=dtype)
final_out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='fused_attention',
inputs=inputs,
outputs={
"LnMean": pre_ln_mean_out,
"LnVariance": pre_ln_variance_out,
"LnOut": pre_ln_out,
"QKVOut": qkv_out,
"QKVBiasOut": qkv_bias_out,
"TransposeOut2": transpose_out,
"QKOut": qk_out,
"QKTVOut": qktv_out,
"SoftmaxOut": softmax_out,
"AttnDropoutMaskOut": attn_dropout_mask_out,
"AttnDropoutOut": attn_dropout_out,
"SrcMaskOut": attn_mask_out,
"FMHAOut": fmha_out,
"OutLinearOut": out_linear_out,
"DropoutMaskOut": dropout_mask_out,
"Ln2Mean": ln_mean_out,
"Ln2Variance": ln_variance_out,
"BiasDropoutResidualOut": bias_dropout_residual_out,
'Y': final_out
},
attrs=attrs)
return final_out
......@@ -12,27 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from paddle.nn import functional as F
from paddle.incubate.nn import functional as incubate_f
from paddle.nn import Layer
from paddle.framework import ParamAttr
import paddle
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle.nn.initializer import Constant
import collections
class FusedMultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces.
Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
for more details.
Parameters:
embed_dim (int): The expected feature size in the input and output.
num_heads (int): The number of heads in multi-head attention.
dropout (float, optional): The dropout probability used on attention
weights to drop some attention targets. 0 for no dropout. Default 0
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.5.
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.5.
kdim (int, optional): The feature size in key. If None, assumed equal to
`embed_dim`. Default None.
vdim (int, optional): The feature size in value. If None, assumed equal to
`embed_dim`. Default None.
normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True)
or post_layer_norm architecture (False). Default False.
need_weights (bool, optional): Indicate whether to return the attention
weights. Default False.
weights. Now, only False is supported. Default False.
weight_attr(ParamAttr, optional): To specify the weight parameter property.
Default: None, which means the default weight parameter property is used.
See usage for details in :code:`ParamAttr` .
......@@ -40,35 +55,84 @@ class FusedMultiHeadAttention(Layer):
Default: None, which means the default bias parameter property is used.
If it is set to False, this layer will not have trainable bias parameter.
See usage for details in :code:`ParamAttr` .
Examples:
.. code-block:: python
import paddle
# encoder input: [batch_size, sequence_length, d_model]
# input: [batch_size, sequence_length, embed_dim]
query = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, num_heads, query_len, query_len]
attn_mask = paddle.rand((2, 2, 4, 4))
multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
multi_head_attn = paddle.incubate.nn.FusedMultiHeadAttention(128, 2)
output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
"""
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
dropout_rate=0.5,
attn_dropout_rate=0.5,
kdim=None,
vdim=None,
normalize_before=False,
need_weights=False,
weight_attr=None,
bias_attr=None):
bias_attr=None,
name=None):
super(FusedMultiHeadAttention, self).__init__()
raise NotImplementedError()
assert embed_dim > 0, ("Expected embed_dim to be greater than 0, "
"but recieved {}".format(embed_dim))
assert num_heads > 0, ("Expected nhead to be greater than 0, "
"but recieved {}".format(num_heads))
attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate
self.normalize_before = normalize_before
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
assert need_weights == False, "Only support need_weight is False now."
self.qkv_weight = self.create_parameter(
shape=[3, num_heads, self.head_dim, embed_dim],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.qkv_bias = self.create_parameter(
shape=[3, num_heads, self.head_dim],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.linear_weight = self.create_parameter(
shape=[embed_dim, embed_dim],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.linear_bias = self.create_parameter(
shape=[embed_dim],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True)
self.pre_ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.pre_ln_bias = self.create_parameter(
attr=self._bias_attr, shape=[embed_dim], is_bias=True)
self.ln_scale = self.create_parameter(
attr=self._weight_attr,
shape=[embed_dim],
default_initializer=Constant(value=1.0))
self.ln_bias = self.create_parameter(
attr=self._bias_attr, shape=[embed_dim], is_bias=True)
self.dropout_rate = dropout_rate
self.attn_dropout_rate = attn_dropout_rate
self.name = name
def forward(self, query, key=None, value=None, attn_mask=None, cache=None):
"""
......@@ -97,30 +161,34 @@ class FusedMultiHeadAttention(Layer):
`-INF` values and the others have 0 values. It can be None when
nothing wanted or needed to be prevented attention to. Default None.
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors
shaped `[batch_size, num_heads, length, embed_dim]` which are results
of linear projection, reshape and transpose calculations in
MultiHeadAttention. If it is an instance of `Cache`, `k` and `v`
fields reserve intermediate results of previous positions, which
mostly used for decoder self attention. If it is an instance of
`StaticCache`, `key` and `value` args would be ignored, `k` and
`v` fields would be used as calculated results on `key` and
`value`, which mostly used for decoder-encoder cross attention.
It is only used for inference and should be None for training.
Default None.
Now, only None is supported. Default None.
Returns:
Tensor|tuple: It is a tensor that has the same shape and data type \
as `query`, representing attention output. Or a tuple if \
`need_weights` is True or `cache` is not None. If `need_weights` \
is True, except for attention output, the tuple also includes \
the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
If `cache` is not None, the tuple then includes the new cache \
having the same type as `cache`, and if it is `StaticCache`, it \
is same as the input `cache`, if it is `Cache`, the new cache \
reserves tensors concatanating raw tensors with intermediate \
results of current query.
as `query`, representing attention output.
"""
raise NotImplementedError()
if attn_mask is not None:
# Support bool or int mask
attn_mask = _convert_attention_mask(attn_mask, query.dtype)
assert cache == None, "Only support cache is None now."
out = incubate_f.fused_multi_head_attention(
x=query,
qkv_weight=self.qkv_weight,
linear_weight=self.linear_weight,
pre_layer_norm=self.normalize_before,
pre_ln_scale=self.pre_ln_scale,
pre_ln_bias=self.pre_ln_bias,
ln_scale=self.ln_scale,
ln_bias=self.ln_bias,
pre_ln_epsilon=1e-05,
qkv_bias=self.qkv_bias,
linear_bias=self.linear_bias,
attn_mask=attn_mask,
dropout_rate=self.dropout_rate,
attn_dropout_rate=self.attn_dropout_rate,
ln_epsilon=1e-05)
return out
class FusedFeedForward(Layer):
......@@ -187,6 +255,7 @@ class FusedTransformerEncoderLayer(Layer):
.. code-block:: python
# required: gpu
import paddle
from paddle.nn import TransformerEncoderLayer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册