未验证 提交 1c465824 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] update every rank has global view process_groups (#54067)

* global view process_group

* fix import

* fix attr

* fix tunner init comm
上级 2b546613
......@@ -795,8 +795,6 @@ class Engine:
)
else:
for process_group in all_process_groups:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate()
def _initialize(self, mode):
......
......@@ -15,11 +15,13 @@
from collections import OrderedDict
import paddle
from paddle import _legacy_C_ops
from paddle.framework import core, in_dynamic_mode
from paddle.tensor import fill_constant
from paddle.framework import core
from ..collective import _get_global_env, _new_ring_id
from ..utils.log_utils import get_logger
from .utils import dygraph_guard
logger = get_logger("INFO", __name__)
def get_all_process_groups():
......@@ -122,6 +124,7 @@ class ProcessGroup:
def is_instantiate(self):
return self._is_instantiate
@dygraph_guard
def instantiate(self):
if self._is_instantiate:
return
......@@ -129,7 +132,10 @@ class ProcessGroup:
genv = _get_global_env()
global_rank = genv.rank
if self.nranks >= 2:
if self.nranks >= 2 and global_rank in self.ranks:
logger.info(
f"group_id: {self.id}, ranks: {self.ranks}, nranks: {self.nranks}, trainer_endpoints: {genv.current_endpoint}"
)
strategy = core.ParallelStrategy()
strategy.nranks = self.nranks
strategy.local_rank = self.local_rank(global_rank)
......@@ -156,9 +162,6 @@ class ProcessGroup:
else:
raise AssertionError('No CUDA device found')
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
paddle.disable_static()
if core.is_compiled_with_cuda():
paddle.set_device(
'gpu:%d' % paddle.distributed.ParallelEnv().dev_id
......@@ -175,17 +178,19 @@ class ProcessGroup:
paddle.distributed.ParallelEnv().dev_id,
),
)
tmp = (
paddle.to_tensor([1], dtype="int32")
if in_dynamic_mode()
else fill_constant([0], dtype="int32", value="1")
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
# use legacy ops
_legacy_C_ops.c_allreduce_sum_(
tmp, 'use_calc_stream', True, 'ring_id', self.id
if self.nranks > 1:
barrier_tensor = paddle.full([1], 1, dtype="int32")
paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', 0
)
_legacy_C_ops.c_sync_calc_stream(tmp, tmp)
paddle.enable_static()
self._is_instantiate = True
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
from collections import OrderedDict
from functools import reduce
import paddle
......@@ -456,11 +457,11 @@ class Inserter:
)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
slice_op = block._insert_op(
attrs = {"in_place": False, "op_role": op_role}
assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
# use split once
......@@ -1458,7 +1459,7 @@ class Resharder:
if not serial
else source_tensor.shape
)
op_desc_seq = {}
op_desc_seq = OrderedDict()
# TODO: if the target process group has the same process with source process group
if set(target_process_group).intersection(
......@@ -1723,6 +1724,23 @@ class Resharder:
self, block, op_desc_seq, var_name, reshard_op, dist_attr
):
"""Parse op desc sequence and insert op in the block"""
# Parse all communicator groups for all ranks
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
for rank_id in op_desc_seq:
op_desc_list = op_desc_seq[rank_id]
for op_desc in op_desc_list:
if isinstance(op_desc, AllGatherOpDesc):
new_process_group(op_desc.group)
elif isinstance(op_desc, AllGatherConcatOpDesc):
new_process_group(op_desc.group)
elif isinstance(op_desc, SendOpDesc):
new_process_group([op_desc.src, op_desc.dst])
elif isinstance(op_desc, RecvOpDesc):
new_process_group([op_desc.src, op_desc.dst])
tensor_list = []
partition_tensor_list = []
if self.rank_id not in op_desc_seq.keys():
......@@ -2632,7 +2650,7 @@ class Resharder:
item,
recv_rank,
)
if self.rank_id == recv_rank:
elif self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block,
......@@ -2642,6 +2660,11 @@ class Resharder:
item,
recv_rank,
)
else:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
else:
for index, tensor_process in enumerate(
tensor_processes
......@@ -2659,11 +2682,16 @@ class Resharder:
self._handle_send(
block, idx, var, op, item, recv_rank
)
if self.rank_id == recv_rank:
elif self.rank_id == recv_rank:
# if recv bool data, recv then cast
self._hadnle_recv(
block, idx, var, op, item, recv_rank
)
else:
# Ensure every rank has a global view of communicator groups for entire cluters.
# When initialize communicators for pipeline parallel, every rank could
# conduct a correct global synchronization.
new_process_group([item, recv_rank])
cur_op_count = len(block.ops)
idx_offset = (
......
......@@ -91,8 +91,6 @@ def init_process_groups(group_map, rank):
# TODO should instantiate global group first
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if process_group.id == 0 or rank not in process_group.ranks:
continue
print(process_group)
process_group.instantiate()
......
......@@ -22,12 +22,12 @@ from functools import reduce
import numpy as np
import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.framework import core
from paddle.framework.io_utils import is_belong_to_optimizer, is_parameter
from paddle.static import Variable
from .dist_attribute import OperatorDistAttr, TensorDistAttr
from .process_group import get_all_process_groups
from .process_mesh import ProcessMesh
OpRole = core.op_proto_and_checker_maker.OpRole
......@@ -55,6 +55,8 @@ def get_logger(log_level, name="auto_parallel"):
)
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
else:
logger.setLevel(log_level)
return logger
......@@ -1816,6 +1818,8 @@ def debug_program(program, path, name):
def ring_id_to_process_group(ring_id):
from .process_group import get_all_process_groups
for g in get_all_process_groups():
if g.id == ring_id:
return g
......@@ -2355,3 +2359,17 @@ def is_dep_skip_op(op):
return True
return False
def _dygraph_guard_(func):
def __impl__(*args, **kwargs):
if paddle.framework.in_dynamic_mode():
return func(*args, **kwargs)
else:
with paddle.fluid.dygraph.guard():
return func(*args, **kwargs)
return __impl__
dygraph_guard = wrap_decorator(_dygraph_guard_)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册