未验证 提交 7696ae02 编写于 作者: Y YUNSHEN XIE 提交者: GitHub

[Cherry-pick] add condition of skipif (#49407)

* resolve conflict

* fix format error
上级 2a438b0a
...@@ -30,10 +30,10 @@ from paddle.fluid.framework import default_main_program ...@@ -30,10 +30,10 @@ from paddle.fluid.framework import default_main_program
from paddle.fluid import core from paddle.fluid import core
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(
"Paddle is not compiled with CUDA") not core.is_compiled_with_cuda(), "Paddle is not compiled with CUDA"
)
class TestFusedGateAttentionOp(OpTest): class TestFusedGateAttentionOp(OpTest):
def setUp(self): def setUp(self):
self.__class__.op_type = "fused_gate_attention" self.__class__.op_type = "fused_gate_attention"
# use autograd to check grad in this unittest. # use autograd to check grad in this unittest.
...@@ -57,7 +57,6 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -57,7 +57,6 @@ class TestFusedGateAttentionOp(OpTest):
self.bias_attr = True self.bias_attr = True
def generate_input_data(self): def generate_input_data(self):
def _random(shape): def _random(shape):
if self.dtype == "bfloat16": if self.dtype == "bfloat16":
data = np.random.random(shape).astype("float32") data = np.random.random(shape).astype("float32")
...@@ -67,7 +66,8 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -67,7 +66,8 @@ class TestFusedGateAttentionOp(OpTest):
np.random.seed(123) np.random.seed(123)
self.query = _random( self.query = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim)) (self.batch_size, self.msa_len, self.res_len, self.q_dim)
)
self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim)) self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim))
self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim)) self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim)) self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
...@@ -80,15 +80,18 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -80,15 +80,18 @@ class TestFusedGateAttentionOp(OpTest):
self.qkv_weight = np.stack([q_weight_t, k_weight_t, v_weight_t]) self.qkv_weight = np.stack([q_weight_t, k_weight_t, v_weight_t])
else: else:
self.key = _random( self.key = _random(
(self.batch_size, self.msa_len, self.m_size, self.kv_dim)) (self.batch_size, self.msa_len, self.m_size, self.kv_dim)
)
self.qkv_weight = None self.qkv_weight = None
self.attn_mask = _random( self.attn_mask = _random(
(self.batch_size, self.msa_len, 1, 1, self.m_size)) (self.batch_size, self.msa_len, 1, 1, self.m_size)
)
if self.bias_attr: if self.bias_attr:
self.nonbatched_bias = _random( self.nonbatched_bias = _random(
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size)) (self.batch_size, 1, self.num_heads, self.res_len, self.m_size)
)
if self.has_gating: if self.has_gating:
self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim)) self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim))
...@@ -98,12 +101,17 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -98,12 +101,17 @@ class TestFusedGateAttentionOp(OpTest):
self.output_b = _random((self.out_dim)) self.output_b = _random((self.out_dim))
self.dout = _random( self.dout = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim)) (self.batch_size, self.msa_len, self.res_len, self.q_dim)
)
def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out): def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out):
outputs = [ outputs = [
softmax_out, fmha_out, gate_out if self.has_gating else None, out, softmax_out,
query.grad, None if self.merge_qkv else key.grad fmha_out,
gate_out if self.has_gating else None,
out,
query.grad,
None if self.merge_qkv else key.grad,
] ]
return outputs return outputs
...@@ -111,14 +119,17 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -111,14 +119,17 @@ class TestFusedGateAttentionOp(OpTest):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
query = paddle.to_tensor(self.query, stop_gradient=False) query = paddle.to_tensor(self.query, stop_gradient=False)
key = query if self.merge_qkv else paddle.to_tensor(self.key, key = (
stop_gradient=False) query
if self.merge_qkv
else paddle.to_tensor(self.key, stop_gradient=False)
)
q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False) q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False)
k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False) k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False)
v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False) v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
c = self.head_dim**(-0.5) c = self.head_dim ** (-0.5)
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim] # [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim] # -> [batch_size, msa_len, res_len, num_heads, head_dim]
q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
...@@ -136,8 +147,9 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -136,8 +147,9 @@ class TestFusedGateAttentionOp(OpTest):
# -> [batch_size, msa_len, num_heads, res_len, m_size] # -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + src_mask logits = logits + src_mask
if self.bias_attr: if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias, nonbatched_bias = paddle.to_tensor(
stop_gradient=False) self.nonbatched_bias, stop_gradient=False
)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size] # [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size] # -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + nonbatched_bias logits = logits + nonbatched_bias
...@@ -159,14 +171,22 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -159,14 +171,22 @@ class TestFusedGateAttentionOp(OpTest):
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query, # gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
# gating_w) + gating_b # gating_w) + gating_b
gating_w_2d = paddle.reshape( gating_w_2d = paddle.reshape(
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim]) gating_w, shape=[self.q_dim, self.num_heads * self.head_dim]
)
gate_values_4d = paddle.matmul(query, gating_w_2d) gate_values_4d = paddle.matmul(query, gating_w_2d)
gate_values = paddle.reshape( gate_values = (
gate_values_4d, paddle.reshape(
shape=[ gate_values_4d,
self.batch_size, self.msa_len, self.res_len, self.num_heads, shape=[
self.head_dim self.batch_size,
]) + gating_b self.msa_len,
self.res_len,
self.num_heads,
self.head_dim,
],
)
+ gating_b
)
gate_values = nn.functional.sigmoid(gate_values) gate_values = nn.functional.sigmoid(gate_values)
gate_out = fmha_out * gate_values gate_out = fmha_out * gate_values
else: else:
...@@ -183,20 +203,32 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -183,20 +203,32 @@ class TestFusedGateAttentionOp(OpTest):
gate_out, gate_out,
shape=[ shape=[
self.batch_size * self.msa_len * self.res_len, self.batch_size * self.msa_len * self.res_len,
self.num_heads * self.head_dim self.num_heads * self.head_dim,
]) ],
)
output_w_2d = paddle.reshape( output_w_2d = paddle.reshape(
output_w, shape=[self.num_heads * self.head_dim, self.out_dim]) output_w, shape=[self.num_heads * self.head_dim, self.out_dim]
)
out_2d = paddle.matmul(gate_out_2d, output_w_2d) out_2d = paddle.matmul(gate_out_2d, output_w_2d)
out = paddle.reshape( out = (
out_2d, paddle.reshape(
shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim out_2d,
]) + output_b shape=[
self.batch_size,
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], self.msa_len,
retain_graph=True) self.res_len,
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out, self.out_dim,
out) ],
)
+ output_b
)
paddle.autograd.backward(
[out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return self.collect_outputs(
query, key, softmax_out, fmha_out, gate_out, out
)
def get_fused_gate_attention_out(self): def get_fused_gate_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0)) paddle.disable_static(place=paddle.CUDAPlace(0))
...@@ -218,8 +250,9 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -218,8 +250,9 @@ class TestFusedGateAttentionOp(OpTest):
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True) src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
if self.bias_attr: if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias, nonbatched_bias = paddle.to_tensor(
stop_gradient=False) self.nonbatched_bias, stop_gradient=False
)
else: else:
nonbatched_bias = None nonbatched_bias = None
if self.has_gating: if self.has_gating:
...@@ -232,18 +265,42 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -232,18 +265,42 @@ class TestFusedGateAttentionOp(OpTest):
output_w = paddle.to_tensor(self.output_w, stop_gradient=False) output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
output_b = paddle.to_tensor(self.output_b, stop_gradient=False) output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
_, _, _, _, softmax_out, fmha_out, gate_out, out = _legacy_C_ops.fused_gate_attention( (
query, key, q_weight, k_weight, v_weight, qkv_weight, _,
nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b, _,
'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv) _,
_,
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)], softmax_out,
retain_graph=True) fmha_out,
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out, gate_out,
out) out,
) = _legacy_C_ops.fused_gate_attention(
query,
key,
q_weight,
k_weight,
v_weight,
qkv_weight,
nonbatched_bias,
src_mask,
gating_w,
gating_b,
output_w,
output_b,
'has_gating',
self.has_gating,
'merge_qkv',
self.merge_qkv,
)
paddle.autograd.backward(
[out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return self.collect_outputs(
query, key, softmax_out, fmha_out, gate_out, out
)
def check(self, ref, out, atol, rtol, check_equal, name): def check(self, ref, out, atol, rtol, check_equal, name):
def _convert(value): def _convert(value):
if self.dtype == "bfloat16": if self.dtype == "bfloat16":
return convert_uint16_to_float(value) return convert_uint16_to_float(value)
...@@ -252,19 +309,25 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -252,19 +309,25 @@ class TestFusedGateAttentionOp(OpTest):
if check_equal: if check_equal:
self.assertTrue( self.assertTrue(
np.equal(_convert(ref), _convert(out)).all(), np.equal(_convert(ref), _convert(out)).all(),
"Checking < {} > failed!".format(name)) "Checking < {} > failed!".format(name),
)
else: else:
np.testing.assert_allclose( np.testing.assert_allclose(
_convert(ref), _convert(ref),
_convert(out), _convert(out),
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
err_msg="Checking < {} > failed!".format(name)) err_msg="Checking < {} > failed!".format(name),
)
def check_output_and_grad(self, atol, rtol): def check_output_and_grad(self, atol, rtol):
output_names = [ output_names = [
"softmax_out", "fmha_out", "gate_out", "out", "query_grad", "softmax_out",
"key_grad" "fmha_out",
"gate_out",
"out",
"query_grad",
"key_grad",
] ]
outputs_ref = self.get_reference_out() outputs_ref = self.get_reference_out()
outputs_fused = self.get_fused_gate_attention_out() outputs_fused = self.get_fused_gate_attention_out()
...@@ -280,22 +343,26 @@ class TestFusedGateAttentionOp(OpTest): ...@@ -280,22 +343,26 @@ class TestFusedGateAttentionOp(OpTest):
# that in fused ops, check_equal is set to False and we use allclose # that in fused ops, check_equal is set to False and we use allclose
# to check the correctness. # to check the correctness.
check_equal = False check_equal = False
self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol, self.check(
check_equal, output_names[i]) ref_res.numpy(),
fused_res.numpy(),
atol,
rtol,
check_equal,
output_names[i],
)
def test_output_and_grad(self): def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-5, rtol=1e-6) self.check_output_and_grad(atol=1e-5, rtol=1e-6)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp): class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
def config(self): def config(self):
super().config() super().config()
self.batch_size = 2 self.batch_size = 2
class TestSeparatedQKVCase(TestFusedGateAttentionOp): class TestSeparatedQKVCase(TestFusedGateAttentionOp):
def config(self): def config(self):
self.dtype = "float32" self.dtype = "float32"
self.has_gating = False self.has_gating = False
...@@ -312,7 +379,6 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp): ...@@ -312,7 +379,6 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp):
class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp): class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp):
def config(self): def config(self):
super().config() super().config()
self.has_gating = False self.has_gating = False
...@@ -320,7 +386,6 @@ class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp): ...@@ -320,7 +386,6 @@ class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp):
class TestMergeQKVFp16Case(TestFusedGateAttentionOp): class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
def config(self): def config(self):
super().config() super().config()
self.dtype = "float16" self.dtype = "float16"
...@@ -332,18 +397,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp): ...@@ -332,18 +397,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case): class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
def config(self): def config(self):
super().config() super().config()
self.batch_size = 2 self.batch_size = 2
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11000, not core.is_compiled_with_cuda()
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3" or get_cuda_version() < 11000
or paddle.device.cuda.get_device_capability()[0] < 8,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
) )
class TestMergeQKVBF16Case(TestFusedGateAttentionOp): class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
def config(self): def config(self):
super().config() super().config()
self.dtype = "bfloat16" self.dtype = "bfloat16"
...@@ -353,7 +418,6 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp): ...@@ -353,7 +418,6 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case): class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
def config(self): def config(self):
super().config() super().config()
self.batch_size = 2 self.batch_size = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册