Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
87eaf8f9
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
87eaf8f9
编写于
3月 06, 2023
作者:
L
Lev Kurilenko
提交者:
GitHub
3月 06, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Check for local CUDA graphs when enable_cuda_graph=True (#2941)
上级
2ede0d94
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
73 addition
and
15 deletion
+73
-15
deepspeed/inference/engine.py
deepspeed/inference/engine.py
+28
-1
deepspeed/model_implementations/diffusers/unet.py
deepspeed/model_implementations/diffusers/unet.py
+3
-3
deepspeed/model_implementations/diffusers/vae.py
deepspeed/model_implementations/diffusers/vae.py
+6
-6
deepspeed/model_implementations/features/__init__.py
deepspeed/model_implementations/features/__init__.py
+1
-0
deepspeed/model_implementations/features/cuda_graph.py
deepspeed/model_implementations/features/cuda_graph.py
+24
-0
deepspeed/model_implementations/transformers/clip_encoder.py
deepspeed/model_implementations/transformers/clip_encoder.py
+3
-3
deepspeed/module_inject/containers/unet.py
deepspeed/module_inject/containers/unet.py
+4
-1
deepspeed/module_inject/containers/vae.py
deepspeed/module_inject/containers/vae.py
+4
-1
未找到文件。
deepspeed/inference/engine.py
浏览文件 @
87eaf8f9
...
...
@@ -24,6 +24,8 @@ from deepspeed.accelerator import get_accelerator
from
..module_inject.policy
import
TransformerPolicy
from
..module_inject.auto_tp
import
AutoTP
from
..module_inject.replace_policy
import
generic_policies
DS_INFERENCE_ENABLED
=
False
from
torch
import
nn
...
...
@@ -155,6 +157,9 @@ class InferenceEngine(Module):
if
config
.
tensor_parallel
.
tp_size
>
1
:
assert
not
config
.
enable_cuda_graph
,
"Cuda graph is not supported for model parallelism"
# Check if local CUDA graphs can be created in replacement modules
self
.
local_cuda_graph
=
self
.
_local_cuda_graph_used
(
self
.
module
)
def
profile_model_time
(
self
,
use_cuda_events
=
True
):
if
not
self
.
model_profile_enabled
and
not
self
.
_config
.
enable_cuda_graph
:
self
.
module
.
register_forward_pre_hook
(
self
.
_pre_forward_hook
)
...
...
@@ -512,6 +517,27 @@ class InferenceEngine(Module):
self
.
_model_times
=
[]
return
model_times
def
_module_match
(
self
,
module
):
for
policy
in
generic_policies
:
policy
=
policy
()
if
policy
.
match_replaced
(
module
):
return
True
return
False
def
_local_cuda_graph_used
(
self
,
module
):
if
isinstance
(
module
,
torch
.
nn
.
Module
):
return
False
else
:
sub_module_cuda_graph
=
False
for
name
in
module
.
__dict__
.
keys
():
sub_module
=
getattr
(
module
,
name
)
if
self
.
_module_match
(
sub_module
)
and
hasattr
(
sub_module
,
"enable_cuda_graph"
):
sub_module_cuda_graph
=
True
return
sub_module_cuda_graph
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""Execute forward propagation
...
...
@@ -525,7 +551,8 @@ class InferenceEngine(Module):
get_accelerator
().
synchronize
()
start
=
time
.
time
()
if
get_accelerator
().
device_name
()
==
'cuda'
and
self
.
_config
.
enable_cuda_graph
:
if
get_accelerator
().
device_name
(
)
==
'cuda'
and
self
.
_config
.
enable_cuda_graph
and
not
self
.
local_cuda_graph
:
if
self
.
cuda_graph_created
:
outputs
=
self
.
_graph_replay
(
*
inputs
,
**
kwargs
)
else
:
...
...
deepspeed/model_implementations/diffusers/unet.py
浏览文件 @
87eaf8f9
...
...
@@ -2,11 +2,12 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
import
torch
from
..features.cuda_graph
import
CUDAGraph
class
DSUNet
(
torch
.
nn
.
Module
):
class
DSUNet
(
CUDAGraph
,
torch
.
nn
.
Module
):
def
__init__
(
self
,
unet
,
enable_cuda_graph
=
True
):
super
().
__init__
()
super
().
__init__
(
enable_cuda_graph
=
enable_cuda_graph
)
self
.
unet
=
unet
# SD pipeline accesses this attribute
self
.
in_channels
=
unet
.
in_channels
...
...
@@ -17,7 +18,6 @@ class DSUNet(torch.nn.Module):
self
.
unet
.
requires_grad_
(
requires_grad
=
False
)
self
.
unet
.
to
(
memory_format
=
torch
.
channels_last
)
self
.
cuda_graph_created
=
False
self
.
enable_cuda_graph
=
enable_cuda_graph
def
_graph_replay
(
self
,
*
inputs
,
**
kwargs
):
for
i
in
range
(
len
(
inputs
)):
...
...
deepspeed/model_implementations/diffusers/vae.py
浏览文件 @
87eaf8f9
...
...
@@ -2,11 +2,12 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
import
torch
from
..features.cuda_graph
import
CUDAGraph
class
DSVAE
(
torch
.
nn
.
Module
):
class
DSVAE
(
CUDAGraph
,
torch
.
nn
.
Module
):
def
__init__
(
self
,
vae
,
enable_cuda_graph
=
True
):
super
().
__init__
()
super
().
__init__
(
enable_cuda_graph
=
enable_cuda_graph
)
self
.
vae
=
vae
self
.
device
=
self
.
vae
.
device
self
.
dtype
=
self
.
vae
.
dtype
...
...
@@ -14,7 +15,6 @@ class DSVAE(torch.nn.Module):
self
.
decoder_cuda_graph_created
=
False
self
.
encoder_cuda_graph_created
=
False
self
.
all_cuda_graph_created
=
False
self
.
enable_cuda_graph
=
enable_cuda_graph
def
_graph_replay_decoder
(
self
,
*
inputs
,
**
kwargs
):
for
i
in
range
(
len
(
inputs
)):
...
...
@@ -104,7 +104,7 @@ class DSVAE(torch.nn.Module):
else
:
return
self
.
_encode
(
*
inputs
,
**
kwargs
)
def
_graph_replay
_all
(
self
,
*
inputs
,
**
kwargs
):
def
_graph_replay
(
self
,
*
inputs
,
**
kwargs
):
for
i
in
range
(
len
(
inputs
)):
if
torch
.
is_tensor
(
inputs
[
i
]):
self
.
static_inputs
[
i
].
copy_
(
inputs
[
i
])
...
...
@@ -117,10 +117,10 @@ class DSVAE(torch.nn.Module):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
self
.
enable_cuda_graph
:
if
self
.
cuda_graph_created
:
outputs
=
self
.
_graph_replay
_all
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
_graph_replay
(
*
inputs
,
**
kwargs
)
else
:
self
.
_create_cuda_graph
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
_graph_replay
_all
(
*
inputs
,
**
kwargs
)
outputs
=
self
.
_graph_replay
(
*
inputs
,
**
kwargs
)
return
outputs
else
:
return
self
.
_forward
(
*
inputs
,
**
kwargs
)
...
...
deepspeed/model_implementations/features/__init__.py
0 → 100644
浏览文件 @
87eaf8f9
'''Copyright The Microsoft DeepSpeed Team'''
deepspeed/model_implementations/features/cuda_graph.py
0 → 100644
浏览文件 @
87eaf8f9
'''
Copyright 2023 The Microsoft DeepSpeed Team
'''
from
abc
import
ABC
,
abstractmethod
class
CUDAGraph
(
ABC
):
def
__init__
(
self
,
enable_cuda_graph
=
False
):
super
().
__init__
()
self
.
enable_cuda_graph
=
enable_cuda_graph
@
abstractmethod
def
_create_cuda_graph
(
self
):
"""
Create CUDA graph(s)
"""
raise
NotImplementedError
@
abstractmethod
def
_graph_replay
(
self
):
"""
Replay CUDA graph(s)
"""
raise
NotImplementedError
deepspeed/model_implementations/transformers/clip_encoder.py
浏览文件 @
87eaf8f9
...
...
@@ -3,11 +3,12 @@ Copyright 2022 The Microsoft DeepSpeed Team
'''
import
torch
from
deepspeed.accelerator
import
get_accelerator
from
..features.cuda_graph
import
CUDAGraph
class
DSClipEncoder
(
torch
.
nn
.
Module
):
class
DSClipEncoder
(
CUDAGraph
,
torch
.
nn
.
Module
):
def
__init__
(
self
,
enc
,
enable_cuda_graph
=
False
):
super
().
__init__
()
super
().
__init__
(
enable_cuda_graph
=
enable_cuda_graph
)
enc
.
text_model
.
_build_causal_attention_mask
=
self
.
_build_causal_attention_mask
self
.
enc
=
enc
self
.
device
=
self
.
enc
.
device
...
...
@@ -18,7 +19,6 @@ class DSClipEncoder(torch.nn.Module):
self
.
static_output
=
[
None
,
None
]
self
.
_cuda_graphs
=
[
None
,
None
]
self
.
iter
=
0
self
.
enable_cuda_graph
=
enable_cuda_graph
self
.
config
=
self
.
enc
.
config
def
_build_causal_attention_mask
(
self
,
bsz
,
seq_len
,
dtype
):
...
...
deepspeed/module_inject/containers/unet.py
浏览文件 @
87eaf8f9
...
...
@@ -5,6 +5,7 @@ import torch
from
torch.nn.parameter
import
Parameter
from
..policy
import
DSPolicy
from
...model_implementations.diffusers.unet
import
DSUNet
class
UNetPolicy
(
DSPolicy
):
...
...
@@ -19,9 +20,11 @@ class UNetPolicy(DSPolicy):
def
match
(
self
,
module
):
return
isinstance
(
module
,
self
.
_orig_layer_class
)
def
match_replaced
(
self
,
module
):
return
isinstance
(
module
,
DSUNet
)
def
apply
(
self
,
module
,
enable_cuda_graph
=
True
):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from
...model_implementations.diffusers.unet
import
DSUNet
return
DSUNet
(
module
,
enable_cuda_graph
=
enable_cuda_graph
)
def
attention
(
self
,
client_module
):
...
...
deepspeed/module_inject/containers/vae.py
浏览文件 @
87eaf8f9
...
...
@@ -2,6 +2,7 @@
Copyright 2022 The Microsoft DeepSpeed Team
'''
from
..policy
import
DSPolicy
from
...model_implementations.diffusers.vae
import
DSVAE
class
VAEPolicy
(
DSPolicy
):
...
...
@@ -20,9 +21,11 @@ class VAEPolicy(DSPolicy):
def
match
(
self
,
module
):
return
isinstance
(
module
,
self
.
_orig_layer_class
)
def
match_replaced
(
self
,
module
):
return
isinstance
(
module
,
DSVAE
)
def
apply
(
self
,
module
,
enable_cuda_graph
=
True
):
# TODO(cmikeh2): Enable cuda graph should be an inference configuration
from
...model_implementations.diffusers.vae
import
DSVAE
return
DSVAE
(
module
,
enable_cuda_graph
=
enable_cuda_graph
)
# NOTE (lekurile): Should we have a diffusers policy class?
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录