未验证 提交 1c465824 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] update every rank has global view process_groups (#54067)

* global view process_group

* fix import

* fix attr

* fix tunner init comm
上级 2b546613
...@@ -795,8 +795,6 @@ class Engine: ...@@ -795,8 +795,6 @@ class Engine:
) )
else: else:
for process_group in all_process_groups: for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate() process_group.instantiate()
def _initialize(self, mode): def _initialize(self, mode):
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
from paddle import _legacy_C_ops from paddle.framework import core
from paddle.framework import core, in_dynamic_mode
from paddle.tensor import fill_constant
from ..collective import _get_global_env, _new_ring_id from ..collective import _get_global_env, _new_ring_id
from ..utils.log_utils import get_logger
from .utils import dygraph_guard
logger = get_logger("INFO", __name__)
def get_all_process_groups(): def get_all_process_groups():
...@@ -122,6 +124,7 @@ class ProcessGroup: ...@@ -122,6 +124,7 @@ class ProcessGroup:
def is_instantiate(self): def is_instantiate(self):
return self._is_instantiate return self._is_instantiate
@dygraph_guard
def instantiate(self): def instantiate(self):
if self._is_instantiate: if self._is_instantiate:
return return
...@@ -129,7 +132,10 @@ class ProcessGroup: ...@@ -129,7 +132,10 @@ class ProcessGroup:
genv = _get_global_env() genv = _get_global_env()
global_rank = genv.rank global_rank = genv.rank
if self.nranks >= 2: if self.nranks >= 2 and global_rank in self.ranks:
logger.info(
f"group_id: {self.id}, ranks: {self.ranks}, nranks: {self.nranks}, trainer_endpoints: {genv.current_endpoint}"
)
strategy = core.ParallelStrategy() strategy = core.ParallelStrategy()
strategy.nranks = self.nranks strategy.nranks = self.nranks
strategy.local_rank = self.local_rank(global_rank) strategy.local_rank = self.local_rank(global_rank)
...@@ -156,9 +162,6 @@ class ProcessGroup: ...@@ -156,9 +162,6 @@ class ProcessGroup:
else: else:
raise AssertionError('No CUDA device found') raise AssertionError('No CUDA device found')
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
paddle.disable_static()
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
paddle.set_device( paddle.set_device(
'gpu:%d' % paddle.distributed.ParallelEnv().dev_id 'gpu:%d' % paddle.distributed.ParallelEnv().dev_id
...@@ -175,17 +178,19 @@ class ProcessGroup: ...@@ -175,17 +178,19 @@ class ProcessGroup:
paddle.distributed.ParallelEnv().dev_id, paddle.distributed.ParallelEnv().dev_id,
), ),
) )
tmp = (
paddle.to_tensor([1], dtype="int32") # TODO(shenliang03): This is a temporary solution to solve the problem of
if in_dynamic_mode() # hang caused by cross-creation of new_group
else fill_constant([0], dtype="int32", value="1") barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
) )
# use legacy ops
_legacy_C_ops.c_allreduce_sum_( if self.nranks > 1:
tmp, 'use_calc_stream', True, 'ring_id', self.id barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', 0
) )
_legacy_C_ops.c_sync_calc_stream(tmp, tmp)
paddle.enable_static()
self._is_instantiate = True self._is_instantiate = True
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from collections import OrderedDict
from functools import reduce from functools import reduce
import paddle import paddle
...@@ -456,11 +457,11 @@ class Inserter: ...@@ -456,11 +457,11 @@ class Inserter:
) )
inputs = {'X': [tensor]} inputs = {'X': [tensor]}
outputs = {"Out": [out]} outputs = {"Out": [out]}
attrs = {"in_place": False} attrs = {"in_place": False, "op_role": op_role}
slice_op = block._insert_op( assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
) )
slice_op._set_attr('op_namescope', "/auto_parallel/reshard") assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use split once # use split once
...@@ -1458,7 +1459,7 @@ class Resharder: ...@@ -1458,7 +1459,7 @@ class Resharder:
if not serial if not serial
else source_tensor.shape else source_tensor.shape
) )
op_desc_seq = {} op_desc_seq = OrderedDict()
# TODO: if the target process group has the same process with source process group # TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection( if set(target_process_group).intersection(
...@@ -1723,6 +1724,23 @@ class Resharder: ...@@ -1723,6 +1724,23 @@ class Resharder:
self, block, op_desc_seq, var_name, reshard_op, dist_attr self, block, op_desc_seq, var_name, reshard_op, dist_attr
): ):
"""Parse op desc sequence and insert op in the block""" """Parse op desc sequence and insert op in the block"""
# Parse all communicator groups for all ranks
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
for rank_id in op_desc_seq:
op_desc_list = op_desc_seq[rank_id]
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc):
new_process_group(op_desc.group)
elif isinstance(op_desc, AllGatherConcatOpDesc):
new_process_group(op_desc.group)
elif isinstance(op_desc, SendOpDesc):
new_process_group([op_desc.src, op_desc.dst])
elif isinstance(op_desc, RecvOpDesc):
new_process_group([op_desc.src, op_desc.dst])
tensor_list = [] tensor_list = []
partition_tensor_list = [] partition_tensor_list = []
if self.rank_id not in op_desc_seq.keys(): if self.rank_id not in op_desc_seq.keys():
...@@ -2632,7 +2650,7 @@ class Resharder: ...@@ -2632,7 +2650,7 @@ class Resharder:
item, item,
recv_rank, recv_rank,
) )
if self.rank_id == recv_rank: elif self.rank_id == recv_rank:
# if recv bool data, recv then cast # if recv bool data, recv then cast
self._hadnle_recv( self._hadnle_recv(
block, block,
...@@ -2642,6 +2660,11 @@ class Resharder: ...@@ -2642,6 +2660,11 @@ class Resharder:
item, item,
recv_rank, recv_rank,
) )
else:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
else: else:
for index, tensor_process in enumerate( for index, tensor_process in enumerate(
tensor_processes tensor_processes
...@@ -2659,11 +2682,16 @@ class Resharder: ...@@ -2659,11 +2682,16 @@ class Resharder:
self._handle_send( self._handle_send(
block, idx, var, op, item, recv_rank block, idx, var, op, item, recv_rank
) )
if self.rank_id == recv_rank: elif self.rank_id == recv_rank:
# if recv bool data, recv then cast # if recv bool data, recv then cast
self._hadnle_recv( self._hadnle_recv(
block, idx, var, op, item, recv_rank block, idx, var, op, item, recv_rank
) )
else:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
idx_offset = ( idx_offset = (
......
...@@ -91,8 +91,6 @@ def init_process_groups(group_map, rank): ...@@ -91,8 +91,6 @@ def init_process_groups(group_map, rank):
# TODO should instantiate global group first # TODO should instantiate global group first
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
for process_group in all_process_groups: for process_group in all_process_groups:
if process_group.id == 0 or rank not in process_group.ranks:
continue
print(process_group) print(process_group)
process_group.instantiate() process_group.instantiate()
......
...@@ -22,12 +22,12 @@ from functools import reduce ...@@ -22,12 +22,12 @@ from functools import reduce
import numpy as np import numpy as np
import paddle import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.framework import core from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable from paddle.static import Variable
from .dist_attribute import OperatorDistAttr, TensorDistAttr from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .process_group import get_all_process_groups
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
...@@ -55,6 +55,8 @@ def get_logger(log_level, name="auto_parallel"): ...@@ -55,6 +55,8 @@ def get_logger(log_level, name="auto_parallel"):
) )
log_handler.setFormatter(log_format) log_handler.setFormatter(log_format)
logger.addHandler(log_handler) logger.addHandler(log_handler)
else:
logger.setLevel(log_level)
return logger return logger
...@@ -1816,6 +1818,8 @@ def debug_program(program, path, name): ...@@ -1816,6 +1818,8 @@ def debug_program(program, path, name):
def ring_id_to_process_group(ring_id): def ring_id_to_process_group(ring_id):
from .process_group import get_all_process_groups
for g in get_all_process_groups(): for g in get_all_process_groups():
if g.id == ring_id: if g.id == ring_id:
return g return g
...@@ -2355,3 +2359,17 @@ def is_dep_skip_op(op): ...@@ -2355,3 +2359,17 @@ def is_dep_skip_op(op):
return True return True
return False return False
def _dygraph_guard_(func):
def __impl__(*args, **kwargs):
if paddle.framework.in_dynamic_mode():
return func(*args, **kwargs)
else:
with paddle.fluid.dygraph.guard():
return func(*args, **kwargs)
return __impl__
dygraph_guard = wrap_decorator(_dygraph_guard_)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册