未验证 提交 7f3e82fe 编写于 作者: M mzl 提交者: GitHub

do allgather only in shared optimizer states groups (#4167)

* skip all-gather

* add notes

---------
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 7711bdbb
......@@ -944,6 +944,10 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align
partition_id = dist.get_rank(group=dp_process_group[group_id])
dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
if dp_world_size == 1:
# no groups share optimizer states
# pipeline parallel with bf16 will default call this even if dp size = 1.
continue
num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册