未验证 提交 a5836222 编写于 作者: L LiYuRio 提交者: GitHub

fix fused_attention with mp unit test case fail on A100 with CUDA >= 11.6 (#45883)

上级 7c1dc754
...@@ -36,7 +36,7 @@ def get_param_attr(weight, bias): ...@@ -36,7 +36,7 @@ def get_param_attr(weight, bias):
DTYPE = "float32" DTYPE = "float32"
MODEL_PARALLEL_SIZE = 2 MODEL_PARALLEL_SIZE = 2
n_head = 2 * MODEL_PARALLEL_SIZE n_head = 2 * MODEL_PARALLEL_SIZE
d_key = 4 d_key = 2
hidden = n_head * d_key hidden = n_head * d_key
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册