From a58362222cf53a63f32c9bddc1d4afde65e25387 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Fri, 9 Sep 2022 08:44:14 +0800 Subject: [PATCH] fix fused_attention with mp unit test case fail on A100 with CUDA >= 11.6 (#45883) --- .../tests/unittests/static_model_parallel_fused_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py index e34da4c45a7..0fdc7ac0218 100644 --- a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py @@ -36,7 +36,7 @@ def get_param_attr(weight, bias): DTYPE = "float32" MODEL_PARALLEL_SIZE = 2 n_head = 2 * MODEL_PARALLEL_SIZE -d_key = 4 +d_key = 2 hidden = n_head * d_key -- GitLab