Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
07887f66
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
07887f66
编写于
12月 12, 2021
作者:
G
Gary Miguel
提交者:
GitHub
12月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
sharded_moe: make top1gating ONNX-exportable (#1578)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
64c2946a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
53 addition
and
33 deletion
+53
-33
deepspeed/moe/sharded_moe.py
deepspeed/moe/sharded_moe.py
+53
-33
未找到文件。
deepspeed/moe/sharded_moe.py
浏览文件 @
07887f66
...
...
@@ -22,6 +22,7 @@ import torch
from
torch
import
Tensor
import
torch.distributed
as
dist
from
torch.nn
import
Module
,
ModuleList
import
torch.nn.functional
as
F
if
TYPE_CHECKING
:
Base
=
Module
[
Tensor
]
...
...
@@ -79,12 +80,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return
gumbel
(
shape
)
import
torch.distributed
as
dist
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -102,16 +97,13 @@ class _AllToAll(torch.autograd.Function):
return
(
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
))
from
torch
import
nn
import
torch.nn.functional
as
F
import
math
# einsum rewrites are on par or more performant
# switch can be bubbled up in future
USE_EINSUM
=
True
# einsum dimensions: (g)roup, (s)equence, (e)xpert, (m)odel, (c)apacity
# See https://arxiv.org/pdf/2006.16668.pdf for details.
def
einsum
(
rule
,
a
,
b
):
if
USE_EINSUM
:
return
torch
.
einsum
(
rule
,
a
,
b
)
...
...
@@ -143,14 +135,47 @@ def einsum(rule, a, b):
return
torch
.
einsum
(
rule
,
a
,
b
)
def
top1gating
(
logits
:
torch
.
Tensor
,
# The following functions are extracted and scripted
# because otherwise during a torch.jit.trace, the non-Tensor
# values used in the calculations get recorded as constants.
# torch.jit.script coerces them into Tensors and preserves
# their dynamic shapes. This enables ONNX export.
# We can't script the entire top1gating function because it
# includes stateful caching logic which is incompatible with ONNX.
@
torch
.
jit
.
script
def
_capacity
(
gates
:
Tensor
,
capacity_factor
:
Tensor
,
min_capacity
:
Tensor
)
->
Tensor
:
# gates has shape of SE
num_tokens
=
gates
.
shape
[
0
]
num_experts
=
gates
.
shape
[
1
]
# to(torch.int64) works around a bug in torch.onnx.export:
# it should cast k to int64 when converting torch.topk but it doesn't.
capacity
=
torch
.
ceil
((
num_tokens
/
num_experts
)
*
capacity_factor
).
to
(
torch
.
int64
)
if
capacity
<
min_capacity
:
capacity
=
min_capacity
.
to
(
torch
.
int64
)
return
capacity
@
torch
.
jit
.
script
def
_top_idx
(
source
,
k
):
return
torch
.
topk
(
source
,
k
=
k
,
dim
=
0
)[
1
]
@
torch
.
jit
.
script
def
_one_hot_to_float
(
x
,
num_classes
):
return
F
.
one_hot
(
x
,
num_classes
=
num_classes
).
float
()
def
top1gating
(
logits
:
Tensor
,
capacity_factor
:
float
,
min_capacity
:
int
,
used_token
:
torch
.
Tensor
=
None
,
used_token
:
Tensor
=
None
,
noisy_gate_policy
:
Optional
[
str
]
=
None
,
drop_tokens
:
bool
=
True
,
use_rts
:
bool
=
True
,
use_tutel
:
bool
=
False
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top1Gating on logits."""
...
...
@@ -159,19 +184,16 @@ def top1gating(logits: torch.Tensor,
# everything is in fp32 in this function
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# gates has shape of SE
num_tokens
=
int
(
gates
.
shape
[
0
])
num_experts
=
int
(
gates
.
shape
[
1
])
# round-up
capacity
=
math
.
ceil
((
num_tokens
/
num_experts
)
*
capacity_factor
)
if
capacity
<
min_capacity
:
capacity
=
min_capacity
capacity
=
_capacity
(
gates
,
torch
.
tensor
(
capacity_factor
),
torch
.
tensor
(
min_capacity
))
# Create a mask for 1st's expert per token
# noisy gating
indices1_s
=
torch
.
argmax
(
logits_w_noise
if
noisy_gate_policy
==
'RSample'
else
gates
,
dim
=
1
)
num_experts
=
int
(
gates
.
shape
[
1
])
mask1
=
F
.
one_hot
(
indices1_s
,
num_classes
=
num_experts
)
# mask only used tokens
...
...
@@ -207,7 +229,7 @@ def top1gating(logits: torch.Tensor,
assert
logits
.
shape
[
0
]
>=
min_capacity
,
"No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size."
_
,
top_idx
=
torch
.
topk
(
mask1_rand
,
k
=
capacity
,
dim
=
0
)
top_idx
=
_top_idx
(
mask1_rand
,
capacity
)
new_mask1
=
mask1
*
torch
.
zeros_like
(
mask1
).
scatter_
(
0
,
top_idx
,
1
)
mask1
=
new_mask1
...
...
@@ -236,7 +258,7 @@ def top1gating(logits: torch.Tensor,
mask1_float
=
mask1
.
float
()
gates
=
gates
*
mask1_float
locations1_sc
=
F
.
one_hot
(
locations1_s
,
num_classes
=
capacity
).
float
(
)
locations1_sc
=
_one_hot_to_float
(
locations1_s
,
capacity
)
combine_weights
=
einsum
(
"se,sc->sec"
,
gates
,
locations1_sc
)
dispatch_mask
=
combine_weights
.
bool
()
...
...
@@ -244,24 +266,22 @@ def top1gating(logits: torch.Tensor,
return
l_aux
,
combine_weights
,
dispatch_mask
,
exp_counts
def
top2gating
(
logits
:
torch
.
Tensor
,
def
top2gating
(
logits
:
Tensor
,
capacity_factor
:
float
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
# logits_fp32 = logits.to(torch.float32)
gates
=
F
.
softmax
(
logits
,
dim
=
1
)
# gates has shape of SE
num_tokens
=
int
(
gates
.
shape
[
0
])
num_experts
=
int
(
gates
.
shape
[
1
])
# capacity = (2 * num_tokens // num_experts) * capacity_factor
# round-up
capacity
=
math
.
ceil
((
2
*
num_tokens
/
num_experts
)
*
capacity_factor
)
capacity
=
_capacity
(
gates
,
torch
.
tensor
(
capacity_factor
*
2
),
torch
.
tensor
(
min_capacity
))
# Create a mask for 1st's expert per token
indices1_s
=
torch
.
argmax
(
gates
,
dim
=
1
)
num_experts
=
int
(
gates
.
shape
[
1
])
mask1
=
F
.
one_hot
(
indices1_s
,
num_classes
=
num_experts
)
# Create a mask for 2nd's expert per token using Gumbel-max trick
...
...
@@ -308,8 +328,8 @@ def top2gating(logits: torch.Tensor,
# Calculate combine_weights and dispatch_mask
gates1
=
einsum
(
"s,se->se"
,
gates1_s
,
mask1_float
)
gates2
=
einsum
(
"s,se->se"
,
gates2_s
,
mask2_float
)
locations1_sc
=
F
.
one_hot
(
locations1_s
,
num_classes
=
capacity
).
float
(
)
locations2_sc
=
F
.
one_hot
(
locations2_s
,
num_classes
=
capacity
).
float
(
)
locations1_sc
=
_one_hot_to_float
(
locations1_s
,
capacity
)
locations2_sc
=
_one_hot_to_float
(
locations2_s
,
capacity
)
combine1_sec
=
einsum
(
"se,sc->sec"
,
gates1
,
locations1_sc
)
combine2_sec
=
einsum
(
"se,sc->sec"
,
gates2
,
locations2_sc
)
combine_weights
=
combine1_sec
+
combine2_sec
...
...
@@ -318,7 +338,7 @@ def top2gating(logits: torch.Tensor,
return
l_aux
,
combine_weights
,
dispatch_mask
,
exp_counts
class
TopKGate
(
torch
.
nn
.
Module
):
class
TopKGate
(
Module
):
"""Gate module which implements Top2Gating as described in Gshard_.
::
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录