diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py index 70c2ff5cbc8f23b09a5622795610b5800b4dfe9b..bdaf32ee0726dcbcf362fe1864913126db4904f0 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op_api.py @@ -79,8 +79,12 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, seq_len = query.shape[1] embed_dim = query.shape[2] + has_bias = True + if ln_bias is None: + has_bias = False + if (pre_layer_norm): - ln_out = layer_norm(query, True, True, ln_scale, ln_bias) + ln_out = layer_norm(query, True, has_bias, ln_scale, ln_bias) num_head = qkv_weight.shape[1] head_dim = qkv_weight.shape[2] @@ -89,17 +93,24 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, qkv_weight = qkv_weight.reshape(qkv_weight.shape[0], qkv_weight.shape[1] * qkv_weight.shape[2] * qkv_weight.shape[3]) - qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * - qkv_bias.shape[2]) + if qkv_bias is not None: + qkv_bias = qkv_bias.reshape(qkv_bias.shape[0] * qkv_bias.shape[1] * + qkv_bias.shape[2]) if (pre_layer_norm): ln_out = ln_out.reshape(batch_size * seq_len, embed_dim) qkv = fc(ln_out, qkv_weight) - qkv_bias_out = qkv + qkv_bias + if qkv_bias is not None: + qkv_bias_out = qkv + qkv_bias + else: + qkv_bias_out = qkv ln_out = ln_out.reshape(batch_size, seq_len, embed_dim) else: query = query.reshape(batch_size * seq_len, embed_dim) qkv = fc(query, qkv_weight) - qkv_bias_out = qkv + qkv_bias + if qkv_bias is not None: + qkv_bias_out = qkv + qkv_bias + else: + qkv_bias_out = qkv query = query.reshape(batch_size, seq_len, embed_dim) qkv_bias_out = qkv_bias_out.reshape(batch_size, seq_len, 3, num_head, @@ -140,26 +151,42 @@ def compute_reference(pre_layer_norm, query, attn_mask, ln_scale, ln_bias, out_linear_out = fc(out_linear_input, out_linear_weight) # bias add, dropout, residual add, layer_norm. - out_linear_bias_out = out_linear_out + out_linear_bias + if out_linear_bias is not None: + out_linear_bias_out = out_linear_out + out_linear_bias + else: + out_linear_bias_out = out_linear_out out_linear_bias_dropout_out = out_linear_bias_out out_linear_bias_dropout_residual_out = query + out_linear_bias_dropout_out if not pre_layer_norm: out_linear_bias_dropout_residual_out = layer_norm( - out_linear_bias_dropout_residual_out, True, True, ln_2_scale, + out_linear_bias_dropout_residual_out, True, has_bias, ln_2_scale, ln_2_bias) return out_linear_bias_dropout_residual_out class TestFusedAttentionAPI(unittest.TestCase): def setUp(self): + self.setXType() + self.setPreLn() + self.setAttnMask() + self.setBiasAttr() self.config() self.generate_input_data() - def config(self): + def setAttnMask(self): + self.has_attn_mask = True + + def setBiasAttr(self): + self.bias_attr = None + + def setPreLn(self): + self.pre_layer_norm = False + + def setXType(self): self.x_type = np.float32 + + def config(self): self.attn_mask_type = np.float64 - self.pre_layer_norm = True - self.has_attn_mask = True self.training = True self.need_weight = False @@ -172,7 +199,6 @@ class TestFusedAttentionAPI(unittest.TestCase): self.dropout_prob = 0.0 self.attn_dropout_prob = 0.0 self.weight_attr = None - self.bias_attr = None self.kdim, self.vdim = self.embed_dim, self.embed_dim self.key_length, self.value_length = self.query_length, self.query_length @@ -205,23 +231,32 @@ class TestFusedAttentionAPI(unittest.TestCase): self.embed_dim, self.num_heads, self.dropout_prob, self.attn_dropout_prob, self.kdim, self.vdim, self.pre_layer_norm, self.need_weight, self.weight_attr, self.bias_attr) - qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype('float32') - fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias)) + if self.bias_attr is not False: + qkv_bias = np.random.random(fused_attn.qkv_bias.shape).astype( + 'float32') + fused_attn.qkv_bias.set_value(paddle.to_tensor(qkv_bias)) out = fused_attn( paddle.to_tensor(self.query), paddle.to_tensor(self.query), paddle.to_tensor(self.query), attn_mask_tensor) - ref_out = compute_reference(self.pre_layer_norm, self.query, - self.attn_mask, - fused_attn.pre_ln_scale.numpy(), - fused_attn.pre_ln_bias.numpy(), - fused_attn.ln_scale.numpy(), - fused_attn.ln_bias.numpy(), - fused_attn.qkv_weight.numpy(), - fused_attn.qkv_bias.numpy(), - fused_attn.linear_weight.numpy(), - fused_attn.linear_bias.numpy()) - np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-5) + + fused_attn_qkv_bias = None + fused_attn_linear_bias = None + fused_attn_pre_ln_bias = None + fused_attn_ln_bias = None + if self.bias_attr is not False: + fused_attn_qkv_bias = fused_attn.qkv_bias.numpy() + fused_attn_linear_bias = fused_attn.linear_bias.numpy() + fused_attn_pre_ln_bias = fused_attn.pre_ln_bias.numpy() + fused_attn_ln_bias = fused_attn.ln_bias.numpy() + + ref_out = compute_reference( + self.pre_layer_norm, self.query, self.attn_mask, + fused_attn.pre_ln_scale.numpy(), fused_attn_pre_ln_bias, + fused_attn.ln_scale.numpy(), fused_attn_ln_bias, + fused_attn.qkv_weight.numpy(), fused_attn_qkv_bias, + fused_attn.linear_weight.numpy(), fused_attn_linear_bias) + np.testing.assert_allclose(ref_out, out.numpy(), rtol=1e-5, atol=1e-4) def run_static(self): fused_attn = FusedMultiHeadAttention( @@ -248,27 +283,53 @@ class TestFusedAttentionAPI(unittest.TestCase): place = paddle.CUDAPlace(0) exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) + + qkv_bias = None + linear_bias = None + ln_bias = None + ln_2_bias = None if self.has_attn_mask: - out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( - paddle.static.default_main_program(), - feed={"X": self.query, - "SrcMask": self.attn_mask}, - fetch_list=[ - final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, - fused_attn.linear_weight, fused_attn.linear_bias, - fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, - fused_attn.ln_scale, fused_attn.ln_bias - ]) + if self.bias_attr is False: + out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, + fused_attn.linear_weight, fused_attn.pre_ln_scale, + fused_attn.ln_scale + ]) + else: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, + "SrcMask": self.attn_mask}, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) else: - out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( - paddle.static.default_main_program(), - feed={"X": self.query, }, - fetch_list=[ - final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, - fused_attn.linear_weight, fused_attn.linear_bias, - fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, - fused_attn.ln_scale, fused_attn.ln_bias - ]) + if self.bias_attr is False: + out, qkv_weight, out_linear_weight, ln_scale, ln_2_scale = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, }, + fetch_list=[ + final_out, fused_attn.qkv_weight, + fused_attn.linear_weight, fused_attn.pre_ln_scale, + fused_attn.ln_scale + ]) + else: + out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias = exe.run( + paddle.static.default_main_program(), + feed={"X": self.query, }, + fetch_list=[ + final_out, fused_attn.qkv_weight, fused_attn.qkv_bias, + fused_attn.linear_weight, fused_attn.linear_bias, + fused_attn.pre_ln_scale, fused_attn.pre_ln_bias, + fused_attn.ln_scale, fused_attn.ln_bias + ]) return out, qkv_weight, qkv_bias, out_linear_weight, linear_bias, ln_scale, ln_bias, ln_2_scale, ln_2_bias def test_static_api(self): @@ -280,7 +341,7 @@ class TestFusedAttentionAPI(unittest.TestCase): self.attn_mask, ln_scale, ln_bias, ln_2_scale, ln_2_bias, qkv_weight, qkv_bias, linear_weight, linear_bias) - np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-5) + np.testing.assert_allclose(ref_out, out, rtol=1e-5, atol=1e-4) def test_dynamic_api(self): paddle.disable_static(place=paddle.CUDAPlace(0)) @@ -288,27 +349,16 @@ class TestFusedAttentionAPI(unittest.TestCase): class TestFusedAttentionAPINoneAttnMask(TestFusedAttentionAPI): - def config(self): - self.x_type = np.float32 - self.attn_mask_type = np.float64 - self.pre_layer_norm = True + def setAttnMask(self): self.has_attn_mask = False - self.training = True - self.need_weight = False - self.batch_size = 1 - self.query_length = 2 - self.head_dim = 2 - self.num_heads = 2 - self.embed_dim = self.head_dim * self.num_heads + def setPreLn(self): + self.pre_layer_norm = True - self.dropout_prob = 0.0 - self.attn_dropout_prob = 0.0 - self.weight_attr = None - self.bias_attr = None - self.kdim, self.vdim = self.embed_dim, self.embed_dim - self.key_length, self.value_length = self.query_length, self.query_length +class TestFusedAttentionAPIBiasIsNone(TestFusedAttentionAPI): + def setBiasAttr(self): + self.bias_attr = False if __name__ == "__main__": diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index eafefd98298f542e0d16162e8e8c2bcc861ec622..3569d372fa6dc7ef89b6d1f8e9e0f675ab89dde9 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -388,10 +388,12 @@ def fused_multi_head_attention(x, if pre_ln_bias: inputs['LnBias'] = [pre_ln_bias] inputs['QKVW'] = [qkv_weight] - inputs['QKVBias'] = [qkv_bias] + if qkv_bias is not None: + inputs['QKVBias'] = [qkv_bias] inputs['SrcMask'] = attn_mask inputs['OutLinearW'] = [linear_weight] - inputs['OutLinearBias'] = [linear_bias] + if linear_bias is not None: + inputs['OutLinearBias'] = [linear_bias] if ln_scale: inputs['Ln2Scale'] = [ln_scale] if ln_bias: