未验证 提交 719d96b9 编写于 作者: Y Yuang Liu 提交者: GitHub

do not use fuse for sync param in dp (#56437)

上级 14393611
...@@ -145,7 +145,11 @@ def build_groups(vars, group_size): ...@@ -145,7 +145,11 @@ def build_groups(vars, group_size):
@imperative_base.no_grad @imperative_base.no_grad
@framework.dygraph_only @framework.dygraph_only
def sync_params_buffers( def sync_params_buffers(
model, comm_group=None, src_rank=0, is_model_parallel=False model,
comm_group=None,
src_rank=0,
is_model_parallel=False,
fuse_params=True,
): ):
model_vars = [] model_vars = []
for _, param in model._obtain_parameters_buffers().items(): for _, param in model._obtain_parameters_buffers().items():
...@@ -170,22 +174,28 @@ def sync_params_buffers( ...@@ -170,22 +174,28 @@ def sync_params_buffers(
if len(model_vars) == 0: if len(model_vars) == 0:
return return
# group size is 128M if fuse_params:
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024) # group size is 128M
coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024)
for coalesced_var, _, _ in coalesced_vars: for coalesced_var, _, _ in coalesced_vars:
paddle.distributed.broadcast( paddle.distributed.broadcast(
coalesced_var, src=src_rank, group=comm_group, sync_op=True coalesced_var, src=src_rank, group=comm_group, sync_op=True
) )
for coalesced_var, origin_vars, var_shapes in coalesced_vars: for coalesced_var, origin_vars, var_shapes in coalesced_vars:
var_len = [np.prod(v_shape) for v_shape in var_shapes] var_len = [np.prod(v_shape) for v_shape in var_shapes]
paddle.fluid.framework._dygraph_tracer().trace_op( paddle.fluid.framework._dygraph_tracer().trace_op(
type='split', type='split',
inputs={'X': coalesced_var}, inputs={'X': coalesced_var},
outputs={'Out': origin_vars}, outputs={'Out': origin_vars},
attrs={'sections': var_len, 'axis': 0}, attrs={'sections': var_len, 'axis': 0},
) )
else:
for var in model_vars:
paddle.distributed.broadcast(
var, src=src_rank, group=comm_group, sync_op=True
)
class DataParallel(layers.Layer): class DataParallel(layers.Layer):
...@@ -398,7 +408,7 @@ class DataParallel(layers.Layer): ...@@ -398,7 +408,7 @@ class DataParallel(layers.Layer):
), "ProcessGroup must be an instance of Group in DataParallel." ), "ProcessGroup must be an instance of Group in DataParallel."
# sync buffer and params # sync buffer and params
sync_params_buffers(self._layers) sync_params_buffers(self._layers, fuse_params=False)
self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024)
# NOTE(shenliang03): We can set environment variables to control # NOTE(shenliang03): We can set environment variables to control
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册