未验证 提交 7696ae02 编写于 作者: Y YUNSHEN XIE 提交者: GitHub

[Cherry-pick] add condition of skipif (#49407)

* resolve conflict

* fix format error
上级 2a438b0a
......@@ -30,10 +30,10 @@ from paddle.fluid.framework import default_main_program
from paddle.fluid import core
@unittest.skipIf(not core.is_compiled_with_cuda(),
"Paddle is not compiled with CUDA")
@unittest.skipIf(
not core.is_compiled_with_cuda(), "Paddle is not compiled with CUDA"
)
class TestFusedGateAttentionOp(OpTest):
def setUp(self):
self.__class__.op_type = "fused_gate_attention"
# use autograd to check grad in this unittest.
......@@ -57,7 +57,6 @@ class TestFusedGateAttentionOp(OpTest):
self.bias_attr = True
def generate_input_data(self):
def _random(shape):
if self.dtype == "bfloat16":
data = np.random.random(shape).astype("float32")
......@@ -67,7 +66,8 @@ class TestFusedGateAttentionOp(OpTest):
np.random.seed(123)
self.query = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
(self.batch_size, self.msa_len, self.res_len, self.q_dim)
)
self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim))
self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
......@@ -80,15 +80,18 @@ class TestFusedGateAttentionOp(OpTest):
self.qkv_weight = np.stack([q_weight_t, k_weight_t, v_weight_t])
else:
self.key = _random(
(self.batch_size, self.msa_len, self.m_size, self.kv_dim))
(self.batch_size, self.msa_len, self.m_size, self.kv_dim)
)
self.qkv_weight = None
self.attn_mask = _random(
(self.batch_size, self.msa_len, 1, 1, self.m_size))
(self.batch_size, self.msa_len, 1, 1, self.m_size)
)
if self.bias_attr:
self.nonbatched_bias = _random(
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size))
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size)
)
if self.has_gating:
self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim))
......@@ -98,12 +101,17 @@ class TestFusedGateAttentionOp(OpTest):
self.output_b = _random((self.out_dim))
self.dout = _random(
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
(self.batch_size, self.msa_len, self.res_len, self.q_dim)
)
def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out):
outputs = [
softmax_out, fmha_out, gate_out if self.has_gating else None, out,
query.grad, None if self.merge_qkv else key.grad
softmax_out,
fmha_out,
gate_out if self.has_gating else None,
out,
query.grad,
None if self.merge_qkv else key.grad,
]
return outputs
......@@ -111,14 +119,17 @@ class TestFusedGateAttentionOp(OpTest):
paddle.disable_static(place=paddle.CUDAPlace(0))
query = paddle.to_tensor(self.query, stop_gradient=False)
key = query if self.merge_qkv else paddle.to_tensor(self.key,
stop_gradient=False)
key = (
query
if self.merge_qkv
else paddle.to_tensor(self.key, stop_gradient=False)
)
q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False)
k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False)
v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
c = self.head_dim**(-0.5)
c = self.head_dim ** (-0.5)
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
......@@ -136,8 +147,9 @@ class TestFusedGateAttentionOp(OpTest):
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + src_mask
if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
stop_gradient=False)
nonbatched_bias = paddle.to_tensor(
self.nonbatched_bias, stop_gradient=False
)
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
# -> [batch_size, msa_len, num_heads, res_len, m_size]
logits = logits + nonbatched_bias
......@@ -159,14 +171,22 @@ class TestFusedGateAttentionOp(OpTest):
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
# gating_w) + gating_b
gating_w_2d = paddle.reshape(
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim])
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim]
)
gate_values_4d = paddle.matmul(query, gating_w_2d)
gate_values = paddle.reshape(
gate_values_4d,
shape=[
self.batch_size, self.msa_len, self.res_len, self.num_heads,
self.head_dim
]) + gating_b
gate_values = (
paddle.reshape(
gate_values_4d,
shape=[
self.batch_size,
self.msa_len,
self.res_len,
self.num_heads,
self.head_dim,
],
)
+ gating_b
)
gate_values = nn.functional.sigmoid(gate_values)
gate_out = fmha_out * gate_values
else:
......@@ -183,20 +203,32 @@ class TestFusedGateAttentionOp(OpTest):
gate_out,
shape=[
self.batch_size * self.msa_len * self.res_len,
self.num_heads * self.head_dim
])
self.num_heads * self.head_dim,
],
)
output_w_2d = paddle.reshape(
output_w, shape=[self.num_heads * self.head_dim, self.out_dim])
output_w, shape=[self.num_heads * self.head_dim, self.out_dim]
)
out_2d = paddle.matmul(gate_out_2d, output_w_2d)
out = paddle.reshape(
out_2d,
shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim
]) + output_b
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True)
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
out)
out = (
paddle.reshape(
out_2d,
shape=[
self.batch_size,
self.msa_len,
self.res_len,
self.out_dim,
],
)
+ output_b
)
paddle.autograd.backward(
[out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return self.collect_outputs(
query, key, softmax_out, fmha_out, gate_out, out
)
def get_fused_gate_attention_out(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
......@@ -218,8 +250,9 @@ class TestFusedGateAttentionOp(OpTest):
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
if self.bias_attr:
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
stop_gradient=False)
nonbatched_bias = paddle.to_tensor(
self.nonbatched_bias, stop_gradient=False
)
else:
nonbatched_bias = None
if self.has_gating:
......@@ -232,18 +265,42 @@ class TestFusedGateAttentionOp(OpTest):
output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
_, _, _, _, softmax_out, fmha_out, gate_out, out = _legacy_C_ops.fused_gate_attention(
query, key, q_weight, k_weight, v_weight, qkv_weight,
nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b,
'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv)
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
retain_graph=True)
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
out)
(
_,
_,
_,
_,
softmax_out,
fmha_out,
gate_out,
out,
) = _legacy_C_ops.fused_gate_attention(
query,
key,
q_weight,
k_weight,
v_weight,
qkv_weight,
nonbatched_bias,
src_mask,
gating_w,
gating_b,
output_w,
output_b,
'has_gating',
self.has_gating,
'merge_qkv',
self.merge_qkv,
)
paddle.autograd.backward(
[out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return self.collect_outputs(
query, key, softmax_out, fmha_out, gate_out, out
)
def check(self, ref, out, atol, rtol, check_equal, name):
def _convert(value):
if self.dtype == "bfloat16":
return convert_uint16_to_float(value)
......@@ -252,19 +309,25 @@ class TestFusedGateAttentionOp(OpTest):
if check_equal:
self.assertTrue(
np.equal(_convert(ref), _convert(out)).all(),
"Checking < {} > failed!".format(name))
"Checking < {} > failed!".format(name),
)
else:
np.testing.assert_allclose(
_convert(ref),
_convert(out),
atol=atol,
rtol=rtol,
err_msg="Checking < {} > failed!".format(name))
err_msg="Checking < {} > failed!".format(name),
)
def check_output_and_grad(self, atol, rtol):
output_names = [
"softmax_out", "fmha_out", "gate_out", "out", "query_grad",
"key_grad"
"softmax_out",
"fmha_out",
"gate_out",
"out",
"query_grad",
"key_grad",
]
outputs_ref = self.get_reference_out()
outputs_fused = self.get_fused_gate_attention_out()
......@@ -280,22 +343,26 @@ class TestFusedGateAttentionOp(OpTest):
# that in fused ops, check_equal is set to False and we use allclose
# to check the correctness.
check_equal = False
self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol,
check_equal, output_names[i])
self.check(
ref_res.numpy(),
fused_res.numpy(),
atol,
rtol,
check_equal,
output_names[i],
)
def test_output_and_grad(self):
self.check_output_and_grad(atol=1e-5, rtol=1e-6)
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
def config(self):
super().config()
self.batch_size = 2
class TestSeparatedQKVCase(TestFusedGateAttentionOp):
def config(self):
self.dtype = "float32"
self.has_gating = False
......@@ -312,7 +379,6 @@ class TestSeparatedQKVCase(TestFusedGateAttentionOp):
class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp):
def config(self):
super().config()
self.has_gating = False
......@@ -320,7 +386,6 @@ class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp):
class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
def config(self):
super().config()
self.dtype = "float16"
......@@ -332,18 +397,18 @@ class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
def config(self):
super().config()
self.batch_size = 2
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11000,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
not core.is_compiled_with_cuda()
or get_cuda_version() < 11000
or paddle.device.cuda.get_device_capability()[0] < 8,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
)
class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
def config(self):
super().config()
self.dtype = "bfloat16"
......@@ -353,7 +418,6 @@ class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
def config(self):
super().config()
self.batch_size = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册