From 6e4cba14e2f838e2c1a246ece65f39f6af38f23f Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Mon, 10 Oct 2022 15:08:39 +0800 Subject: [PATCH] fix the combination bug of sharding stage1 + dp (#46631) --- python/paddle/distributed/fleet/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 40633788f12..d75f490fd01 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -139,7 +139,8 @@ def distributed_model(model): model, comm_buffer_size=strategy.fuse_grad_size_in_MB, last_comm_buffer_size=strategy.last_comm_group_size_MB, - find_unused_parameters=strategy.find_unused_parameters) + find_unused_parameters=strategy.find_unused_parameters, + group=fleet_env._hcg.get_data_parallel_group()) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: -- GitLab