未验证 提交 c5c3093b 编写于 作者: L Logan Adams 提交者: GitHub

Merge branch 'master' into loadams/low-cpu-mem-ut

......@@ -944,6 +944,10 @@ def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_align
partition_id = dist.get_rank(group=dp_process_group[group_id])
dp_world_size = dist.get_world_size(group=dp_process_group[group_id])
if dp_world_size == 1:
# no groups share optimizer states
# pipeline parallel with bf16 will default call this even if dp size = 1.
continue
num_shards = max(1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size)
shard_size = partitioned_params[partition_id].numel() // num_shards
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册