未验证 提交 86bbb0f2 编写于 作者: L lilong12 提交者: GitHub

add backend for heter training (#41526) (#41651)

上级 9f2ae360
......@@ -138,7 +138,7 @@ _group_map_by_name = {}
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"
_valid_backend_list = ['nccl', 'gloo', 'hccl']
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter']
_default_store = None # the default tcp store
_default_backend = None
......@@ -234,6 +234,31 @@ def _new_process_group_impl(backend,
pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
elif backend == "hccl":
pg = core.ProcessGroupHCCL(store, rank, world_size, group_id)
elif backend == "heter":
cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
assert cluster_id >= 0, "please set the CLUSTER_ID variable."
cluster_size = os.getenv("CLUSTER_SIZE", None)
assert cluster_size, "please set the CLUSTER_SIZE variable."
cluster_size = cluster_size.split(",")
cluster_size = [int(s) for s in cluster_size]
switch_ep = os.getenv("CLUSTER_SWITCH", None)
assert switch_ep, "please set the CLUSTER_SWITCH variable."
cluster_size_cumsum = np.cumsum(cluster_size)
cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
cluster_id - 1]
global_rank = cluster_offset + rank
global_world_size = cluster_size_cumsum[-1]
pg = core.ProcessGroupHeter(
store,
rank=global_rank,
world_size=global_world_size,
gid=0,
local_rank=rank,
local_size=world_size,
gloo_rank=cluster_id,
gloo_size=len(cluster_size),
with_switch=True,
switch_endpoint=switch_ep)
return pg
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册