未验证 提交 472dcca4 编写于 作者: L Li Min 提交者: GitHub

【fix-bug】Support attn_mask=None input cases for fused_attention_op. (#36951)

目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
上级 b7e88308
......@@ -69,7 +69,7 @@ class FMHARef {
~FMHARef() {}
void ComputeForward(const Tensor& qkv_input_tensor,
const Tensor& src_mask_tensor,
const Tensor* src_mask_tensor,
Tensor* transpose_2_out_tensor, Tensor* qk_out_tensor,
Tensor* src_mask_out_tensor, Tensor* softmax_out_tensor,
Tensor* dropout_mask_out_tensor,
......@@ -111,17 +111,17 @@ class FMHARef {
blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, q_ptr,
k_ptr, beta, qk_out_data, gemm_batch_size, stride_a,
stride_b);
int softmax_axis = -1;
if (src_mask_tensor != nullptr) {
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(qk_out_tensor);
ins.emplace_back(&src_mask_tensor);
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
int softmax_axis = -1;
if (&src_mask_tensor != nullptr) {
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, elewise_add_axis, AddFunctor<T>());
SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *src_mask_out_tensor,
softmax_axis, softmax_out_tensor);
} else {
......@@ -165,7 +165,7 @@ class FMHARef {
}
void ComputeBackward(
const Tensor& transpose_2_out_tensor, const Tensor& src_mask_tensor,
const Tensor& transpose_2_out_tensor, const Tensor* src_mask_tensor,
const Tensor& softmax_out_tensor, const Tensor& dropout_mask_out_tensor,
const Tensor& dropout_out_tensor, const Tensor& qk_out_tensor,
const Tensor& src_mask_out_tensor, const Tensor& fmha_out_grad_tensor,
......@@ -249,7 +249,7 @@ class FMHARef {
softmax_out_grad_tensor);
}
if (&src_mask_tensor != nullptr) {
if (src_mask_tensor != nullptr) {
SoftmaxBackwardCUDAKernelDriver<T>(dev_ctx_, softmax_out_tensor,
*softmax_out_grad_tensor, softmax_axis,
src_mask_out_grad_tensor);
......
......@@ -27,8 +27,6 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVW"), "Input", "QKVW", "FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionOp");
......@@ -57,8 +55,11 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("QKTVOut"), "Output", "QKTVOut",
"FusedAttentionOp");
if (ctx->HasInput("SrcMask")) {
OP_INOUT_CHECK(ctx->HasOutput("SrcMaskOut"), "Output", "SrcMaskOut",
"FusedAttentionOp");
}
OP_INOUT_CHECK(ctx->HasOutput("SoftmaxOut"), "Output", "SoftmaxOut",
"FusedAttentionOp");
OP_INOUT_CHECK(ctx->HasOutput("AttnDropoutMaskOut"), "Output",
......@@ -119,7 +120,10 @@ class FusedAttentionOp : public framework::OperatorWithKernel {
{y_dim[0], x_dim[0], y_dim[1], x_dim[1], y_dim[2]});
// [batch, num_head, seq_len, seq_len]
ctx->SetOutputDim("QKOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
if (ctx->HasInput("SrcMask")) {
ctx->SetOutputDim("SrcMaskOut", {x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
}
// the same as QKOut's shape.
ctx->SetOutputDim("AttnDropoutOut",
{x_dim[0], y_dim[1], x_dim[1], x_dim[1]});
......@@ -320,7 +324,7 @@ class FusedAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
{
out = transpose(out, perm=[2, 0, 3, 1, 4]);
out = q * k^t;
out = attn_mark + out;
out = attn_mask + out;
out = softmax(out);
out = dropout(out);
out = out * v;
......@@ -368,8 +372,6 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("QKVBias"), "Input", "QKVBias",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("SrcMask"), "Input", "SrcMask",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearW"), "Input", "OutLinearW",
"FusedAttentionGrad");
OP_INOUT_CHECK(ctx->HasInput("OutLinearBias"), "Input", "OutLinearBias",
......@@ -413,8 +415,11 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("SoftmaxOut"));
ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"),
ctx->GetInputDim("AttnDropoutOut"));
if (ctx->HasOutput(framework::GradVarName("SrcMaskOut"))) {
ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"),
ctx->GetInputDim("SrcMaskOut"));
}
ctx->SetOutputDim(framework::GradVarName("QKVOut"),
ctx->GetInputDim("QKVOut"));
ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"),
......@@ -448,7 +453,14 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("X", this->Input("X"));
op->SetInput("QKVW", this->Input("QKVW"));
op->SetInput("QKVBias", this->Input("QKVBias"));
if (this->HasInput("SrcMask")) {
op->SetInput("SrcMask", this->Input("SrcMask"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
}
op->SetInput("OutLinearW", this->Input("OutLinearW"));
op->SetInput("OutLinearBias", this->Input("OutLinearBias"));
......@@ -508,7 +520,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("SoftmaxOut", this->Output("SoftmaxOut"));
op->SetInput("AttnDropoutMaskOut", this->Output("AttnDropoutMaskOut"));
op->SetInput("AttnDropoutOut", this->Output("AttnDropoutOut"));
op->SetInput("SrcMaskOut", this->Output("SrcMaskOut"));
op->SetInput("FMHAOut", this->Output("FMHAOut"));
op->SetInput("OutLinearOut", this->Output("OutLinearOut"));
......@@ -538,8 +550,7 @@ class FusedAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad("SoftmaxOut"));
op->SetOutput(framework::GradVarName("AttnDropoutOut"),
this->OutputGrad("AttnDropoutOut"));
op->SetOutput(framework::GradVarName("SrcMaskOut"),
this->OutputGrad("SrcMaskOut"));
op->SetOutput(framework::GradVarName("FMHAOut"),
this->OutputGrad("FMHAOut"));
op->SetOutput(framework::GradVarName("BiasDropoutResidualOut"),
......
......@@ -114,7 +114,9 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
transpose_out_2->mutable_data<T>(ctx.GetPlace());
auto *qk_out_data = qk_out->mutable_data<T>(ctx.GetPlace());
auto *qktv_out_data = qktv_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data = src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr
: src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *softmax_out_data = softmax_out->mutable_data<T>(ctx.GetPlace());
auto *attn_dropout_mask_out_data =
attn_dropout_mask_out->mutable_data<uint8_t>(ctx.GetPlace());
......@@ -184,10 +186,11 @@ class FusedAttentionOpKernel : public framework::OpKernel<T> {
qkv_compute.ComputeForward(qkv_weight_data, x_data, qkv_bias_data,
qkv_out_data, qkv_bias_out_data);
}
fmha_ref_compute.ComputeForward(*qkv_bias_out, *src_mask, transpose_out_2,
fmha_ref_compute.ComputeForward(*qkv_bias_out, src_mask, transpose_out_2,
qk_out, src_mask_out, softmax_out,
attn_dropout_mask_out, attn_dropout_out,
qktv_out, fmha_out);
// fmha_out: [batch_size, seq_len, num_head, head_dim]
// weight: [embed_dim, embed_dim]
// out_linear_out: [batch_size, seq_len, embed_dim]
......@@ -265,7 +268,8 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *qk_out_data = qk_out->data<T>();
auto *qktv_out_data = qktv_out->data<T>();
auto *softmax_out_data = softmax_out->data<T>();
auto *src_mask_out_data = src_mask_out->data<T>();
auto *src_mask_out_data =
(src_mask == nullptr) ? nullptr : src_mask_out->data<T>();
auto *out_linear_out_data = out_linear_out->data<T>();
auto *ln_2_mean_data = ln_2_mean->data<U>();
auto *ln_2_var_data = ln_2_var->data<U>();
......@@ -302,7 +306,9 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
auto *d_softmax_out_data = d_softmax_out->mutable_data<T>(ctx.GetPlace());
auto *d_attn_dropout_out_data =
d_attn_dropout_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data = d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_src_mask_out_data =
(src_mask == nullptr) ? nullptr
: d_src_mask_out->mutable_data<T>(ctx.GetPlace());
auto *d_fmha_out_data = d_fmha_out->mutable_data<T>(ctx.GetPlace());
auto *d_out_linear_out_data =
d_out_linear_out->mutable_data<T>(ctx.GetPlace());
......@@ -386,7 +392,7 @@ class FusedAttentionGradKernel : public framework::OpKernel<T> {
d_out_linear_out_data, d_fmha_out_data,
d_out_linear_weight_data, nullptr);
fmha_ref_compute.ComputeBackward(
*transpose_out_2, *src_mask, *softmax_out, *attn_dropout_mask_out,
*transpose_out_2, src_mask, *softmax_out, *attn_dropout_mask_out,
*attn_dropout_out, *qk_out, *src_mask_out, *d_fmha_out, d_qktv_out,
d_attn_dropout_out, d_softmax_out, d_src_mask_out, d_qk_out,
d_transpose_out_2, nullptr, d_qkv_bias_out);
......
......@@ -66,6 +66,7 @@ class TestFusedAttentionOp(OpTest):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
......@@ -84,6 +85,7 @@ class TestFusedAttentionOp(OpTest):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
......@@ -93,7 +95,10 @@ class TestFusedAttentionOp(OpTest):
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
self.dout = np.random.random((self.batch_size, self.query_length,
......@@ -102,7 +107,10 @@ class TestFusedAttentionOp(OpTest):
def GetBaselineOut(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
residual = tensor_query
ln1_out = tensor_query
......@@ -187,7 +195,10 @@ class TestFusedAttentionOp(OpTest):
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
x = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False)
qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False)
epsilon = 1e-05
......@@ -218,6 +229,37 @@ class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.batch_size = 8
self.query_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=1e-5, atol=1e-1)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=1e-5, atol=1e-1)
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.batch_size = 8
......@@ -247,6 +289,7 @@ class TestFusedAttentionOpFp16(TestFusedAttentionOp):
self.x_type = np.float16
self.attn_mask_type = np.float64
self.pre_layer_norm = False
self.has_attn_mask = True
self.training = True
self.batch_size = 8
......
......@@ -152,6 +152,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.need_weight = False
......@@ -172,6 +173,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
def generate_input_data(self):
self.query = np.random.rand(self.batch_size, self.query_length,
self.embed_dim).astype(self.x_type)
if self.has_attn_mask:
self.attn_mask = np.ones(
(self.batch_size, self.num_heads, self.query_length,
self.key_length),
......@@ -181,10 +183,17 @@ class TestFusedAttentionAPI(unittest.TestCase):
elif self.attn_mask_type == np.float64:
self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9
else:
raise ValueError("'attn_mask_type' should be 'int64' or 'float64'.")
raise ValueError(
"'attn_mask_type' should be 'int64' or 'float64'.")
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
def run_imperative(self):
if self.has_attn_mask:
attn_mask_tensor = paddle.to_tensor(self.attn_mask)
else:
attn_mask_tensor = None
fused_attn = FusedMultiHeadAttention(
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
......@@ -192,7 +201,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.attn_mask))
paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
......@@ -203,7 +212,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
self.assertTrue(np.allclose(ref_out, out, rtol=1e-5, atol=1e-5))
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
def run_static(self):
fused_attn = FusedMultiHeadAttention(
......@@ -215,6 +224,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
name='X',
shape=[self.batch_size, self.query_length, self.embed_dim],
dtype=self.x_type)
if self.has_attn_mask:
attn_mask = paddle.static.data(
name='SrcMask',
shape=[
......@@ -223,10 +233,13 @@ class TestFusedAttentionAPI(unittest.TestCase):
],
dtype=self.attn_mask_type)
final_out = fused_attn(x, x, x, attn_mask)
else:
final_out = fused_attn(x, x, x)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
......@@ -237,7 +250,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self):
......@@ -249,14 +271,36 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
self.assertTrue(
np.allclose(
np.array(ref_out), np.array(out), rtol=1e-5, atol=1e-5))
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
self.run_imperative()
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册