Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
18d55e54
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
18d55e54
编写于
12月 07, 2022
作者:
Q
Quentin Anthony
提交者:
GitHub
12月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update barrier and reduce_scatter_base to conform to PyTorch signatures (#2570)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
731965db
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
44 addition
and
4 deletion
+44
-4
deepspeed/comm/comm.py
deepspeed/comm/comm.py
+23
-2
deepspeed/comm/torch.py
deepspeed/comm/torch.py
+21
-2
未找到文件。
deepspeed/comm/comm.py
浏览文件 @
18d55e54
...
...
@@ -255,6 +255,7 @@ def has_reduce_scatter_base():
def
reduce_scatter_fn
(
output_tensor
,
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
...
...
@@ -265,6 +266,7 @@ def reduce_scatter_fn(output_tensor,
if
cdb
.
has_reduce_scatter_base
:
return
reduce_scatter_base
(
output_tensor
,
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
,
prof
=
prof
,
...
...
@@ -279,6 +281,7 @@ def reduce_scatter_fn(output_tensor,
input_tensor_lst
=
list
(
torch
.
chunk
(
tensor
,
cdb
.
get_world_size
(
group
)))
return
reduce_scatter
(
output_tensor
,
input_tensor_lst
,
op
=
op
,
group
=
group
,
async_op
=
async_op
,
prof
=
prof
,
...
...
@@ -288,6 +291,7 @@ def reduce_scatter_fn(output_tensor,
@
timed_op
def
reduce_scatter_base
(
output_tensor
,
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
...
...
@@ -296,6 +300,7 @@ def reduce_scatter_base(output_tensor,
global
cdb
return
cdb
.
reduce_scatter_base
(
output_tensor
=
output_tensor
,
input_tensor
=
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
...
...
@@ -453,9 +458,25 @@ def scatter(tensor,
@
timed_op
def
barrier
(
group
=
None
,
prof
=
False
,
log_name
=
'barrier'
,
debug
=
get_caller_func
()):
def
barrier
(
group
=
None
,
async_op
=
False
,
device_ids
=
None
,
prof
=
False
,
log_name
=
'barrier'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
barrier
(
group
=
group
,
async_op
=
async_op
,
device_ids
=
device_ids
)
@
timed_op
def
monitored_barrier
(
group
=
None
,
timeout
=
None
,
wait_all_ranks
=
False
,
prof
=
False
,
log_name
=
'monitored_barrier'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
barrier
()
return
cdb
.
barrier
(
group
=
group
,
timeout
=
timeout
,
wait_all_ranks
=
wait_all_ranks
)
def
log_summary
():
...
...
deepspeed/comm/torch.py
浏览文件 @
18d55e54
...
...
@@ -95,11 +95,13 @@ class TorchBackend(Backend):
def
reduce_scatter_base
(
self
,
output_tensor
,
input_tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
if
self
.
has_reduce_scatter_base
:
return
torch
.
distributed
.
_reduce_scatter_base
(
output_tensor
,
input_tensor
,
op
=
self
.
_reduce_op
(
op
),
group
=
group
,
async_op
=
async_op
)
else
:
...
...
@@ -149,8 +151,25 @@ class TorchBackend(Backend):
group
=
group
,
async_op
=
async_op
)
def
barrier
(
self
):
return
torch
.
distributed
.
barrier
()
def
barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
async_op
=
False
,
device_ids
=
None
):
if
group
is
None
:
group
=
torch
.
distributed
.
GroupMember
.
WORLD
return
torch
.
distributed
.
barrier
(
group
=
group
,
async_op
=
async_op
,
device_ids
=
device_ids
)
def
monitored_barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
timeout
=
None
,
wait_all_ranks
=
False
):
if
group
is
None
:
group
=
torch
.
distributed
.
GroupMember
.
WORLD
return
torch
.
distributed
.
monitored_barrier
(
group
=
group
,
timeout
=
timeout
,
wait_all_ranks
=
wait_all_ranks
)
def
get_rank
(
self
,
group
=
None
):
return
torch
.
distributed
.
get_rank
(
group
=
group
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录