diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 40633788f12d45be4bc30f25044e700656034d93..d75f490fd0152955c4d4b2f191d068d247e37849 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -139,7 +139,8 @@ def distributed_model(model): model, comm_buffer_size=strategy.fuse_grad_size_in_MB, last_comm_buffer_size=strategy.last_comm_group_size_MB, - find_unused_parameters=strategy.find_unused_parameters) + find_unused_parameters=strategy.find_unused_parameters, + group=fleet_env._hcg.get_data_parallel_group()) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: model = TensorParallel(model, fleet_env._hcg, strategy=strategy) elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: