Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
a148bd33
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
a148bd33
编写于
9月 21, 2020
作者:
R
RezaYazdaniAminabadi
提交者:
GitHub
9月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add configurable intermediate size to transformer kernels (#423)
上级
a825f996
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
46 addition
and
27 deletion
+46
-27
csrc/includes/context.h
csrc/includes/context.h
+1
-1
csrc/transformer/ds_transformer_cuda.cpp
csrc/transformer/ds_transformer_cuda.cpp
+13
-6
deepspeed/ops/transformer/transformer.py
deepspeed/ops/transformer/transformer.py
+19
-12
tests/unit/test_cuda_backward.py
tests/unit/test_cuda_backward.py
+6
-4
tests/unit/test_cuda_forward.py
tests/unit/test_cuda_forward.py
+7
-4
未找到文件。
csrc/includes/context.h
浏览文件 @
a148bd33
...
...
@@ -69,7 +69,7 @@ public:
if
(
!
_workspace
)
{
assert
(
_workspace
==
nullptr
);
cudaMalloc
(
&
_workspace
,
size
);
}
else
if
(
_workSpaceSize
!=
size
)
{
}
else
if
(
_workSpaceSize
<
size
)
{
cudaFree
(
_workspace
);
cudaMalloc
(
&
_workspace
,
size
);
}
...
...
csrc/transformer/ds_transformer_cuda.cpp
浏览文件 @
a148bd33
...
...
@@ -20,13 +20,14 @@ template <typename T>
size_t
get_workspace_size
(
int
maxBatchSize
,
int
seq_len
,
int
hidden_size
,
int
intermediate_size
,
int
heads
,
bool
training
,
bool
gelu_checkpoint
)
{
size_t
workSpacesize
=
4
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
if
(
training
)
{
workSpacesize
+=
(
std
::
max
((
4
*
size_t
(
maxBatchSize
)
*
seq_len
*
hidden
_size
),
workSpacesize
+=
(
std
::
max
((
size_t
(
maxBatchSize
)
*
seq_len
*
intermediate
_size
),
2
*
(
size_t
(
maxBatchSize
)
*
heads
*
seq_len
*
seq_len
)));
if
(
gelu_checkpoint
)
workSpacesize
+=
2
*
(
size_t
(
maxBatchSize
)
*
seq_len
*
hidden_size
);
}
...
...
@@ -92,12 +93,12 @@ BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
false
,
!
normalize_invertible
)),
_ff1
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
4
*
hidden
_size
,
_intermediate
_size
,
hidden_size
,
gemm_algos
[
1
])),
_ff2
(
typename
FeedForward
<
T
>::
Config
(
batch_size
*
seq_length
,
hidden_size
,
4
*
hidden
_size
,
_intermediate
_size
,
gemm_algos
[
2
])),
_softmax
(
typename
Softmax
<
T
>::
Config
(
batch_size
,
num_heads
,
seq_length
)),
_gelu
(
typename
Gelu
<
T
>::
Config
(
_batch_size
,
_seq_length
,
_intermediate_size
)),
...
...
@@ -143,8 +144,13 @@ BertTransformerLayer<T>::~BertTransformerLayer()
template
<
typename
T
>
void
BertTransformerLayer
<
T
>::
Initialize
()
{
Context
::
Instance
().
GenWorkSpace
(
get_workspace_size
<
T
>
(
_batch_size
,
_seq_length
,
_hidden_size
,
_heads
,
_training
,
_gelu_checkpoint
));
Context
::
Instance
().
GenWorkSpace
(
get_workspace_size
<
T
>
(
_batch_size
,
_seq_length
,
_hidden_size
,
_intermediate_size
,
_heads
,
_training
,
_gelu_checkpoint
));
if
(
std
::
is_same
<
T
,
__half
>::
value
)
cublasSetMathMode
(
_cublasHandle
,
CUBLAS_TENSOR_OP_MATH
);
}
...
...
@@ -343,7 +349,8 @@ void BertTransformerLayer<T>::Backward(int bsz,
T
*
buf_2
=
buf_1
+
small_buf_size
;
T
*
buf_3
=
buf_2
+
small_buf_size
;
T
*
ff2_buf
=
buf_3
+
(
_gelu_checkpoint
?
3
:
1
)
*
small_buf_size
;
T
*
ff2_buf
=
(
_gelu_checkpoint
?
buf_2
+
(
bsz
*
_seq_length
*
_intermediate_size
)
:
buf_3
+
small_buf_size
);
T
*
ctx_bufB_ptr_recomp
=
ff2_buf
+
(
_seq_length
*
_seq_length
*
bsz
*
_heads
);
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
...
...
deepspeed/ops/transformer/transformer.py
100644 → 100755
浏览文件 @
a148bd33
...
...
@@ -18,6 +18,7 @@ class TransformerConfig():
batch_size
,
max_seq_length
,
hidden_size
,
intermediate_size
,
heads
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
...
...
@@ -26,6 +27,7 @@ class TransformerConfig():
self
.
layer_id
=
-
1
self
.
batch_size
=
batch_size
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
max_seq_length
=
max_seq_length
self
.
heads
=
heads
self
.
attn_dropout_ratio
=
attn_dropout_ratio
...
...
@@ -44,6 +46,8 @@ class DeepSpeedTransformerConfig(TransformerConfig):
hidden_size: The hidden size of the transformer layer
intermediate_size: The intermediate size of the feed-forward part of transformer layer
heads: The number of heads in the self-attention of the transformer layer
attn_dropout_ratio: The ratio of dropout for the attention's output
...
...
@@ -88,6 +92,7 @@ class DeepSpeedTransformerConfig(TransformerConfig):
batch_size
=-
1
,
max_seq_length
=-
1
,
hidden_size
=-
1
,
intermediate_size
=-
1
,
heads
=-
1
,
attn_dropout_ratio
=-
1
,
hidden_dropout_ratio
=-
1
,
...
...
@@ -103,14 +108,16 @@ class DeepSpeedTransformerConfig(TransformerConfig):
attn_dropout_checkpoint
=
False
,
stochastic_mode
=
False
):
super
(
DeepSpeedTransformerConfig
,
self
).
__init__
(
batch_size
,
max_seq_length
,
hidden_size
,
heads
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
num_hidden_layers
,
initializer_range
)
self
).
__init__
(
batch_size
,
max_seq_length
,
hidden_size
,
(
intermediate_size
if
intermediate_size
>
0
else
4
*
hidden_size
),
heads
,
attn_dropout_ratio
,
hidden_dropout_ratio
,
num_hidden_layers
,
initializer_range
)
self
.
fp16
=
fp16
self
.
pre_layer_norm
=
pre_layer_norm
self
.
local_rank
=
local_rank
...
...
@@ -432,12 +439,12 @@ class DeepSpeedTransformerLayer(nn.Module):
self
.
attn_nw
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
))
self
.
attn_nb
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
))
self
.
inter_w
=
nn
.
Parameter
(
torch
.
Tensor
(
4
*
self
.
config
.
hidden
_size
,
torch
.
Tensor
(
self
.
config
.
intermediate
_size
,
self
.
config
.
hidden_size
))
self
.
inter_b
=
nn
.
Parameter
(
torch
.
Tensor
(
4
*
self
.
config
.
hidden
_size
))
self
.
inter_b
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
intermediate
_size
))
self
.
output_w
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
,
4
*
self
.
config
.
hidden
_size
))
self
.
config
.
intermediate
_size
))
self
.
output_b
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
))
self
.
norm_w
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
))
self
.
norm_b
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
config
.
hidden_size
))
...
...
@@ -485,7 +492,7 @@ class DeepSpeedTransformerLayer(nn.Module):
self
.
config
.
batch_size
,
self
.
config
.
hidden_size
,
self
.
config
.
heads
,
4
*
self
.
config
.
hidden
_size
,
self
.
config
.
intermediate
_size
,
self
.
config
.
max_seq_length
,
self
.
config
.
attn_dropout_ratio
,
self
.
config
.
hidden_dropout_ratio
,
...
...
tests/unit/test_cuda_backward.py
浏览文件 @
a148bd33
...
...
@@ -146,7 +146,7 @@ def create_models(ds_config):
hidden_size
=
ds_config
.
hidden_size
,
num_hidden_layers
=
ds_config
.
num_hidden_layers
,
num_attention_heads
=
ds_config
.
heads
,
intermediate_size
=
4
*
ds_config
.
hidden
_size
,
intermediate_size
=
ds_config
.
intermediate
_size
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
ds_config
.
hidden_dropout_ratio
,
attention_probs_dropout_prob
=
ds_config
.
attn_dropout_ratio
,
...
...
@@ -166,12 +166,12 @@ def create_models(ds_config):
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
weights
[
4
].
data
.
fill_
(
1.0
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
4
*
ds_config
.
hidden
_size
,
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
intermediate
_size
,
ds_config
.
hidden_size
)))
weights
[
5
].
data
.
normal_
(
mean
=
0.0
,
std
=
ds_config
.
initializer_range
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
,
4
*
ds_config
.
hidden
_size
)))
ds_config
.
intermediate
_size
)))
weights
[
6
].
data
.
normal_
(
mean
=
0.0
,
std
=
ds_config
.
initializer_range
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
weights
[
7
].
data
.
fill_
(
1.0
)
...
...
@@ -181,7 +181,7 @@ def create_models(ds_config):
for
i
in
range
(
4
):
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
biases
[
i
+
1
].
data
.
zero_
()
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
4
*
ds_config
.
hidden
_size
)))
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
intermediate
_size
)))
biases
[
5
].
data
.
zero_
()
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
biases
[
6
].
data
.
zero_
()
...
...
@@ -278,6 +278,7 @@ def test_backward(batch_size,
ds_config
.
layer_id
=
None
ds_config
.
batch_size
=
batch_size
ds_config
.
hidden_size
=
hidden_size
ds_config
.
intermediate_size
=
hidden_size
ds_config
.
max_seq_length
=
seq_len
ds_config
.
heads
=
heads
ds_config
.
attn_dropout_ratio
=
0.0
...
...
@@ -314,6 +315,7 @@ def test_backward(batch_size,
# ds_config.layer_id = None
# ds_config.batch_size = batch_size
# ds_config.hidden_size = hidden_size
# ds_config.intermediate_size = 4 * hidden_size
# ds_config.max_seq_length = seq_len
# ds_config.heads = heads
# ds_config.attn_dropout_ratio = 0.0
...
...
tests/unit/test_cuda_forward.py
浏览文件 @
a148bd33
...
...
@@ -113,7 +113,7 @@ def create_models(ds_config):
num_hidden_layers
=
ds_config
.
num_hidden_layers
,
num_attention_heads
=
ds_config
.
heads
,
batch_size
=
ds_config
.
batch_size
,
intermediate_size
=
4
*
ds_config
.
hidden
_size
,
intermediate_size
=
ds_config
.
intermediate
_size
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
ds_config
.
hidden_dropout_ratio
,
attention_probs_dropout_prob
=
ds_config
.
attn_dropout_ratio
,
...
...
@@ -134,12 +134,12 @@ def create_models(ds_config):
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
weights
[
4
].
data
.
fill_
(
1.0
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
4
*
ds_config
.
hidden
_size
,
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
intermediate
_size
,
ds_config
.
hidden_size
)))
weights
[
5
].
data
.
normal_
(
mean
=
0.0
,
std
=
ds_config
.
initializer_range
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
,
4
*
ds_config
.
hidden
_size
)))
ds_config
.
intermediate
_size
)))
weights
[
6
].
data
.
normal_
(
mean
=
0.0
,
std
=
ds_config
.
initializer_range
)
weights
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
weights
[
7
].
data
.
fill_
(
1.0
)
...
...
@@ -149,7 +149,7 @@ def create_models(ds_config):
for
i
in
range
(
4
):
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
biases
[
i
+
1
].
data
.
zero_
()
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
4
*
ds_config
.
hidden
_size
)))
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
intermediate
_size
)))
biases
[
5
].
data
.
zero_
()
biases
.
append
(
nn
.
Parameter
(
torch
.
Tensor
(
ds_config
.
hidden_size
)))
biases
[
6
].
data
.
zero_
()
...
...
@@ -246,6 +246,7 @@ def test_forward(batch_size,
ds_config
.
layer_id
=
None
ds_config
.
batch_size
=
batch_size
ds_config
.
hidden_size
=
hidden_size
ds_config
.
intermediate_size
=
4
*
hidden_size
ds_config
.
max_seq_length
=
seq_len
ds_config
.
heads
=
heads
ds_config
.
attn_dropout_ratio
=
0.0
...
...
@@ -282,6 +283,7 @@ def test_forward_with_small_bsz(batch_size,
ds_config
.
layer_id
=
None
ds_config
.
batch_size
=
batch_size
ds_config
.
hidden_size
=
hidden_size
ds_config
.
intermediate_size
=
4
*
hidden_size
ds_config
.
max_seq_length
=
seq_len
ds_config
.
heads
=
heads
ds_config
.
attn_dropout_ratio
=
0.0
...
...
@@ -316,6 +318,7 @@ def test_forward_stochastic(batch_size,
ds_config
.
layer_id
=
None
ds_config
.
batch_size
=
batch_size
ds_config
.
hidden_size
=
hidden_size
ds_config
.
intermediate_size
=
4
*
hidden_size
ds_config
.
max_seq_length
=
seq_len
ds_config
.
heads
=
heads
ds_config
.
attn_dropout_ratio
=
0.0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录