未验证 提交 097e098d 编写于 作者: L Li Min 提交者: GitHub

Fix bugs when bias add none in static graph for fused_attention op. (#37566)

* Fix bugs when bias is none for static graph for fused_attention op.
上级 dcb91fd7
......@@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
seq_len = query.shape[1]
embed_dim = query.shape[2]
has_bias = True
if ln_bias is None:
has_bias = False
if (pre_layer_norm):
ln_out = layer_norm(query, True, True, ln_scale, ln_bias)
ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias)
num_head = qkv_weight.shape[1]
head_dim = qkv_weight.shape[2]
......@@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] *
qkv_weight.shape[2] * qkv_weight.shape[3])
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if qkv_bias is not None:
qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] *
qkv_bias.shape[2])
if (pre_layer_norm):
ln_out = ln_out.reshape(batch_size * seq_len, embed_dim)
qkv = fc(ln_out, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
ln_out = ln_out.reshape(batch_size, seq_len, embed_dim)
else:
query = query.reshape(batch_size * seq_len, embed_dim)
qkv = fc(query, qkv_weight)
qkv_bias_out = qkv + qkv_bias
if qkv_bias is not None:
qkv_bias_out = qkv + qkv_bias
else:
qkv_bias_out = qkv
query = query.reshape(batch_size, seq_len, embed_dim)
qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head,
......@@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias,
out_linear_out = fc(out_linear_input, out_linear_weight)
# bias add, dropout, residual add, layer_norm.
out_linear_bias_out = out_linear_out + out_linear_bias
if out_linear_bias is not None:
out_linear_bias_out = out_linear_out + out_linear_bias
else:
out_linear_bias_out = out_linear_out
out_linear_bias_dropout_out = out_linear_bias_out
out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out
if not pre_layer_norm:
out_linear_bias_dropout_residual_out = layer_norm(
out_linear_bias_dropout_residual_out, True, True, ln_2_scale,
out_linear_bias_dropout_residual_out, True, has_bias, ln_2_scale,
ln_2_bias)
return out_linear_bias_dropout_residual_out
class TestFusedAttentionAPI(unittest.TestCase):
def setUp(self):
self.setXType()
self.setPreLn()
self.setAttnMask()
self.setBiasAttr()
self.config()
self.generate_input_data()
def config(self):
def setAttnMask(self):
self.has_attn_mask = True
def setBiasAttr(self):
self.bias_attr = None
def setPreLn(self):
self.pre_layer_norm = False
def setXType(self):
self.x_type = np.float32
def config(self):
self.attn_mask_type = np.float64
self.pre_layer_norm = True
self.has_attn_mask = True
self.training = True
self.need_weight = False
......@@ -172,7 +199,6 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
......@@ -205,23 +231,32 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.embed_dim, self.num_heads, self.dropout_prob,
self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm,
self.need_weight, self.weight_attr, self.bias_attr)
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
if self.bias_attr is not False:
qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype(
'float32')
fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias))
out = fused_attn(
paddle.to_tensor(self.query),
paddle.to_tensor(self.query),
paddle.to_tensor(self.query), attn_mask_tensor)
ref_out = compute_reference(self.pre_layer_norm, self.query,
self.attn_mask,
fused_attn.pre_ln_scale.numpy(),
fused_attn.pre_ln_bias.numpy(),
fused_attn.ln_scale.numpy(),
fused_attn.ln_bias.numpy(),
fused_attn.qkv_weight.numpy(),
fused_attn.qkv_bias.numpy(),
fused_attn.linear_weight.numpy(),
fused_attn.linear_bias.numpy())
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5)
fused_attn_qkv_bias = None
fused_attn_linear_bias = None
fused_attn_pre_ln_bias = None
fused_attn_ln_bias = None
if self.bias_attr is not False:
fused_attn_qkv_bias = fused_attn.qkv_bias.numpy()
fused_attn_linear_bias = fused_attn.linear_bias.numpy()
fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy()
fused_attn_ln_bias = fused_attn.ln_bias.numpy()
ref_out = compute_reference(
self.pre_layer_norm, self.query, self.attn_mask,
fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias,
fused_attn.ln_scale.numpy(), fused_attn_ln_bias,
fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias,
fused_attn.linear_weight.numpy(), fused_attn_linear_bias)
np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-4)
def run_static(self):
fused_attn = FusedMultiHeadAttention(
......@@ -248,27 +283,53 @@ class TestFusedAttentionAPI(unittest.TestCase):
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
qkv_bias = None
linear_bias = None
ln_bias = None
ln_2_bias = None
if self.has_attn_mask:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query,
"SrcMask": self.attn_mask},
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
if self.bias_attr is False:
out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight,
fused_attn.linear_weight, fused_attn.pre_ln_scale,
fused_attn.ln_scale
])
else:
out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run(
paddle.static.default_main_program(),
feed={"X": self.query, },
fetch_list=[
final_out, fused_attn.qkv_weight, fused_attn.qkv_bias,
fused_attn.linear_weight, fused_attn.linear_bias,
fused_attn.pre_ln_scale, fused_attn.pre_ln_bias,
fused_attn.ln_scale, fused_attn.ln_bias
])
return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias
def test_static_api(self):
......@@ -280,7 +341,7 @@ class TestFusedAttentionAPI(unittest.TestCase):
self.attn_mask, ln_scale, ln_bias,
ln_2_scale, ln_2_bias, qkv_weight, qkv_bias,
linear_weight, linear_bias)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5)
np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-4)
def test_dynamic_api(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
......@@ -288,27 +349,16 @@ class TestFusedAttentionAPI(unittest.TestCase):
class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI):
def config(self):
self.x_type = np.float32
self.attn_mask_type = np.float64
self.pre_layer_norm = True
def setAttnMask(self):
self.has_attn_mask = False
self.training = True
self.need_weight = False
self.batch_size = 1
self.query_length = 2
self.head_dim = 2
self.num_heads = 2
self.embed_dim = self.head_dim * self.num_heads
def setPreLn(self):
self.pre_layer_norm = True
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = self.query_length, self.query_length
class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI):
def setBiasAttr(self):
self.bias_attr = False
if __name__ == "__main__":
......
......@@ -388,10 +388,12 @@ def fused_multi_head_attention(x,
if pre_ln_bias:
inputs['LnBias'] = [pre_ln_bias]
inputs['QKVW'] = [qkv_weight]
inputs['QKVBias'] = [qkv_bias]
if qkv_bias is not None:
inputs['QKVBias'] = [qkv_bias]
inputs['SrcMask'] = attn_mask
inputs['OutLinearW'] = [linear_weight]
inputs['OutLinearBias'] = [linear_bias]
if linear_bias is not None:
inputs['OutLinearBias'] = [linear_bias]
if ln_scale:
inputs['Ln2Scale'] = [ln_scale]
if ln_bias:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册