diff --git a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py index c88a967035874022b0f9d94d2bfab1a393c56aee..74ccd16656724c3befa21ff3f6b4e2d1b3dd4513 100644 --- a/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py +++ b/python/paddle/distributed/fleet/utils/hybrid_parallel_util.py @@ -193,7 +193,7 @@ def fused_allreduce_gradients_with_group( else _apply_collective_grads ) with framework.no_grad(): - apply_func(parameter_list, group, bucket_size) + apply_func(parameter_list, group, bucket_size, scale) def fused_allreduce_gradients(parameter_list, hcg):