From 09a604774faeffa67a4c32fc7b1cbb47479a538e Mon Sep 17 00:00:00 2001 From: caizejun <1129569290@qq.com> Date: Fri, 28 Jul 2023 14:16:39 +0800 Subject: [PATCH] fix fused multihead matmul unitest (#55755) * bugfix fused_multihead_matmul * fix test case of fused multihead matmul --------- Co-authored-by: bukejiyu <395822456@qq.com> --- .../test_fused_multihead_matmul_op.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/test/legacy_test/test_fused_multihead_matmul_op.py b/test/legacy_test/test_fused_multihead_matmul_op.py index 1f24227e3d5..d600aef0b98 100644 --- a/test/legacy_test/test_fused_multihead_matmul_op.py +++ b/test/legacy_test/test_fused_multihead_matmul_op.py @@ -47,12 +47,10 @@ class TestFusedMultiHeadMatmulOp_biasqk2(OpTest): self.config() h = self.seq_len w = self.head_number * self.size_per_head - self.Input = ( - np.random.random((self.batch_size, h, w)).astype("float32") - 0.5 - ) - self.WQ = np.random.random((w, w)).astype("float32") - self.KQ = np.random.random((w, w)).astype("float32") - self.VQ = np.random.random((w, w)).astype("float32") + self.Input = np.random.random((self.batch_size, h, w)).astype("float32") + self.WQ = np.eye(w).astype("float32") + self.KQ = np.eye(w).astype("float32") + self.VQ = np.eye(w).astype("float32") self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape( (w, 3, w) ) @@ -154,12 +152,10 @@ class TestFusedMultiheadMatmulOp(OpTest): self.config() h = self.seq_len w = self.head_number * self.size_per_head - self.Input = ( - np.random.random((self.batch_size, h, w)).astype("float32") - 0.5 - ) - self.WQ = np.random.random((w, w)).astype("float32") - self.KQ = np.random.random((w, w)).astype("float32") - self.VQ = np.random.random((w, w)).astype("float32") + self.Input = np.random.random((self.batch_size, h, w)).astype("float32") + self.WQ = np.eye(w).astype("float32") + self.KQ = np.eye(w).astype("float32") + self.VQ = np.eye(w).astype("float32") self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape( (w, 3, w) ) -- GitLab