From 1c465824f5b132736cca50efdd745c22f53c8434 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Fri, 26 May 2023 15:50:06 +0800 Subject: [PATCH] [AutoParallel] update every rank has global view process_groups (#54067) * global view process_group * fix import * fix attr * fix tunner init comm --- .../distributed/auto_parallel/engine.py | 2 - .../auto_parallel/process_group.py | 37 +++++++++-------- .../distributed/auto_parallel/reshard.py | 40 ++++++++++++++++--- .../auto_parallel/tuner/profiler.py | 2 - .../paddle/distributed/auto_parallel/utils.py | 20 +++++++++- 5 files changed, 74 insertions(+), 27 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 728a60e18a4..7a979a86420 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -795,8 +795,6 @@ class Engine: ) else: for process_group in all_process_groups: - if self._cur_rank not in process_group.ranks: - continue process_group.instantiate() def _initialize(self, mode): diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index b5669b850b6..e7d8a758161 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -15,11 +15,13 @@ from collections import OrderedDict import paddle -from paddle import _legacy_C_ops -from paddle.framework import core, in_dynamic_mode -from paddle.tensor import fill_constant +from paddle.framework import core from ..collective import _get_global_env, _new_ring_id +from ..utils.log_utils import get_logger +from .utils import dygraph_guard + +logger = get_logger("INFO", __name__) def get_all_process_groups(): @@ -122,6 +124,7 @@ class ProcessGroup: def is_instantiate(self): return self._is_instantiate + @dygraph_guard def instantiate(self): if self._is_instantiate: return @@ -129,7 +132,10 @@ class ProcessGroup: genv = _get_global_env() global_rank = genv.rank - if self.nranks >= 2: + if self.nranks >= 2 and global_rank in self.ranks: + logger.info( + f"group_id: {self.id}, ranks: {self.ranks}, nranks: {self.nranks}, trainer_endpoints: {genv.current_endpoint}" + ) strategy = core.ParallelStrategy() strategy.nranks = self.nranks strategy.local_rank = self.local_rank(global_rank) @@ -156,9 +162,6 @@ class ProcessGroup: else: raise AssertionError('No CUDA device found') - # TODO(shenliang03): This is a temporary solution to solve the problem of - # hang caused by cross-creation of new_group - paddle.disable_static() if core.is_compiled_with_cuda(): paddle.set_device( 'gpu:%d' % paddle.distributed.ParallelEnv().dev_id @@ -175,17 +178,19 @@ class ProcessGroup: paddle.distributed.ParallelEnv().dev_id, ), ) - tmp = ( - paddle.to_tensor([1], dtype="int32") - if in_dynamic_mode() - else fill_constant([0], dtype="int32", value="1") + + # TODO(shenliang03): This is a temporary solution to solve the problem of + # hang caused by cross-creation of new_group + barrier_tensor = paddle.full([1], 1, dtype="int32") + paddle._legacy_C_ops.barrier( + barrier_tensor, barrier_tensor, 'ring_id', ring_id ) - # use legacy ops - _legacy_C_ops.c_allreduce_sum_( - tmp, 'use_calc_stream', True, 'ring_id', self.id + + if self.nranks > 1: + barrier_tensor = paddle.full([1], 1, dtype="int32") + paddle._legacy_C_ops.barrier( + barrier_tensor, barrier_tensor, 'ring_id', 0 ) - _legacy_C_ops.c_sync_calc_stream(tmp, tmp) - paddle.enable_static() self._is_instantiate = True diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 897dec1f12b..788528717a3 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License +from collections import OrderedDict from functools import reduce import paddle @@ -456,11 +457,11 @@ class Inserter: ) inputs = {'X': [tensor]} outputs = {"Out": [out]} - attrs = {"in_place": False} - slice_op = block._insert_op( + attrs = {"in_place": False, "op_role": op_role} + assign_op = block._insert_op( idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs ) - slice_op._set_attr('op_namescope', "/auto_parallel/reshard") + assign_op._set_attr('op_namescope', "/auto_parallel/reshard") return out # use split once @@ -1458,7 +1459,7 @@ class Resharder: if not serial else source_tensor.shape ) - op_desc_seq = {} + op_desc_seq = OrderedDict() # TODO: if the target process group has the same process with source process group if set(target_process_group).intersection( @@ -1723,6 +1724,23 @@ class Resharder: self, block, op_desc_seq, var_name, reshard_op, dist_attr ): """Parse op desc sequence and insert op in the block""" + + # Parse all communicator groups for all ranks + # Ensure every rank has a global view of communicator groups for entire cluters. + # When initialize communicators for pipeline parallel, every rank could + # conduct a correct global synchronization. + for rank_id in op_desc_seq: + op_desc_list = op_desc_seq[rank_id] + for op_desc in op_desc_list: + if isinstance(op_desc, AllGatherOpDesc): + new_process_group(op_desc.group) + elif isinstance(op_desc, AllGatherConcatOpDesc): + new_process_group(op_desc.group) + elif isinstance(op_desc, SendOpDesc): + new_process_group([op_desc.src, op_desc.dst]) + elif isinstance(op_desc, RecvOpDesc): + new_process_group([op_desc.src, op_desc.dst]) + tensor_list = [] partition_tensor_list = [] if self.rank_id not in op_desc_seq.keys(): @@ -2632,7 +2650,7 @@ class Resharder: item, recv_rank, ) - if self.rank_id == recv_rank: + elif self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( block, @@ -2642,6 +2660,11 @@ class Resharder: item, recv_rank, ) + else: + # Ensure every rank has a global view of communicator groups for entire cluters. + # When initialize communicators for pipeline parallel, every rank could + # conduct a correct global synchronization. + new_process_group([item, recv_rank]) else: for index, tensor_process in enumerate( tensor_processes @@ -2659,11 +2682,16 @@ class Resharder: self._handle_send( block, idx, var, op, item, recv_rank ) - if self.rank_id == recv_rank: + elif self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( block, idx, var, op, item, recv_rank ) + else: + # Ensure every rank has a global view of communicator groups for entire cluters. + # When initialize communicators for pipeline parallel, every rank could + # conduct a correct global synchronization. + new_process_group([item, recv_rank]) cur_op_count = len(block.ops) idx_offset = ( diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 27e0fa49845..486db968ee3 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -91,8 +91,6 @@ def init_process_groups(group_map, rank): # TODO should instantiate global group first all_process_groups = get_all_process_groups() for process_group in all_process_groups: - if process_group.id == 0 or rank not in process_group.ranks: - continue print(process_group) process_group.instantiate() diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index f4dfb8d9c20..d5a196a080d 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -22,12 +22,12 @@ from functools import reduce import numpy as np import paddle +from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.framework import core from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.static import Variable from .dist_attribute import OperatorDistAttr, TensorDistAttr -from .process_group import get_all_process_groups from .process_mesh import ProcessMesh OpRole = core.op_proto_and_checker_maker.OpRole @@ -55,6 +55,8 @@ def get_logger(log_level, name="auto_parallel"): ) log_handler.setFormatter(log_format) logger.addHandler(log_handler) + else: + logger.setLevel(log_level) return logger @@ -1816,6 +1818,8 @@ def debug_program(program, path, name): def ring_id_to_process_group(ring_id): + from .process_group import get_all_process_groups + for g in get_all_process_groups(): if g.id == ring_id: return g @@ -2355,3 +2359,17 @@ def is_dep_skip_op(op): return True return False + + +def _dygraph_guard_(func): + def __impl__(*args, **kwargs): + if paddle.framework.in_dynamic_mode(): + return func(*args, **kwargs) + else: + with paddle.fluid.dygraph.guard(): + return func(*args, **kwargs) + + return __impl__ + + +dygraph_guard = wrap_decorator(_dygraph_guard_) -- GitLab