Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
522c2bc0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
522c2bc0
编写于
12月 27, 2022
作者:
W
wanghuancoder
提交者:
GitHub
12月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
delete old dygraph pylayer recompute (#49338)
上级
2bbdc47a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
0 addition
and
142 deletion
+0
-142
python/paddle/distributed/fleet/recompute/recompute.py
python/paddle/distributed/fleet/recompute/recompute.py
+0
-142
未找到文件。
python/paddle/distributed/fleet/recompute/recompute.py
浏览文件 @
522c2bc0
...
@@ -18,7 +18,6 @@ import weakref
...
@@ -18,7 +18,6 @@ import weakref
import
paddle
import
paddle
from
paddle
import
framework
from
paddle
import
framework
from
paddle.autograd
import
PyLayer
from
paddle.autograd
import
PyLayer
from
paddle.autograd.py_layer
import
LegacyPyLayer
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
(
get_rng_state_tracker
,
get_rng_state_tracker
,
)
)
...
@@ -67,147 +66,6 @@ def swith_rng_state_tracker(rng_state, tracker):
...
@@ -67,147 +66,6 @@ def swith_rng_state_tracker(rng_state, tracker):
get_rng_state_tracker
().
set_states_tracker
(
orig_rng_tracker
)
get_rng_state_tracker
().
set_states_tracker
(
orig_rng_tracker
)
class
LegacyRecomputeFunction
(
LegacyPyLayer
):
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
# store for recomputing
ctx
.
run_function
=
run_function
ctx
.
preserve_rng_state
=
preserve_rng_state
# NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input
# the order of tensors in backward()'s output should be the same as tensors in forward()'s input
# None tensor inputs will be filtered in backward inputs.
# save input for backward
ctx
.
inputs
=
[]
ctx
.
tensor_indices
=
[]
tensor_inputs
=
[]
for
i
,
arg
in
enumerate
(
args
):
if
paddle
.
is_tensor
(
arg
):
tensor_inputs
.
append
(
arg
)
ctx
.
tensor_indices
.
append
(
i
)
ctx
.
inputs
.
append
(
None
)
else
:
ctx
.
inputs
.
append
(
arg
)
ctx
.
save_for_backward
(
*
tensor_inputs
)
# NOTE recompute with restore RNG only support one senario where one process for one cuda gpu.
# one process with multiple gpu and mix-gpu-cpu senarios are not support
if
ctx
.
preserve_rng_state
:
ctx
.
fw_rng_state
=
paddle
.
get_rng_state
()
ctx
.
fwd_rng_state_tracker
=
(
get_rng_state_tracker
().
get_states_tracker
()
)
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
ctx
.
is_fw_autocast
=
(
False
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O0
else
True
)
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O2
:
ctx
.
amp_level
=
'O2'
elif
tracer
.
_amp_level
in
(
core
.
AmpLevel
.
O1
,
core
.
AmpLevel
.
O0
):
ctx
.
amp_level
=
'O1'
else
:
raise
ValueError
(
"unsupported amp level: {}"
.
format
(
tracer
.
_amp_level
)
)
if
tracer
.
_amp_dtype
==
'float16'
:
ctx
.
amp_dtype
=
'float16'
elif
tracer
.
_amp_dtype
in
(
'bfloat16'
,
'float32'
):
ctx
.
amp_dtype
=
'bfloat16'
else
:
raise
ValueError
(
"unsupported amp dtype: {}"
.
format
(
tracer
.
_amp_dtype
)
)
ctx
.
amp_white_list
,
ctx
.
amp_black_list
=
tracer
.
_get_amp_op_list
()
with
paddle
.
no_grad
():
outputs
=
run_function
(
*
args
)
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
with
paddle
.
fluid
.
dygraph
.
guard
():
# TODO need to check the recompute calling is vaild or not
# Restore inputs
inputs
=
list
(
ctx
.
inputs
)
tensor_indices
=
ctx
.
tensor_indices
tensors
=
ctx
.
saved_tensor
()
for
i
,
idx
in
enumerate
(
tensor_indices
):
inputs
[
idx
]
=
tensors
[
i
]
# paddle.enable_grad()
tracer
=
framework
.
_dygraph_tracer
()
tracer
.
_has_grad
=
True
# NOTE support AMP
# need restore auto_cast state as well as w/b list
if
ctx
.
preserve_rng_state
:
with
swith_rng_state_tracker
(
ctx
.
fw_rng_state
,
ctx
.
fwd_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_black_list
=
ctx
.
amp_black_list
,
level
=
ctx
.
amp_level
,
dtype
=
ctx
.
amp_dtype
,
):
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
else
:
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_black_list
=
ctx
.
amp_black_list
,
level
=
ctx
.
amp_level
,
dtype
=
ctx
.
amp_dtype
,
):
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
if
isinstance
(
outputs
,
core
.
VarBase
):
outputs
=
(
outputs
,)
assert
len
(
outputs
)
==
len
(
args
)
# run backward() with only tensor that requires grad
forward_outputs_with_grad
=
[]
# NOTE In Transformer-like network, if user put the attention mask into the recompute segment output,
# pylayer will force the stop_gradient of attention mask to be False, which will make the number of
# tensor that need grad does not match.
# the following backward_inputs_with_grad is used to avoid this case.
backward_inputs_with_grad
=
[]
for
i
in
range
(
len
(
outputs
)):
if
(
isinstance
(
outputs
[
i
],
core
.
VarBase
)
and
not
outputs
[
i
].
stop_gradient
):
forward_outputs_with_grad
.
append
(
outputs
[
i
])
backward_inputs_with_grad
.
append
(
args
[
i
])
if
len
(
forward_outputs_with_grad
)
==
0
:
raise
RuntimeError
(
"none of output has requires_grad=True, this recompute() is not necessary"
)
# actually backward
with
paddle
.
amp
.
auto_cast
(
enable
=
False
):
paddle
.
autograd
.
backward
(
forward_outputs_with_grad
,
backward_inputs_with_grad
)
grads
=
list
(
inp
.
_grad_ivar
()
for
inp
in
detached_inputs
if
isinstance
(
inp
,
core
.
VarBase
)
)
return
grads
class
RecomputeFunction
(
PyLayer
):
class
RecomputeFunction
(
PyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
,
**
kwargs
):
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
,
**
kwargs
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录