diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index a34807d2b7377ee1cea7864a828d8feaf65d689d..217920debe4de258175925c5fbb42b19d1f328d5 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -145,7 +145,11 @@ def build_groups(vars, group_size): @imperative_base.no_grad @framework.dygraph_only def sync_params_buffers( - model, comm_group=None, src_rank=0, is_model_parallel=False + model, + comm_group=None, + src_rank=0, + is_model_parallel=False, + fuse_params=True, ): model_vars = [] for _, param in model._obtain_parameters_buffers().items(): @@ -170,22 +174,28 @@ def sync_params_buffers( if len(model_vars) == 0: return - # group size is 128M - coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024) + if fuse_params: + # group size is 128M + coalesced_vars = build_groups(model_vars, 128 * 1024 * 1024) - for coalesced_var, _, _ in coalesced_vars: - paddle.distributed.broadcast( - coalesced_var, src=src_rank, group=comm_group, sync_op=True - ) + for coalesced_var, _, _ in coalesced_vars: + paddle.distributed.broadcast( + coalesced_var, src=src_rank, group=comm_group, sync_op=True + ) - for coalesced_var, origin_vars, var_shapes in coalesced_vars: - var_len = [np.prod(v_shape) for v_shape in var_shapes] - paddle.fluid.framework._dygraph_tracer().trace_op( - type='split', - inputs={'X': coalesced_var}, - outputs={'Out': origin_vars}, - attrs={'sections': var_len, 'axis': 0}, - ) + for coalesced_var, origin_vars, var_shapes in coalesced_vars: + var_len = [np.prod(v_shape) for v_shape in var_shapes] + paddle.fluid.framework._dygraph_tracer().trace_op( + type='split', + inputs={'X': coalesced_var}, + outputs={'Out': origin_vars}, + attrs={'sections': var_len, 'axis': 0}, + ) + else: + for var in model_vars: + paddle.distributed.broadcast( + var, src=src_rank, group=comm_group, sync_op=True + ) class DataParallel(layers.Layer): @@ -398,7 +408,7 @@ class DataParallel(layers.Layer): ), "ProcessGroup must be an instance of Group in DataParallel." # sync buffer and params - sync_params_buffers(self._layers) + sync_params_buffers(self._layers, fuse_params=False) self.comm_buffer_size = int(comm_buffer_size * 1024 * 1024) # NOTE(shenliang03): We can set environment variables to control