未验证 提交 09a60477 编写于 作者: C caizejun 提交者: GitHub

fix fused multihead matmul unitest (#55755)

* bugfix fused_multihead_matmul

* fix test case of fused multihead matmul

---------
Co-authored-by: Nbukejiyu <395822456@qq.com>
上级 b76c2f94
......@@ -47,12 +47,10 @@ class TestFusedMultiHeadMatmulOp_biasqk2(OpTest):
self.config()
h = self.seq_len
w = self.head_number * self.size_per_head
self.Input = (
np.random.random((self.batch_size, h, w)).astype("float32") - 0.5
)
self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.Input = np.random.random((self.batch_size, h, w)).astype("float32")
self.WQ = np.eye(w).astype("float32")
self.KQ = np.eye(w).astype("float32")
self.VQ = np.eye(w).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w)
)
......@@ -154,12 +152,10 @@ class TestFusedMultiheadMatmulOp(OpTest):
self.config()
h = self.seq_len
w = self.head_number * self.size_per_head
self.Input = (
np.random.random((self.batch_size, h, w)).astype("float32") - 0.5
)
self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.Input = np.random.random((self.batch_size, h, w)).astype("float32")
self.WQ = np.eye(w).astype("float32")
self.KQ = np.eye(w).astype("float32")
self.VQ = np.eye(w).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w)
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册