Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
625dd722
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
625dd722
编写于
4月 04, 2022
作者:
S
ShenLiang
提交者:
GitHub
4月 04, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix recompute (#41396)
上级
a6b6bcbf
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
191 addition
and
67 deletion
+191
-67
python/paddle/distributed/fleet/utils/recompute.py
python/paddle/distributed/fleet/utils/recompute.py
+143
-4
python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
...on/paddle/fluid/tests/unittests/test_dygraph_recompute.py
+48
-63
未找到文件。
python/paddle/distributed/fleet/utils/recompute.py
浏览文件 @
625dd722
...
...
@@ -14,9 +14,11 @@
import
paddle
from
paddle.fluid
import
core
from
paddle.autograd
import
PyLayer
from
paddle.autograd
import
PyLayer
,
EagerPyLayer
from
paddle.fluid
import
framework
import
contextlib
from
paddle.fluid.framework
import
in_dygraph_mode
import
logging
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -32,7 +34,7 @@ __all__ = []
def
detach_variable
(
inputs
):
out
=
[]
for
inp
in
inputs
:
if
not
isinstance
(
inp
,
core
.
VarBase
):
if
not
isinstance
(
inp
,
(
core
.
eager
.
Tensor
,
core
.
VarBase
)
):
out
.
append
(
inp
)
continue
...
...
@@ -44,7 +46,7 @@ def detach_variable(inputs):
def
check_recompute_necessary
(
inputs
):
if
not
any
(
input_
.
stop_gradient
==
False
for
input_
in
inputs
if
isinstance
(
input_
,
paddle
.
Tensor
)):
if
isinstance
(
input_
,
(
core
.
eager
.
Tensor
,
paddle
.
Tensor
)
)):
logger
.
warn
(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !"
)
...
...
@@ -60,6 +62,140 @@ def swith_rng_state(rng_state):
paddle
.
set_cuda_rng_state
(
orig_cuda_rng_state
)
class
EagerRecomputeFunction
(
EagerPyLayer
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
if
framework
.
_dygraph_tracer
().
_has_grad
:
check_recompute_necessary
(
args
)
# store for recomputing
ctx
.
run_function
=
run_function
ctx
.
preserve_rng_state
=
preserve_rng_state
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# save input for backward
ctx
.
inputs
=
[]
ctx
.
tensor_indices
=
[]
tensor_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
paddle
.
is_tensor
(
arg
):
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
inputs
.
append
(
None
)
else
:
ctx
.
inputs
.
append
(
arg
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if
ctx
.
preserve_rng_state
:
cur_device
=
paddle
.
get_device
()
if
'gpu:'
not
in
cur_device
:
raise
RuntimeError
(
"Recompute with RNG perserve is not support current device: {}."
.
format
(
cur_device
))
ctx
.
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
ctx
.
is_fw_autocast
=
False
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O0
else
True
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O2
:
ctx
.
amp_level
=
'O2'
elif
tracer
.
_amp_level
in
(
core
.
AmpLevel
.
O1
,
core
.
AmpLevel
.
O0
):
ctx
.
amp_level
=
'O1'
else
:
raise
ValueError
(
"unsupported amp level: {}"
.
format
(
tracer
.
_amp_level
))
if
tracer
.
_amp_dtype
==
'float16'
:
ctx
.
amp_dtype
=
'float16'
elif
tracer
.
_amp_dtype
in
(
'bfloat16'
,
'float32'
):
ctx
.
amp_dtype
=
'bfloat16'
else
:
raise
ValueError
(
"unsupported amp dtype: {}"
.
format
(
tracer
.
_amp_dtype
))
ctx
.
amp_white_list
,
ctx
.
amp_black_list
=
tracer
.
_get_amp_op_list
()
with
paddle
.
no_grad
():
outputs
=
run_function
(
*
args
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
with
paddle
.
fluid
.
dygraph
.
guard
():
# TODO need to check the recompute calling is vaild or not
# Restore inputs
inputs
=
list
(
ctx
.
inputs
)
tensor_indices
=
ctx
.
tensor_indices
tensors
=
ctx
.
saved_tensor
()
for
i
,
idx
in
enumerate
(
tensor_indices
):
inputs
[
idx
]
=
tensors
[
i
]
# paddle.enable_grad()
tracer
=
framework
.
_dygraph_tracer
()
tracer
.
_has_grad
=
True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if
ctx
.
preserve_rng_state
:
with
swith_rng_state
(
ctx
.
fw_cuda_rng_state
):
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_black_list
=
ctx
.
amp_black_list
,
level
=
ctx
.
amp_level
,
dtype
=
ctx
.
amp_dtype
):
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
else
:
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_black_list
=
ctx
.
amp_black_list
,
level
=
ctx
.
amp_level
,
dtype
=
ctx
.
amp_dtype
):
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
if
isinstance
(
outputs
,
core
.
eager
.
Tensor
):
outputs
=
(
outputs
,
)
assert
len
(
outputs
)
==
len
(
args
)
# run backward() with only tensor that requires grad
forward_outputs_with_grad
=
[]
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad
=
[]
for
i
in
range
(
len
(
outputs
)):
if
isinstance
(
outputs
[
i
],
core
.
eager
.
Tensor
)
and
not
outputs
[
i
].
stop_gradient
:
forward_outputs_with_grad
.
append
(
outputs
[
i
])
backward_inputs_with_grad
.
append
(
args
[
i
])
if
len
(
forward_outputs_with_grad
)
==
0
:
raise
RuntimeError
(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
paddle
.
autograd
.
backward
(
forward_outputs_with_grad
,
backward_inputs_with_grad
)
grads
=
tuple
(
inp
.
grad
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
eager
.
Tensor
))
return
grads
class
RecomputeFunction
(
PyLayer
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
...
...
@@ -315,4 +451,7 @@ def recompute(function, *args, **kwargs):
raise
ValueError
(
"Unexpected keyword arguments: "
+
","
.
join
(
arg
for
arg
in
kwargs
))
return
RecomputeFunction
.
apply
(
function
,
preserve
,
*
args
)
if
in_dygraph_mode
():
return
EagerRecomputeFunction
.
apply
(
function
,
preserve
,
*
args
)
else
:
return
RecomputeFunction
.
apply
(
function
,
preserve
,
*
args
)
python/paddle/fluid/tests/unittests/test_dygraph_recompute.py
浏览文件 @
625dd722
...
...
@@ -23,6 +23,7 @@ from paddle.distributed.fleet.utils import recompute
import
random
import
paddle.fluid.layers
as
layers
from
paddle.fluid.framework
import
_test_eager_guard
def
get_fc_block
(
block_idx
,
input_size
,
is_last
=
False
):
...
...
@@ -141,96 +142,75 @@ def run_model(recompute_block=[],
class
TestPyLayer
(
unittest
.
TestCase
):
def
test_
fc_net_with_dropout
(
self
):
def
test_
base_case
(
self
,
enable_autocast
=
False
,
pure_fp16
=
False
):
def
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
):
self
.
assertEqual
(
loss_ref
,
loss
)
self
.
assertEqual
(
param_ref
,
param
)
self
.
assertEqual
(
grad_ref
,
grad
)
# without recompute
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[])
# recompute second block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
])
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
3
])
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second to fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
2
,
3
])
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second & fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
3
])
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_without_restore_rng
(
self
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
{
"preserve_rng_state"
:
False
},
enable_autocast
=
True
)
def
test_fc_net_with_amp
(
self
):
def
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
):
self
.
assertEqual
(
loss_ref
,
loss
)
self
.
assertEqual
(
param_ref
,
param
)
self
.
assertEqual
(
grad_ref
,
grad
)
# without recompute
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[],
enable_autocast
=
True
)
recompute_block
=
[],
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
)
# recompute second block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
],
enable_autocast
=
True
)
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
],
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
3
],
enable_autocast
=
True
)
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
3
],
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second to fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
2
,
3
],
enable_autocast
=
True
)
recompute_block
=
[
1
,
2
,
3
],
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
# recompute second & fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
3
],
enable_autocast
=
True
)
recompute_block
=
[
1
,
3
],
enable_autocast
=
enable_autocast
,
pure_fp16
=
pure_fp16
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_with_fp16
(
self
):
def
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
):
self
.
assertEqual
(
loss_ref
,
loss
)
self
.
assertEqual
(
param_ref
,
param
)
self
.
assertEqual
(
grad_ref
,
grad
)
# without recompute
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[],
enable_autocast
=
True
,
pure_fp16
=
True
)
# recompute second block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
],
enable_autocast
=
True
,
pure_fp16
=
True
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_with_dropout
(
self
):
with
_test_eager_guard
():
self
.
test_base_case
()
self
.
test_base_case
()
# recompute fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
3
],
enable_autocast
=
True
,
pure_fp16
=
True
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_without_restore_rng
(
self
):
with
_test_eager_guard
():
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
{
"preserve_rng_state"
:
False
},
enable_autocast
=
True
)
# recompute second to fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
2
,
3
],
enable_autocast
=
True
,
pure_fp16
=
True
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_with_amp
(
self
):
with
_test_eager_guard
():
self
.
test_base_case
(
enable_autocast
=
True
)
self
.
test_base_case
(
enable_autocast
=
True
)
# recompute second & fourth block
loss
,
param
,
grad
=
run_model
(
recompute_block
=
[
1
,
3
],
enable_autocast
=
True
,
pure_fp16
=
True
)
check_identical
(
loss_ref
,
param_ref
,
grad_ref
,
loss
,
param
,
grad
)
def
test_fc_net_with_fp16
(
self
):
with
_test_eager_guard
():
self
.
test_base_case
(
enable_autocast
=
True
,
pure_fp16
=
True
)
self
.
test_base_case
(
enable_autocast
=
True
,
pure_fp16
=
True
)
def
test_recompute_kwargs
(
self
):
with
_test_eager_guard
():
paddle
.
set_device
(
"gpu"
)
kwargs
=
{
"is_test"
:
False
}
with
self
.
assertRaises
(
ValueError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
)
paddle
.
set_device
(
"gpu"
)
kwargs
=
{
"is_test"
:
False
}
with
self
.
assertRaises
(
ValueError
):
...
...
@@ -238,6 +218,11 @@ class TestPyLayer(unittest.TestCase):
recompute_block
=
[
2
],
recompute_kwargs
=
kwargs
)
def
test_recompute_cpu_rng
(
self
):
with
_test_eager_guard
():
paddle
.
set_device
(
"cpu"
)
with
self
.
assertRaises
(
RuntimeError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
])
paddle
.
set_device
(
"cpu"
)
with
self
.
assertRaises
(
RuntimeError
):
loss_ref
,
param_ref
,
grad_ref
=
run_model
(
recompute_block
=
[
2
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录