未验证 提交 6e4cba14 编写于 作者: W wuhuachaocoding 提交者: GitHub

fix the combination bug of sharding stage1 + dp (#46631)

上级 6512e087
...@@ -139,7 +139,8 @@ def distributed_model(model): ...@@ -139,7 +139,8 @@ def distributed_model(model):
model, model,
comm_buffer_size=strategy.fuse_grad_size_in_MB, comm_buffer_size=strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=strategy.last_comm_group_size_MB, last_comm_buffer_size=strategy.last_comm_group_size_MB,
find_unused_parameters=strategy.find_unused_parameters) find_unused_parameters=strategy.find_unused_parameters,
group=fleet_env._hcg.get_data_parallel_group())
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL: elif fleet_env._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
model = TensorParallel(model, fleet_env._hcg, strategy=strategy) model = TensorParallel(model, fleet_env._hcg, strategy=strategy)
elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: elif fleet_env._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册