未验证 提交 472dcca4 编写于 作者: L Li Min 提交者: GitHub

【fix-bug】Support attn_mask=None input cases for fused_attention_op. (#36951)

目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
上级 b7e88308
...@@ -69,7 +69,7 @@ class FMHARef { ...@@ -69,7 +69,7 @@ class FMHARef {
~FMHARef() {} ~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor, void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor, const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor, Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor, Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor, Tensor* dropout_mask_out_tensor,
...@@ -111,17 +111,17 @@ class FMHARef { ...@@ -111,17 +111,17 @@ class FMHARef {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr, blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a, k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b); stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
std::vector<const Tensor*> ins; std::vector<const Tensor*> ins;
std::vector<Tensor*> outs; std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor); ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor); ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor); outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1; int elewise_add_axis = -1;
int softmax_axis = -1;
if (&src_mask_tensor != nullptr) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>( LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>()); dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor, SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor); softmax_axis, softmax_out_tensor);
} else { } else {
...@@ -165,7 +165,7 @@ class FMHARef { ...@@ -165,7 +165,7 @@ class FMHARef {
} }
void ComputeBackward( void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor, const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor, const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor, const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor, const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
...@@ -249,7 +249,7 @@ class FMHARef { ...@@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor); softmax_out_grad_tensor);
} }
if (&src_mask_tensor != nullptr) { if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor, SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis, *softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor); src_mask_out_grad_tensor);
......
...@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp"); OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp"); "FusedAttentionOp");
...@@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut", OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp"); "FusedAttentionOp");
if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut", OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp"); "FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut", OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp"); "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output", OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
...@@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel { ...@@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]}); {y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len] // [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
// the same as QKOut's shape. // the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut", ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]}); {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
...@@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
{ {
out = transpose(out, perm=[2, 0, 3, 1, 4]); out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t; out = q * k^t;
out = attn_mark + out; out = attn_mask + out;
out = softmax(out); out = softmax(out);
out = dropout(out); out = dropout(out);
out = out * v; out = out * v;
...@@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias", OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW", OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad"); "FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias", OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
...@@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ...@@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("SoftmaxOut")); ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut")); ctx->GetInputDim("AttnDropoutOut"));
if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut")); ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"), ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut")); ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
...@@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW")); op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias")); op->SetInput("QKVBias", this->Input("QKVBias"));
if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask")); op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
}
op->SetInput("OutLinearW", this->Input("OutLinearW")); op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias")); op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
...@@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut")); op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut")); op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut")); op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut")); op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut")); op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
...@@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("SoftmaxOut")); this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"), op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut")); this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"), op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut")); this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"), op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
......
...@@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2->mutable_data<T>(ctx.GetPlace()); transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace()); auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace()); auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace()); auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr
: src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace()); auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data = auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace()); attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
...@@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> { ...@@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data, qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data); qkv_out_data, qkv_bias_out_data);
} }
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2, fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out, qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out, attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out); qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim] // fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim] // weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim] // out_linear_out: [batch_size, seq_len, embed_dim]
...@@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qk_out_data = qk_out->data<T>(); auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>(); auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>(); auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>(); auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>(); auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>(); auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>(); auto *ln_2_var_data = ln_2_var->data<U>();
...@@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace()); auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data = auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace()); d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace()); auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr
: d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace()); auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data = auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace()); d_out_linear_out->mutable_data<T>(ctx.GetPlace());
...@@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> { ...@@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data, d_fmha_out_data, d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr); d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward( fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out, *transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out, *attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out, d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out); d_transpose_out_2, nullptr, d_qkv_bias_out);
......
...@@ -66,6 +66,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -66,6 +66,7 @@ class TestFusedAttentionOp(OpTest):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = False self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
...@@ -84,6 +85,7 @@ class TestFusedAttentionOp(OpTest): ...@@ -84,6 +85,7 @@ class TestFusedAttentionOp(OpTest):
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length, self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type) self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones( self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length, (self.batch_size, self.num_heads, self.query_length,
self.key_length), self.key_length),
...@@ -93,7 +95,10 @@ class TestFusedAttentionOp(OpTest): ...@@ -93,7 +95,10 @@ class TestFusedAttentionOp(OpTest):
elif self.attn_mask_type == np.float64: elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else: else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length, self.dout = np.random.random((self.batch_size, self.query_length,
...@@ -102,7 +107,10 @@ class TestFusedAttentionOp(OpTest): ...@@ -102,7 +107,10 @@ class TestFusedAttentionOp(OpTest):
def GetBaselineOut(self): def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False) tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
residual = tensor_query residual = tensor_query
ln1_out = tensor_query ln1_out = tensor_query
...@@ -187,7 +195,10 @@ class TestFusedAttentionOp(OpTest): ...@@ -187,7 +195,10 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False) x = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05 epsilon = 1e-05
...@@ -218,6 +229,37 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp): ...@@ -218,6 +229,37 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
...@@ -247,6 +289,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp): ...@@ -247,6 +289,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.x_type = np.float16 self.x_type = np.float16
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = False self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True self.training = True
self.batch_size = 8 self.batch_size = 8
......
...@@ -152,6 +152,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -152,6 +152,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.x_type = np.float32 self.x_type = np.float32
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True self.training = True
self.need_weight = False self.need_weight = False
...@@ -172,6 +173,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -172,6 +173,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
def generate_input_data(self): def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length, self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type) self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones( self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length, (self.batch_size, self.num_heads, self.query_length,
self.key_length), self.key_length),
...@@ -181,10 +183,17 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -181,10 +183,17 @@ class TestFusedAttentionAPI(unittest.TestCase):
elif self.attn_mask_type == np.float64: elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else: else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.") raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query self.key, self.value = self.query, self.query
def run_imperative(self): def run_imperative(self):
if self.has_attn_mask:
attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else:
attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention( fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
...@@ -192,7 +201,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -192,7 +201,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
out = fused_attn( out = fused_attn(
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask)) paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query, ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn.pre_ln_scale.numpy(),
...@@ -203,7 +212,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -203,7 +212,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.qkv_bias.numpy(), fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(), fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy()) fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5)) np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
def run_static(self): def run_static(self):
fused_attn = FusedMultiHeadAttention( fused_attn = FusedMultiHeadAttention(
...@@ -215,6 +224,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -215,6 +224,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
name='X', name='X',
shape=[self.batch_size, self.query_length, self.embed_dim], shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type) dtype=self.x_type)
if self.has_attn_mask:
attn_mask = paddle.static.data( attn_mask = paddle.static.data(
name='SrcMask', name='SrcMask',
shape=[ shape=[
...@@ -223,10 +233,13 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -223,10 +233,13 @@ class TestFusedAttentionAPI(unittest.TestCase):
], ],
dtype=self.attn_mask_type) dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask) final_out = fused_attn(x, x, x, attn_mask)
else:
final_out = fused_attn(x, x, x)
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(), paddle.static.default_main_program(),
feed={"X": self.query, feed={"X": self.query,
...@@ -237,7 +250,16 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -237,7 +250,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.ln_scale, fused_attn.ln_bias
]) ])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self): def test_static_api(self):
...@@ -249,14 +271,36 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -249,14 +271,36 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias, self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias) linear_weight, linear_bias)
self.assertTrue( np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
def test_dynamic_api(self): def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative() self.run_imperative()
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册