未验证 提交 719d96b9 编写于 作者: Y Yuang Liu 提交者: GitHub

do not use fuse for sync param in dp (#56437)

上级 14393611
......@@ -145,7 +145,11 @@ def build_groups(vars, group_size):
@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(
model, comm_group=None, src_rank=0, is_model_parallel=False
model,
comm_group=None,
src_rank=0,
is_model_parallel=False,
fuse_params=True,
):
model_vars = []
for _, param in model._obtain_parameters_buffers().items():
......@@ -170,6 +174,7 @@ def sync_params_buffers(
if len(model_vars) == 0:
return
if fuse_params:
# group size is 128M
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
......@@ -186,6 +191,11 @@ def sync_params_buffers(
outputs={'Out': origin_vars},
attrs={'sections': var_len, 'axis': 0},
)
else:
for var in model_vars:
paddle.distributed.broadcast(
var, src=src_rank, group=comm_group, sync_op=True
)
class DataParallel(layers.Layer):
......@@ -398,7 +408,7 @@ class DataParallel(layers.Layer):
), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params
sync_params_buffers(self._layers)
sync_params_buffers(self._layers, fuse_params=False)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册