Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
f5453124
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
f5453124
编写于
7月 11, 2020
作者:
J
Jeff Rasley
提交者:
GitHub
7月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support amp deepspeed backend (#286)
* add amp support for deepspeed (non-ZeRO) * tests for amp mode
上级
4a3234e0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
169 addition
and
14 deletion
+169
-14
deepspeed/pt/deepspeed_config.py
deepspeed/pt/deepspeed_config.py
+19
-0
deepspeed/pt/deepspeed_constants.py
deepspeed/pt/deepspeed_constants.py
+18
-0
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+32
-14
tests/unit/test_fp16.py
tests/unit/test_fp16.py
+100
-0
未找到文件。
deepspeed/pt/deepspeed_config.py
浏览文件 @
f5453124
...
...
@@ -5,6 +5,7 @@ Licensed under the MIT license.
import
torch
import
json
import
copy
from
deepspeed.pt.deepspeed_constants
import
*
from
deepspeed.pt.loss_scaler
import
INITIAL_LOSS_SCALE
,
SCALE_WINDOW
,
DELAYED_SHIFT
,
MIN_LOSS_SCALE
from
deepspeed.pt.deepspeed_config_utils
import
get_scalar_param
,
dict_raise_error_on_duplicate_keys
...
...
@@ -18,6 +19,22 @@ LAMB_OPTIMIZER = 'lamb'
DEEPSPEED_OPTIMIZERS
=
[
ADAM_OPTIMIZER
,
LAMB_OPTIMIZER
]
def
get_amp_enabled
(
param_dict
):
if
AMP
in
param_dict
.
keys
():
return
get_scalar_param
(
param_dict
[
AMP
],
AMP_ENABLED
,
AMP_ENABLED_DEFAULT
)
else
:
return
False
def
get_amp_params
(
param_dict
):
if
AMP
in
param_dict
.
keys
():
amp_params
=
copy
.
copy
(
param_dict
[
AMP
])
amp_params
.
pop
(
AMP_ENABLED
)
return
amp_params
else
:
return
False
def
get_fp16_enabled
(
param_dict
):
if
FP16
in
param_dict
.
keys
():
return
get_scalar_param
(
param_dict
[
FP16
],
FP16_ENABLED
,
FP16_ENABLED_DEFAULT
)
...
...
@@ -315,6 +332,8 @@ class DeepSpeedConfig(object):
self
.
gradient_clipping
=
get_gradient_clipping
(
param_dict
)
self
.
fp16_enabled
=
get_fp16_enabled
(
param_dict
)
self
.
amp_enabled
=
get_amp_enabled
(
param_dict
)
self
.
amp_params
=
get_amp_params
(
param_dict
)
self
.
loss_scale
=
get_loss_scale
(
param_dict
)
self
.
initial_dynamic_scale
=
get_initial_dynamic_scale
(
param_dict
)
self
.
dynamic_loss_scale_args
=
get_dynamic_loss_scale_args
(
param_dict
)
...
...
deepspeed/pt/deepspeed_constants.py
浏览文件 @
f5453124
...
...
@@ -117,6 +117,24 @@ FP16_HYSTERESIS_DEFAULT = 2
FP16_MIN_LOSS_SCALE
=
"min_loss_scale"
FP16_MIN_LOSS_SCALE_DEFAULT
=
1
#########################################
# Apex AMP support
#########################################
# Use Apex AMP for mixed precision support, all parameters (other than 'enabled') will be passed to
# amp.initialize(model, optimizer, **amp_params)
# See apex documentation for supported parameters/features: https://nvidia.github.io/apex/amp.html#apex.amp.initialize
AMP_FORMAT
=
'''
"amp" {
"enabled: true,
"opt_level": "O1",
...
}
'''
AMP
=
"amp"
AMP_ENABLED
=
"enabled"
AMP_ENABLED_DEFAULT
=
False
#########################################
# Gradient clipping
#########################################
...
...
deepspeed/pt/deepspeed_light.py
浏览文件 @
f5453124
...
...
@@ -8,6 +8,7 @@ import warnings
import
torch.distributed
as
dist
from
torch.nn.modules
import
Module
from
torch.distributed.distributed_c10d
import
_get_global_rank
from
apex
import
amp
from
tensorboardX
import
SummaryWriter
...
...
@@ -312,6 +313,12 @@ class DeepSpeedLight(Module):
def
fp16_enabled
(
self
):
return
self
.
_config
.
fp16_enabled
def
amp_enabled
(
self
):
return
self
.
_config
.
amp_enabled
def
amp_params
(
self
):
return
self
.
_config
.
amp_params
def
loss_scale
(
self
):
return
self
.
_config
.
loss_scale
...
...
@@ -449,28 +456,33 @@ class DeepSpeedLight(Module):
assert
self
.
dynamic_loss_scale
(),
\
'DeepSpeed {} optimizer requires dynamic loss scaling'
.
format
(
self
.
optimizer_name
())
def
_broadcast_model
(
self
):
for
p
in
self
.
module
.
parameters
():
if
torch
.
is_tensor
(
p
):
dist
.
broadcast
(
p
,
self
.
broadcast_src_rank
,
group
=
self
.
data_parallel_group
)
def
_configure_distributed_model
(
self
,
model
):
self
.
module
=
model
if
self
.
fp16_enabled
():
self
.
module
.
half
()
self
.
module
.
to
(
self
.
device
)
if
self
.
mpu
is
None
:
self
.
data_parallel_group
=
_initialize_parameter_parallel_groups
()
self
.
dp_world_size
=
dist
.
get_world_size
()
src_rank
=
0
s
elf
.
broadcast_s
rc_rank
=
0
else
:
self
.
data_parallel_group
=
self
.
mpu
.
get_data_parallel_group
()
self
.
dp_world_size
=
self
.
mpu
.
get_data_parallel_world_size
()
src_rank
=
_get_global_rank
(
self
.
mpu
.
get_data_parallel_group
(),
0
)
logger
.
info
(
f
"global src_rank=
{
src_rank
}
"
)
for
p
in
self
.
module
.
parameters
():
if
torch
.
is_tensor
(
p
):
dist
.
broadcast
(
p
,
src_rank
,
group
=
self
.
data_parallel_group
)
self
.
broadcast_src_rank
=
_get_global_rank
(
self
.
mpu
.
get_data_parallel_group
(),
0
)
logger
.
info
(
f
"global src_rank=
{
self
.
broadcast_src_rank
}
"
)
# TODO: support new AMP optimizer
# self.module.half()
# self.module.to(self.local_rank)
#self.module, self.optimizer = amp.initialize(self.module, self.optimizer, opt_level="O2")
if
not
self
.
amp_enabled
():
self
.
_broadcast_model
()
# Configure optimizer
def
_configure_optimizer
(
self
,
client_optimizer
,
model_parameters
):
...
...
@@ -486,6 +498,7 @@ class DeepSpeedLight(Module):
logger
.
info
(
'DeepSpeed Basic Optimizer = {}'
.
format
(
basic_optimizer
))
if
self
.
zero_optimization
():
assert
not
self
.
amp_enabled
(),
"Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if
self
.
optimizer_name
()
!=
ADAM_OPTIMIZER
:
assert
self
.
zero_allow_untested_optimizer
(),
\
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'
...
...
@@ -494,6 +507,12 @@ class DeepSpeedLight(Module):
"**** You are using ZeRO with an untested optimizer, proceed with caution *****"
)
self
.
optimizer
=
self
.
_configure_zero_optimizer
(
basic_optimizer
)
elif
self
.
amp_enabled
():
assert
not
self
.
fp16_enabled
(),
"Cannot enable both amp with (legacy) fp16 mode"
amp_params
=
self
.
amp_params
()
logger
.
info
(
f
"Initializing AMP with these params:
{
amp_params
}
"
)
self
.
module
,
self
.
optimizer
=
amp
.
initialize
(
self
.
module
,
basic_optimizer
,
**
amp_params
)
self
.
_broadcast_model
()
elif
self
.
fp16_enabled
():
self
.
optimizer
=
self
.
_configure_fp16_optimizer
(
basic_optimizer
)
else
:
...
...
@@ -748,12 +767,11 @@ class DeepSpeedLight(Module):
if
self
.
zero_optimization
():
self
.
optimizer
.
backward
(
loss
)
elif
self
.
amp_enabled
():
with
amp
.
scale_loss
(
loss
,
self
.
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
elif
self
.
fp16_enabled
():
self
.
optimizer
.
backward
(
loss
)
# TODO: Use new AMP semantics as below
# with amp.scale_loss(loss, self.optimizer) as scaled_loss:
# scaled_loss.backward()
else
:
loss
.
backward
()
...
...
tests/unit/test_fp16.py
浏览文件 @
f5453124
...
...
@@ -395,3 +395,103 @@ def test_zero_empty_partition(tmpdir, zero_stage):
model
.
step
()
_test_zero_empty_partition
(
args
)
def
test_adam_amp_basic
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
1
,
"steps_per_print"
:
1
,
"amp"
:
{
"enabled"
:
True
}}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
@
distributed_test
(
world_size
=
[
1
])
def
_test_adam_amp_basic
(
args
,
model
,
hidden_dim
):
optimizer
=
torch
.
optim
.
Adam
(
params
=
model
.
parameters
())
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
optimizer
=
optimizer
)
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_amp_basic
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_lamb_amp_basic
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Lamb"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"amp"
:
{
"enabled"
:
True
,
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
@
distributed_test
(
world_size
=
[
1
,
2
])
def
_test_lamb_amp_basic
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_lamb_amp_basic
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adam_amp_o2
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"amp"
:
{
"enabled"
:
True
,
"opt_level"
:
"O2"
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
@
distributed_test
(
world_size
=
[
1
,
2
])
def
_test_adam_amp_o2
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_amp_o2
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录