Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
6bea95e5
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 11 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
6bea95e5
编写于
2月 21, 2020
作者:
O
Olatunji Ruwase
提交者:
GitHub
2月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into olruwase/legacy_optimizer_fusion
上级
932268a4
001abe23
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
147 addition
and
93 deletion
+147
-93
deepspeed/__init__.py
deepspeed/__init__.py
+5
-1
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+8
-0
deepspeed/pt/deepspeed_run.py
deepspeed/pt/deepspeed_run.py
+10
-4
docs/features.md
docs/features.md
+2
-1
install.sh
install.sh
+1
-6
tests/unit/simple_model.py
tests/unit/simple_model.py
+46
-0
tests/unit/test_config.py
tests/unit/test_config.py
+54
-0
tests/unit/test_fp16.py
tests/unit/test_fp16.py
+21
-81
未找到文件。
deepspeed/__init__.py
浏览文件 @
6bea95e5
...
...
@@ -54,7 +54,7 @@ def initialize(args,
step(), state_dict(), and load_state_dict() methods
mpu: Optional: A model parallelism unit object that implements
get_
model/data_parallel_group/rank/size
()
get_
{model,data}_parallel_{rank,group,world_size}
()
dist_init_required: Optional: Initializes torch.distributed
...
...
@@ -128,6 +128,10 @@ def _add_core_arguments(parser):
type
=
str
,
help
=
'DeepSpeed json configuration file.'
)
group
.
add_argument
(
'--deepscale_config'
,
default
=
None
,
type
=
str
,
help
=
'Deprecated DeepSpeed json configuration file.'
)
return
parser
...
...
deepspeed/pt/deepspeed_light.py
浏览文件 @
6bea95e5
...
...
@@ -325,6 +325,14 @@ class DeepSpeedLight(Module):
# Validate command line arguments
def
_do_args_sanity_check
(
self
,
args
):
if
hasattr
(
args
,
'deepscale_config'
)
and
args
.
deepscale_config
is
not
None
:
logging
.
warning
(
"************ --deepscale_config is deprecated, please use --deepspeed_config ************"
)
if
hasattr
(
args
,
'deepspeed_config'
):
assert
args
.
deepspeed_config
is
None
,
"Not sure how to proceed, we were given both a deepscale_config and deepspeed_config"
args
.
deepspeed_config
=
args
.
deepscale_config
assert
hasattr
(
args
,
'local_rank'
)
and
type
(
args
.
local_rank
)
==
int
,
\
'DeepSpeed requires integer command line parameter --local_rank'
...
...
deepspeed/pt/deepspeed_run.py
浏览文件 @
6bea95e5
...
...
@@ -15,6 +15,7 @@ import collections
from
copy
import
deepcopy
DLTS_HOSTFILE
=
"/job/hostfile"
EXPORT_ENVS
=
[
"NCCL"
,
"PYTHONPATH"
]
def
parse_args
(
args
=
None
):
...
...
@@ -305,13 +306,18 @@ def main(args=None):
num_gpus_per_node
=
None
curr_path
=
os
.
path
.
abspath
(
'.'
)
if
'PYTHONPATH'
in
env
:
env
[
'PYTHONPATH'
]
=
curr_path
+
":"
+
env
[
'PYTHONPATH'
]
else
:
env
[
'PYTHONPATH'
]
=
curr_path
nccl_export
=
""
for
nccl_var
in
filter
(
lambda
x
:
"NCCL_"
in
x
,
env
.
keys
()):
nccl_export
+=
"export {}={}; "
.
format
(
nccl_var
,
env
[
nccl_var
])
exports
=
""
for
var
in
env
.
keys
():
if
any
(
map
(
lambda
name
:
name
in
var
,
EXPORT_ENVS
)):
exports
+=
"export {}={}; "
.
format
(
var
,
env
[
var
])
deepspeed_launch
=
[
nccl_export
,
exports
,
"cd {};"
.
format
(
curr_path
),
sys
.
executable
,
"-u"
,
...
...
docs/features.md
浏览文件 @
6bea95e5
...
...
@@ -68,10 +68,11 @@ mpu.get_model_parallel_rank()
mpu
.
get_model_parallel_group
()
mpu
.
get_model_parallel_world_size
()
mpu
.
get_data_parallel_rank
/
group
/
world_size
()
mpu
.
get_data_parallel_rank
()
mpu
.
get_data_parallel_group
()
mpu
.
get_data_parallel_world_size
()
```
### Integration with Megatron-LM
DeepSpeed is fully compatible with
[
Megatron
](
https://github.com/NVIDIA/Megatron-LM
)
.
Please see the
[
Megatron-LM tutorial
](
tutorials/MegatronGPT2Tutorial.md
)
for details.
...
...
install.sh
浏览文件 @
6bea95e5
...
...
@@ -109,16 +109,11 @@ if [ "$third_party_install" == "1" ]; then
sudo
-H
pip
install
third_party/apex/dist/apex
*
.whl
fi
if
[
"
$deepspeed_install
"
==
"1"
]
;
then
echo
"
Installing deepspeed
"
echo
"
Building deepspeed wheel
"
python setup.py bdist_wheel
fi
if
[
"
$local_only
"
==
"1"
]
;
then
if
[
"
$third_party_install
"
==
"1"
]
;
then
echo
"Installing apex locally"
sudo
-H
pip uninstall
-y
apex
sudo
-H
pip
install
third_party/apex/dist/apex
*
.whl
fi
if
[
"
$deepspeed_install
"
==
"1"
]
;
then
echo
"Installing deepspeed"
sudo
-H
pip uninstall
-y
deepspeed
...
...
tests/unit/simple_model.py
0 → 100644
浏览文件 @
6bea95e5
import
os
import
json
import
argparse
import
torch
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
self
.
layers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)])
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
def
random_dataloader
(
model
,
total_samples
,
hidden_dim
,
device
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
train_dataset
=
torch
.
utils
.
data
.
TensorDataset
(
train_data
,
train_label
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
)
return
train_loader
def
create_config_from_dict
(
tmpdir
,
config_dict
):
config_path
=
os
.
path
.
join
(
tmpdir
,
'temp_config.json'
)
with
open
(
config_path
,
'w'
)
as
fd
:
json
.
dump
(
config_dict
,
fd
)
return
config_path
def
args_from_dict
(
tmpdir
,
config_dict
):
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepspeed
=
True
args
.
deepspeed_config
=
config_path
args
.
local_rank
=
0
return
args
tests/unit/test_config.py
浏览文件 @
6bea95e5
# A test on its own
import
torch
import
pytest
import
json
import
argparse
from
common
import
distributed_test
from
simple_model
import
SimpleModel
,
create_config_from_dict
,
random_dataloader
import
torch.distributed
as
dist
# A test on its own
...
...
@@ -100,3 +103,54 @@ def test_batch_config(num_ranks, batch, micro_batch, gas, success):
"""Run batch config test """
_test_batch_config
(
num_ranks
,
batch
,
micro_batch
,
gas
,
success
)
def
test_temp_config_json
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
config_json
=
json
.
load
(
open
(
config_path
,
'r'
))
assert
'train_batch_size'
in
config_json
def
test_deprecated_deepscale_config
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"fp16"
:
{
"enabled"
:
True
}
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepscale_config
=
config_path
args
.
local_rank
=
0
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
)
@
distributed_test
(
world_size
=
[
1
])
def
_test_deprecated_deepscale_config
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
5
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_deprecated_deepscale_config
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
tests/unit/test_fp16.py
浏览文件 @
6bea95e5
...
...
@@ -5,67 +5,7 @@ import pytest
import
json
import
os
from
common
import
distributed_test
def
create_config_from_dict
(
tmpdir
,
config_dict
):
config_path
=
os
.
path
.
join
(
tmpdir
,
'temp_config.json'
)
with
open
(
config_path
,
'w'
)
as
fd
:
json
.
dump
(
config_dict
,
fd
)
return
config_path
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
self
.
layers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)])
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
def
test_temp_config_json
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
}
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
config_json
=
json
.
load
(
open
(
config_path
,
'r'
))
assert
'train_batch_size'
in
config_json
def
prepare_optimizer_parameters
(
model
):
param_optimizer
=
list
(
model
.
named_parameters
())
optimizer_grouped_parameters
=
[{
'params'
:
[
p
for
n
,
p
in
param_optimizer
],
'weight_decay'
:
0.0
}]
return
optimizer_grouped_parameters
def
get_data_loader
(
model
,
total_samples
,
hidden_dim
,
device
):
batch_size
=
model
.
train_micro_batch_size_per_gpu
()
train_data
=
torch
.
randn
(
total_samples
,
hidden_dim
,
device
=
device
,
dtype
=
torch
.
half
)
train_label
=
torch
.
empty
(
total_samples
,
dtype
=
torch
.
long
,
device
=
device
).
random_
(
hidden_dim
)
train_dataset
=
torch
.
utils
.
data
.
TensorDataset
(
train_data
,
train_label
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
)
return
train_loader
def
get_args
(
tmpdir
,
config_dict
):
config_path
=
create_config_from_dict
(
tmpdir
,
config_dict
)
parser
=
argparse
.
ArgumentParser
()
args
=
parser
.
parse_args
(
args
=
''
)
args
.
deepspeed
=
True
args
.
deepspeed_config
=
config_path
args
.
local_rank
=
0
return
args
from
simple_model
import
SimpleModel
,
random_dataloader
,
args_from_dict
def
test_lamb_fp16_basic
(
tmpdir
):
...
...
@@ -83,7 +23,7 @@ def test_lamb_fp16_basic(tmpdir):
"enabled"
:
True
}
}
args
=
get_args
(
tmpdir
,
config_dict
)
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
...
...
@@ -94,10 +34,10 @@ def test_lamb_fp16_basic(tmpdir):
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
get_data_
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
data_loader
=
random_data
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
...
...
@@ -121,7 +61,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
"enabled"
:
True
}
}
args
=
get_args
(
tmpdir
,
config_dict
)
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
...
...
@@ -132,10 +72,10 @@ def test_lamb_fp16_empty_grad(tmpdir):
model
=
model
,
model_parameters
=
model
.
parameters
(),
dist_init_required
=
False
)
data_loader
=
get_data_
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
data_loader
=
random_data
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
...
...
@@ -152,7 +92,7 @@ def test_adamw_fp16_basic(tmpdir):
"enabled"
:
True
}
}
args
=
get_args
(
tmpdir
,
config_dict
)
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
...
...
@@ -164,10 +104,10 @@ def test_adamw_fp16_basic(tmpdir):
model
=
model
,
optimizer
=
optimizer
,
dist_init_required
=
False
)
data_loader
=
get_data_
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
data_loader
=
random_data
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
...
...
@@ -184,7 +124,7 @@ def test_adamw_fp16_empty_grad(tmpdir):
"enabled"
:
True
}
}
args
=
get_args
(
tmpdir
,
config_dict
)
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
...
...
@@ -196,10 +136,10 @@ def test_adamw_fp16_empty_grad(tmpdir):
model
=
model
,
optimizer
=
optimizer
,
dist_init_required
=
False
)
data_loader
=
get_data_
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
data_loader
=
random_data
loader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录