未验证 提交 719d96b9 编写于 作者: Y Yuang Liu 提交者: GitHub

do not use fuse for sync param in dp (#56437)

上级 14393611
......@@ -145,7 +145,11 @@ def build_groups(vars, group_size):
@imperative_base.no_grad
@framework.dygraph_only
def sync_params_buffers(
model, comm_group=None, src_rank=0, is_model_parallel=False
model,
comm_group=None,
src_rank=0,
is_model_parallel=False,
fuse_params=True,
):
model_vars = []
for _, param in model._obtain_parameters_buffers().items():
......@@ -170,22 +174,28 @@ def sync_params_buffers(
if len(model_vars) == 0:
return
# group size is 128M
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
if fuse_params:
# group size is 128M
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast(
coalesced_var, src=src_rank, group=comm_group, sync_op=True
)
for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast(
coalesced_var, src=src_rank, group=comm_group, sync_op=True
)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
paddle.fluid.framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len, 'axis': 0},
)
for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes]
paddle.fluid.framework._dygraph_tracer().trace_op(
type='split',
inputs={'X': coalesced_var},
outputs={'Out': origin_vars},
attrs={'sections': var_len, 'axis': 0},
)
else:
for var in model_vars:
paddle.distributed.broadcast(
var, src=src_rank, group=comm_group, sync_op=True
)
class DataParallel(layers.Layer):
......@@ -398,7 +408,7 @@ class DataParallel(layers.Layer):
), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params
sync_params_buffers(self._layers)
sync_params_buffers(self._layers, fuse_params=False)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册