Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
852c524a
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
852c524a
编写于
1月 25, 2021
作者:
S
sdtblck
提交者:
GitHub
1月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add optional timeout parameter to deepspeed.init_distributed (#637)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
34c83a5a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
21 addition
and
12 deletion
+21
-12
deepspeed/constants.py
deepspeed/constants.py
+8
-0
deepspeed/utils/distributed.py
deepspeed/utils/distributed.py
+13
-12
未找到文件。
deepspeed/constants.py
浏览文件 @
852c524a
'''
Copyright 2020 The Microsoft DeepSpeed Team
'''
from
datetime
import
timedelta
#############################################
# Torch distributed constants
#############################################
TORCH_DISTRIBUTED_DEFAULT_PORT
=
29500
# Default process group wide timeout, if applicable.
# This only applies to the gloo and nccl backends
# (only if NCCL_BLOCKING_WAIT or NCCL_ASYNC_ERROR_HANDLING is set to 1).
# To make an attempt at backwards compatibility with THD, we use an
# extraordinarily high default timeout, given that THD did not have timeouts.
default_pg_timeout
=
timedelta
(
minutes
=
30
)
deepspeed/utils/distributed.py
浏览文件 @
852c524a
...
...
@@ -3,25 +3,25 @@ Copyright 2020 The Microsoft DeepSpeed Team
'''
import
os
import
torch
from
datetime
import
timedelta
from
.logging
import
logger
from
..constants
import
TORCH_DISTRIBUTED_DEFAULT_PORT
from
..constants
import
TORCH_DISTRIBUTED_DEFAULT_PORT
,
default_pg_timeout
def
init_distributed
(
dist_backend
=
"nccl"
,
auto_mpi_discovery
=
True
,
distributed_port
=
TORCH_DISTRIBUTED_DEFAULT_PORT
,
verbose
=
True
):
"""Initialize torch.distributed backend, potentially performing MPI discovery if needed.
verbose
=
True
,
timeout
=
default_pg_timeout
):
"""
Initialize torch.distributed backend, potentially performing MPI discovery if needed
Arguments:
dist_backend: torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery: if distributed environment variables are not set, attempt to discover them from MPI
distributed_port: torch distributed backend port
verbose: verbose logging
dist_backend (str): torch distributed backend, e.g., nccl, mpi, gloo
auto_mpi_discovery (bool): if distributed environment variables are not set, attempt to discover them from MPI
distributed_port (int, optional): torch distributed backend port
verbose (bool, optional): verbose logging
timeout (timedelta, optional): Timeout for operations executed against the process group. Default value equals 30 minutes.
"""
required_env
=
[
"RANK"
,
"WORLD_SIZE"
,
"MASTER_ADDR"
,
"MASTER_PORT"
,
"LOCAL_RANK"
]
if
auto_mpi_discovery
and
not
all
(
map
(
lambda
v
:
v
in
os
.
environ
,
required_env
)):
...
...
@@ -38,7 +38,8 @@ def init_distributed(dist_backend="nccl",
if
verbose
:
logger
.
info
(
"Initializing torch distributed with backend: {}"
.
format
(
dist_backend
))
torch
.
distributed
.
init_process_group
(
backend
=
dist_backend
)
assert
isinstance
(
timeout
,
timedelta
)
torch
.
distributed
.
init_process_group
(
backend
=
dist_backend
,
timeout
=
timeout
)
def
mpi_discovery
(
distributed_port
=
TORCH_DISTRIBUTED_DEFAULT_PORT
,
verbose
=
True
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录