未验证 提交 46988e21 编写于 作者: L Li Min 提交者: GitHub

Fix bugs when bias add none in static graph for fused_attention op. (#37566) (#37608)

cherry-pick of PR #37566:

Based on #37411, this PR:

    Continue to fix the bugs when bias add is none in static graph for fused_attention op.
    Polish and improve the unittests in test_fused_attention_op_api.py.
上级 4066713f
...@@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
seq_len = query.shape[1] seq_len = query.shape[1]
embed_dim = query.shape[2] embed_dim = query.shape[2]
has_bias = True
if ln_bias is None:
has_bias = False
if (pre_layer_norm): if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias) ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1] num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2] head_dim = qkv_weight.shape[2]
...@@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] * qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3]) qkv_weight.shape[2] * qkv_weight.shape[3])
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * if qkv_bias is not None:
qkv_bias.shape[2]) qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm): if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight) qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim) ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else: else:
query = query.reshape(batch_size * seq_len, embed_dim) query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight) qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
query = query.reshape(batch_size, seq_len, embed_dim) query = query.reshape(batch_size, seq_len, embed_dim)
qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head, qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
...@@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, ...@@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_out = fc(out_linear_input, out_linear_weight) out_linear_out = fc(out_linear_input, out_linear_weight)
# bias add, dropout, residual add, layer_norm. # bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias if out_linear_bias is not None:
out_linear_bias_out = out_linear_out + out_linear_bias
else:
out_linear_bias_out = out_linear_out
out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
if not pre_layer_norm: if not pre_layer_norm:
out_linear_bias_dropout_residual_out = layer_norm( out_linear_bias_dropout_residual_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale, out_linear_bias_dropout_residual_out, True, has_bias, ln_2_scale,
ln_2_bias) ln_2_bias)
return out_linear_bias_dropout_residual_out return out_linear_bias_dropout_residual_out
class TestFusedAttentionAPI(unittest.TestCase): class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.setXType()
self.setPreLn()
self.setAttnMask()
self.setBiasAttr()
self.config() self.config()
self.generate_input_data() self.generate_input_data()
def config(self): def setAttnMask(self):
self.has_attn_mask = True
def setBiasAttr(self):
self.bias_attr = None
def setPreLn(self):
self.pre_layer_norm = False
def setXType(self):
self.x_type = np.float32 self.x_type = np.float32
def config(self):
self.attn_mask_type = np.float64 self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True self.training = True
self.need_weight = False self.need_weight = False
...@@ -172,7 +199,6 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -172,7 +199,6 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.dropout_prob = 0.0 self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0 self.attn_dropout_prob = 0.0
self.weight_attr = None self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length self.key_length, self.value_length = self.query_length, self.query_length
...@@ -205,23 +231,32 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -205,23 +231,32 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.embed_dim, self.num_heads, self.dropout_prob, self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr) self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32') if self.bias_attr is not False:
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias)) qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
'float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn( out = fused_attn(
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), paddle.to_tensor(self.query),
paddle.to_tensor(self.query), attn_mask_tensor) paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask, fused_attn_qkv_bias = None
fused_attn.pre_ln_scale.numpy(), fused_attn_linear_bias = None
fused_attn.pre_ln_bias.numpy(), fused_attn_pre_ln_bias = None
fused_attn.ln_scale.numpy(), fused_attn_ln_bias = None
fused_attn.ln_bias.numpy(), if self.bias_attr is not False:
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn.qkv_bias.numpy(), fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn.linear_weight.numpy(), fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn.linear_bias.numpy()) fused_attn_ln_bias = fused_attn.ln_bias.numpy()
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-4)
def run_static(self): def run_static(self):
fused_attn = FusedMultiHeadAttention( fused_attn = FusedMultiHeadAttention(
...@@ -248,27 +283,53 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -248,27 +283,53 @@ class TestFusedAttentionAPI(unittest.TestCase):
place = paddle.CUDAPlace(0) place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
qkv_bias = None
linear_bias = None
ln_bias = None
ln_2_bias = None
if self.has_attn_mask: if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.bias_attr is False:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
feed={"X": self.query, paddle.static.default_main_program(),
"SrcMask": self.attn_mask}, feed={"X": self.query,
fetch_list=[ "SrcMask": self.attn_mask},
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out, fused_attn.qkv_weight,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.ln_scale
]) ])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else: else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( if self.bias_attr is False:
paddle.static.default_main_program(), out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
feed={"X": self.query, }, paddle.static.default_main_program(),
fetch_list=[ feed={"X": self.query, },
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, fetch_list=[
fused_attn.linear_weight, fused_attn.linear_bias, final_out, fused_attn.qkv_weight,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale, fused_attn.ln_bias fused_attn.ln_scale
]) ])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self): def test_static_api(self):
...@@ -280,7 +341,7 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -280,7 +341,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias, self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias) linear_weight, linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-4)
def test_dynamic_api(self): def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
...@@ -288,27 +349,16 @@ class TestFusedAttentionAPI(unittest.TestCase): ...@@ -288,27 +349,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI): class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self): def setAttnMask(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = False self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1 def setPreLn(self):
self.query_length = 2 self.pre_layer_norm = True
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
self.key_length, self.value_length = self.query_length, self.query_length def setBiasAttr(self):
self.bias_attr = False
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -388,10 +388,12 @@ def fused_multi_head_attention(x, ...@@ -388,10 +388,12 @@ def fused_multi_head_attention(x,
if pre_ln_bias: if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias] inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight] inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias] if qkv_bias is not None:
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight] inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias] if linear_bias is not None:
inputs['OutLinearBias'] = [linear_bias]
if ln_scale: if ln_scale:
inputs['Ln2Scale'] = [ln_scale] inputs['Ln2Scale'] = [ln_scale]
if ln_bias: if ln_bias:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册