Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
61e6d069
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 12 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
61e6d069
编写于
8月 22, 2023
作者:
L
Logan Adams
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into loadams/low-cpu-mem-ut
上级
c5c3093b
5e16eb2c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
47 addition
and
25 deletion
+47
-25
deepspeed/inference/engine.py
deepspeed/inference/engine.py
+4
-1
deepspeed/module_inject/auto_tp.py
deepspeed/module_inject/auto_tp.py
+0
-15
deepspeed/module_inject/auto_tp_model_utils.py
deepspeed/module_inject/auto_tp_model_utils.py
+15
-0
deepspeed/module_inject/fusedqkv_utils.py
deepspeed/module_inject/fusedqkv_utils.py
+3
-1
setup.py
setup.py
+25
-8
未找到文件。
deepspeed/inference/engine.py
浏览文件 @
61e6d069
...
...
@@ -26,7 +26,7 @@ from ..module_inject.policy import TransformerPolicy
from
..module_inject.auto_tp
import
AutoTP
from
..module_inject.replace_policy
import
generic_policies
from
..module_inject.auto_tp_model_utils
import
build_bloom_alibi_tensor
,
build_mpt_atten_bias_tensor
from
..module_inject.auto_tp_model_utils
import
build_bloom_alibi_tensor
,
build_mpt_atten_bias_tensor
,
build_mpt_alibi_tensor
DS_INFERENCE_ENABLED
=
False
from
torch
import
nn
...
...
@@ -187,6 +187,9 @@ class InferenceEngine(Module):
if
hasattr
(
self
.
module
,
'transformer'
):
if
hasattr
(
self
.
module
.
transformer
,
'build_alibi_tensor'
):
self
.
module
.
transformer
.
build_alibi_tensor
=
build_bloom_alibi_tensor
if
hasattr
(
self
.
module
.
transformer
,
'build_mpt_alibi_tensor'
):
self
.
module
.
transformer
.
build_mpt_alibi_tensor_orig
=
self
.
module
.
transformer
.
build_mpt_alibi_tensor
self
.
module
.
transformer
.
__class__
.
build_mpt_alibi_tensor
=
build_mpt_alibi_tensor
def
build_attn_bias
(
self
):
if
hasattr
(
self
.
module
,
'transformer'
):
...
...
deepspeed/module_inject/auto_tp.py
浏览文件 @
61e6d069
...
...
@@ -296,21 +296,6 @@ class AutoTP():
if
getattr
(
child
,
"replaced"
,
False
)
==
True
:
return
weight_shape
=
child
.
weight
.
shape
if
name
==
'attn.Wqkv'
and
self
.
module
.
_get_name
()
==
'MPTBlock'
:
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
# instead of [num_head,3,head_dim,hidden_size]
new_weight
=
torch
.
empty
((
weight_shape
[
0
]
//
self
.
mp_size
,
weight_shape
[
1
],
),
device
=
child
.
weight
.
device
,
dtype
=
child
.
weight
.
dtype
)
reversed_dim
=
True
mp_replace
=
ReplaceWithTensorSlicing
(
mp_group
=
self
.
mp_group
,
out_dim
=
0
)
# todo: can we remove new tensor allocation if we use strided copy?
mp_replace
.
strided_copy
(
new_weight
,
child
.
weight
.
data
,
num_splits
=
3
,
int8
=
reversed_dim
)
setattr
(
child
,
"replaced"
,
True
)
return
LinearLayer
(
weight
=
new_weight
.
to
(
get_accelerator
().
current_device_name
()),
bias
=
None
)
mp_replace
=
ReplaceWithTensorSlicing
(
mp_group
=
self
.
mp_group
)
if
name
in
self
.
all_reduce_linears
:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
...
...
deepspeed/module_inject/auto_tp_model_utils.py
浏览文件 @
61e6d069
...
...
@@ -76,3 +76,18 @@ def build_mpt_atten_bias_tensor(self,
offset
=
dist
.
get_rank
()
*
num_heads_per_rank
attn_bias
=
attn_bias
[:,
offset
:
num_heads_per_rank
+
offset
,
:,
:]
return
attn_bias
,
attention_mask
def
build_mpt_alibi_tensor
(
self
,
num_heads
,
sequence_length
,
alibi_bias_max
=
8
,
device
=
None
)
->
torch
.
Tensor
:
r
"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
"""
alibi
=
self
.
build_mpt_alibi_tensor_orig
(
num_heads
,
sequence_length
,
alibi_bias_max
,
device
)
if
dist
.
is_initialized
():
num_heads_per_rank
=
int
(
num_heads
/
dist
.
get_world_size
())
offset
=
dist
.
get_rank
()
*
num_heads_per_rank
alibi
=
alibi
[
offset
:
num_heads_per_rank
+
offset
,
:,
:]
return
alibi
deepspeed/module_inject/fusedqkv_utils.py
浏览文件 @
61e6d069
...
...
@@ -16,7 +16,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
def
require_tp_fused_qkvw
(
name
,
mp_size
):
fused_qkvw_name_list
=
[
'qkv_proj'
,
'query_key_value'
]
fused_qkvw_name_list
=
[
'qkv_proj'
,
'query_key_value'
,
'attn.Wqkv'
]
if
mp_size
==
1
:
return
False
...
...
@@ -33,6 +33,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
'CodeGenBlock'
:
'codegentype'
,
'BloomBlock'
:
'bloomtype'
,
'GLMBlock'
:
'glmtype'
,
"MPTBlock"
:
'glmtype'
,
"MptBlock"
:
'glmtype'
,
}
def
_codegen_type_transpose
(
input
,
mp_size
,
codegen_mp_num
=
4
):
...
...
setup.py
浏览文件 @
61e6d069
...
...
@@ -24,6 +24,7 @@ import subprocess
from
setuptools
import
setup
,
find_packages
from
setuptools.command
import
egg_info
import
time
import
typing
torch_available
=
True
try
:
...
...
@@ -56,6 +57,22 @@ def fetch_requirements(path):
return
[
r
.
strip
()
for
r
in
fd
.
readlines
()]
def
is_env_set
(
key
):
"""
Checks if an environment variable is set and not "".
"""
return
bool
(
os
.
environ
.
get
(
key
,
None
))
def
get_env_if_set
(
key
,
default
:
typing
.
Any
=
""
):
"""
Returns an environment variable if it is set and not "",
otherwise returns a default value. In contrast, the fallback
parameter of os.environ.get() is skipped if the variable is set to "".
"""
return
os
.
environ
.
get
(
key
,
None
)
or
default
install_requires
=
fetch_requirements
(
'requirements/requirements.txt'
)
extras_require
=
{
'1bit'
:
[],
# add cupy based on cuda/rocm version
...
...
@@ -116,14 +133,14 @@ if torch_available and not torch.cuda.is_available():
print
(
"[WARNING] Torch did not find cuda available, if cross-compiling or running with cpu only "
"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "
"(compute capabilities 6.0, 6.1, 6.2)"
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
if
not
is_env_set
(
"TORCH_CUDA_ARCH_LIST"
)
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
get_default_compute_capabilities
()
ext_modules
=
[]
# Default to pre-install kernels to false so we rely on JIT on Linux, opposite on Windows.
BUILD_OP_PLATFORM
=
1
if
sys
.
platform
==
"win32"
else
0
BUILD_OP_DEFAULT
=
int
(
os
.
environ
.
g
et
(
'DS_BUILD_OPS'
,
BUILD_OP_PLATFORM
))
BUILD_OP_DEFAULT
=
int
(
get_env_if_s
et
(
'DS_BUILD_OPS'
,
BUILD_OP_PLATFORM
))
print
(
f
"DS_BUILD_OPS=
{
BUILD_OP_DEFAULT
}
"
)
if
BUILD_OP_DEFAULT
:
...
...
@@ -147,7 +164,7 @@ def op_envvar(op_name):
def
op_enabled
(
op_name
):
env_var
=
op_envvar
(
op_name
)
return
int
(
os
.
environ
.
g
et
(
env_var
,
BUILD_OP_DEFAULT
))
return
int
(
get_env_if_s
et
(
env_var
,
BUILD_OP_DEFAULT
))
compatible_ops
=
dict
.
fromkeys
(
ALL_OPS
.
keys
(),
False
)
...
...
@@ -160,7 +177,7 @@ for op_name, builder in ALL_OPS.items():
# If op is requested but not available, throw an error.
if
op_enabled
(
op_name
)
and
not
op_compatible
:
env_var
=
op_envvar
(
op_name
)
if
env_var
not
in
os
.
environ
:
if
not
is_env_set
(
env_var
)
:
builder
.
warning
(
f
"One can disable
{
op_name
}
with
{
env_var
}
=0"
)
abort
(
f
"Unable to pre-compile
{
op_name
}
"
)
...
...
@@ -179,7 +196,7 @@ print(f'Install Ops={install_ops}')
# Write out version/git info.
git_hash_cmd
=
"git rev-parse --short HEAD"
git_branch_cmd
=
"git rev-parse --abbrev-ref HEAD"
if
command_exists
(
'git'
)
and
'DS_BUILD_STRING'
not
in
os
.
environ
:
if
command_exists
(
'git'
)
and
not
is_env_set
(
'DS_BUILD_STRING'
)
:
try
:
result
=
subprocess
.
check_output
(
git_hash_cmd
,
shell
=
True
)
git_hash
=
result
.
decode
(
'utf-8'
).
strip
()
...
...
@@ -216,11 +233,11 @@ version_str = open('version.txt', 'r').read().strip()
# Example: DS_BUILD_STRING=".dev20201022" python setup.py sdist bdist_wheel.
# Building wheel for distribution, update version file.
if
'DS_BUILD_STRING'
in
os
.
environ
:
if
is_env_set
(
'DS_BUILD_STRING'
)
:
# Build string env specified, probably building for distribution.
with
open
(
'build.txt'
,
'w'
)
as
fd
:
fd
.
write
(
os
.
environ
.
get
(
'DS_BUILD_STRING'
)
)
version_str
+=
os
.
environ
.
get
(
'DS_BUILD_STRING'
)
fd
.
write
(
os
.
environ
[
'DS_BUILD_STRING'
]
)
version_str
+=
os
.
environ
[
'DS_BUILD_STRING'
]
elif
os
.
path
.
isfile
(
'build.txt'
):
# build.txt exists, probably installing from distribution.
with
open
(
'build.txt'
,
'r'
)
as
fd
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录