Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
9f7126fc
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
9f7126fc
编写于
3月 04, 2022
作者:
O
Olatunji Ruwase
提交者:
GitHub
3月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor moe/non-moe gradient reduction (#1811)
上级
60fc06c6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
35 addition
and
31 deletion
+35
-31
deepspeed/runtime/engine.py
deepspeed/runtime/engine.py
+35
-31
未找到文件。
deepspeed/runtime/engine.py
浏览文件 @
9f7126fc
...
...
@@ -2074,8 +2074,8 @@ class DeepSpeedEngine(Module):
if
len
(
small_bucket
)
>
0
:
self
.
allreduce_and_copy
(
small_bucket
,
dp_group
)
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
grads
=
[]
def
_get_gradients_for_reduction
(
self
):
non_expert_
grads
=
[]
expert_grads
=
{}
if
self
.
has_moe_layers
:
for
key
in
self
.
expert_data_parallel_group
.
keys
():
...
...
@@ -2091,23 +2091,19 @@ class DeepSpeedEngine(Module):
param
.
grad
=
torch
.
zeros
(
param
.
size
(),
dtype
=
param
.
dtype
,
device
=
param
.
device
)
if
is_moe_param
(
param
):
expert_grads
[
param
.
group_name
].
append
(
param
.
grad
.
data
)
else
:
grads
.
append
(
param
.
grad
.
data
)
grad_data
=
param
.
grad
.
data
if
param_name
in
self
.
sparse_tensor_module_names
or
grad_data
.
is_sparse
:
grad_data
=
SparseTensor
(
grad_data
)
if
is_moe_param
(
param
):
expert_grads
[
param
.
group_name
].
append
(
grad_data
)
else
:
grad_data
=
param
.
grad
.
data
if
param_name
in
self
.
sparse_tensor_module_names
or
grad_data
.
is_sparse
:
if
is_moe_param
(
param
):
expert_grads
[
param
.
group_name
].
append
(
SparseTensor
(
grad_data
))
else
:
grads
.
append
(
SparseTensor
(
grad_data
))
else
:
if
is_moe_param
(
param
):
expert_grads
[
param
.
group_name
].
append
(
grad_data
)
else
:
grads
.
append
(
grad_data
)
non_expert_grads
.
append
(
grad_data
)
return
non_expert_grads
,
expert_grads
def
_reduce_non_expert_gradients
(
self
,
grads
,
elements_per_buffer
):
split_buckets
=
split_half_float_double_sparse
(
grads
)
for
_
,
bucket_tuple
in
enumerate
(
split_buckets
):
bucket_type
,
bucket
=
bucket_tuple
...
...
@@ -2124,21 +2120,29 @@ class DeepSpeedEngine(Module):
dp_group
=
dp_group
,
numel_per_bucket
=
elements_per_buffer
)
def
_reduce_expert_gradients
(
self
,
expert_grads
,
elements_per_buffer
):
for
ep_name
,
expert_grads_group
in
expert_grads
.
items
():
expert_split_buckets
=
split_half_float_double_sparse
(
expert_grads_group
)
for
i
,
bucket_tuple
in
enumerate
(
expert_split_buckets
):
bucket_type
,
bucket
=
bucket_tuple
if
bucket_type
==
SparseTensor
.
type
():
self
.
sparse_allreduce_no_retain
(
bucket
,
groups
.
_get_expert_data_parallel_group
(
ep_name
))
else
:
# Separate between diff groups
self
.
allreduce_no_retain
(
bucket
,
dp_group
=
groups
.
_get_expert_data_parallel_group
(
ep_name
),
numel_per_bucket
=
elements_per_buffer
)
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
non_expert_grads
,
expert_grads
=
self
.
_get_gradients_for_reduction
()
self
.
_reduce_non_expert_gradients
(
non_expert_grads
,
elements_per_buffer
)
if
self
.
has_moe_layers
:
for
ep_name
,
expert_grads_group
in
expert_grads
.
items
():
expert_split_buckets
=
split_half_float_double_sparse
(
expert_grads_group
)
for
i
,
bucket_tuple
in
enumerate
(
expert_split_buckets
):
bucket_type
,
bucket
=
bucket_tuple
if
bucket_type
==
SparseTensor
.
type
():
self
.
sparse_allreduce_no_retain
(
bucket
,
groups
.
_get_expert_data_parallel_group
(
ep_name
))
else
:
# Separate between diff groups
self
.
allreduce_no_retain
(
bucket
,
dp_group
=
groups
.
_get_expert_data_parallel_group
(
ep_name
),
numel_per_bucket
=
elements_per_buffer
)
self
.
_reduce_expert_gradients
(
expert_grads
,
elements_per_buffer
)
def
sparse_allreduce_no_retain
(
self
,
bucket
,
dp_group
):
allreduced_sparses
=
self
.
sparse_allreduce_bucket
(
bucket
,
dp_group
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录