Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
7d8ad45d
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
7d8ad45d
编写于
8月 16, 2022
作者:
J
Jeff Rasley
提交者:
GitHub
8月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix regression w. dist_init_required (#2225)
上级
9b418c1e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
106 addition
and
14 deletion
+106
-14
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
deepspeed/comm/comm.py
deepspeed/comm/comm.py
+5
-0
tests/unit/comm/test_dist.py
tests/unit/comm/test_dist.py
+83
-1
tests/unit/common.py
tests/unit/common.py
+17
-12
未找到文件。
.pre-commit-config.yaml
浏览文件 @
7d8ad45d
...
...
@@ -39,7 +39,7 @@ repos:
name
:
check-torchdist
entry
:
./scripts/check-torchdist.py
language
:
script
exclude
:
^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py)
exclude
:
^(deepspeed/comm/|docs/|benchmarks/|scripts/check-torchdist.py|deepspeed/moe/sharded_moe.py|deepspeed/runtime/comm/coalesced_collectives.py|deepspeed/elasticity/elastic_agent.py|deepspeed/launcher/launch.py
|tests/unit/comm/test_dist.py
)
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
-
repo
:
https://github.com/codespell-project/codespell
...
...
deepspeed/comm/comm.py
浏览文件 @
7d8ad45d
...
...
@@ -600,6 +600,11 @@ def init_distributed(dist_backend="nccl",
if
dist_init_required
is
None
:
dist_init_required
=
cdb
is
None
or
not
cdb
.
is_initialized
()
if
cdb
is
None
and
torch
.
distributed
.
is_initialized
():
# The user initialized torch.dist themselves, create cdb and short-circuit
cdb
=
TorchBackend
(
dist_backend
,
timeout
,
init_method
)
return
if
dist_init_required
is
False
:
assert
(
cdb
is
not
None
and
cdb
.
is_initialized
()
is
True
...
...
tests/unit/comm/test_dist.py
浏览文件 @
7d8ad45d
import
torch
import
deepspeed.comm
as
dist
import
deepspeed
from
tests.unit.common
import
DistributedTest
from
tests.unit.common
import
DistributedTest
,
get_master_port
from
tests.unit.simple_model
import
SimpleModel
import
pytest
...
...
@@ -71,3 +73,83 @@ class TestDistAllReduce(DistributedTest):
result
=
torch
.
ones
(
1
,
3
).
cuda
()
*
sum_of_ranks
dist
.
all_reduce
(
x
)
assert
torch
.
all
(
x
==
result
)
@
pytest
.
mark
.
parametrize
(
"dist_init_required"
,
[
True
,
False
,
None
])
class
TestDistInit
(
DistributedTest
):
init_distributed
=
False
def
test_already_init
(
self
,
dist_init_required
):
torch
.
distributed
.
init_process_group
(
'nccl'
)
deepspeed
.
init_distributed
(
'nccl'
,
dist_init_required
=
dist_init_required
)
def
test_no_init
(
self
,
dist_init_required
):
if
dist_init_required
or
dist_init_required
is
None
:
deepspeed
.
init_distributed
(
'nccl'
,
dist_init_required
=
dist_init_required
)
else
:
# torch.dist is not done and for some reason the user says they don't want it done
with
pytest
.
raises
(
Exception
):
deepspeed
.
init_distributed
(
'nccl'
,
dist_init_required
=
dist_init_required
)
class
TestDistInitNoEnv
(
DistributedTest
):
world_size
=
1
init_distributed
=
False
set_dist_env
=
False
def
test
(
self
):
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
f
"tcp://127.0.0.1:
{
get_master_port
()
}
"
,
world_size
=
1
,
rank
=
0
)
assert
torch
.
distributed
.
is_initialized
()
deepspeed
.
init_distributed
(
'nccl'
,
auto_mpi_discovery
=
True
)
@
pytest
.
mark
.
parametrize
(
"dist_init_required"
,
[
True
,
False
])
class
TestDistInitWithModel
(
DistributedTest
):
init_distributed
=
False
def
test_already_init
(
self
,
dist_init_required
):
torch
.
distributed
.
init_process_group
(
'nccl'
)
model
=
SimpleModel
(
4
)
config_dict
=
{
"train_micro_batch_size_per_gpu"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{}
}
}
engine
,
*
_
=
deepspeed
.
initialize
(
model
=
model
,
config
=
config_dict
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
dist_init_required
)
def
test_no_init
(
self
,
dist_init_required
):
model
=
SimpleModel
(
4
)
config_dict
=
{
"train_micro_batch_size_per_gpu"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{}
}
}
if
dist_init_required
:
engine
,
*
_
=
deepspeed
.
initialize
(
model
=
model
,
config
=
config_dict
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
dist_init_required
)
else
:
# torch.dist is not done and for some reason the user says they don't want it done
with
pytest
.
raises
(
Exception
):
engine
,
*
_
=
deepspeed
.
initialize
(
model
=
model
,
config
=
config_dict
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
dist_init_required
)
tests/unit/common.py
浏览文件 @
7d8ad45d
...
...
@@ -67,6 +67,8 @@ class DistributedTest(ABC):
is_dist_test
=
True
world_size
=
2
backend
=
"nccl"
init_distributed
=
True
set_dist_env
=
True
# Temporary directory that is shared among test methods in a class
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"class"
)
...
...
@@ -151,20 +153,22 @@ class DistributedTest(ABC):
def
_dist_init
(
self
,
local_rank
,
num_procs
,
skip_msg
):
"""Initialize deepspeed.comm and execute the user function. """
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
get_master_port
()
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
local_rank
)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os
.
environ
[
'RANK'
]
=
str
(
local_rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
num_procs
)
if
self
.
set_dist_env
:
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
get_master_port
()
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
local_rank
)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os
.
environ
[
'RANK'
]
=
str
(
local_rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
num_procs
)
# turn off NCCL logging if set
os
.
environ
.
pop
(
'NCCL_DEBUG'
,
None
)
set_cuda_visibile
()
deepspeed
.
init_distributed
(
dist_backend
=
self
.
backend
)
dist
.
barrier
()
if
self
.
init_distributed
:
deepspeed
.
init_distributed
(
dist_backend
=
self
.
backend
)
dist
.
barrier
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
set_device
(
local_rank
)
...
...
@@ -177,10 +181,11 @@ class DistributedTest(ABC):
else
:
raise
e
# make sure all ranks finish at the same time
dist
.
barrier
()
# tear down after test completes
dist
.
destroy_process_group
()
if
self
.
init_distributed
or
dist
.
is_initialized
():
# make sure all ranks finish at the same time
dist
.
barrier
()
# tear down after test completes
dist
.
destroy_process_group
()
def
distributed_test
(
world_size
=
2
,
backend
=
'nccl'
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录