diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py index e34da4c45a7c36ca6e4cd35b4ff24c920c547db1..0fdc7ac0218991887088fa5c180ffcc505973e2b 100644 --- a/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_fused_attention.py @@ -36,7 +36,7 @@ def get_param_attr(weight, bias): DTYPE = "float32" MODEL_PARALLEL_SIZE = 2 n_head = 2 * MODEL_PARALLEL_SIZE -d_key = 4 +d_key = 2 hidden = n_head * d_key