Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f55c0b33
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f55c0b33
编写于
5月 16, 2022
作者:
S
ShenLiang
提交者:
GitHub
5月 16, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Bug]Fix recompute random in modelparallel (#42747)
* fix recompute in mp * fix recompute
上级
8eecd852
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
23 addition
and
22 deletion
+23
-22
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
.../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
+5
-19
python/paddle/distributed/fleet/utils/recompute.py
python/paddle/distributed/fleet/utils/recompute.py
+18
-3
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
浏览文件 @
f55c0b33
...
@@ -19,7 +19,7 @@ from paddle.fluid import core
...
@@ -19,7 +19,7 @@ from paddle.fluid import core
from
paddle
import
_C_ops
from
paddle
import
_C_ops
from
paddle.autograd
import
PyLayer
,
EagerPyLayer
from
paddle.autograd
import
PyLayer
,
EagerPyLayer
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
from
...utils.recompute
import
check_recompute_necessary
,
detach_variable
from
...utils.recompute
import
check_recompute_necessary
,
detach_variable
,
swith_rng_state_tracker
from
..parallel_layers.random
import
get_rng_state_tracker
from
..parallel_layers.random
import
get_rng_state_tracker
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid.framework
import
in_dygraph_mode
...
@@ -151,20 +151,6 @@ def _merge_activation(tensor):
...
@@ -151,20 +151,6 @@ def _merge_activation(tensor):
return
_all_gather
(
tensor
,
group
=
mp_group
)
return
_all_gather
(
tensor
,
group
=
mp_group
)
@
contextlib
.
contextmanager
def
_swith_rng_state_tracker
(
rng_state
,
tracker
):
orig_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
orig_cuda_rng_tracker
=
get_rng_state_tracker
().
get_states_tracker
()
paddle
.
set_cuda_rng_state
(
rng_state
)
get_rng_state_tracker
().
set_states_tracker
(
tracker
)
try
:
yield
finally
:
paddle
.
set_cuda_rng_state
(
orig_cuda_rng_state
)
get_rng_state_tracker
().
set_states_tracker
(
orig_cuda_rng_tracker
)
class
_HPEagerRecomputeFunction
(
EagerPyLayer
):
class
_HPEagerRecomputeFunction
(
EagerPyLayer
):
"""
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
...
@@ -261,7 +247,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
...
@@ -261,7 +247,7 @@ class _HPEagerRecomputeFunction(EagerPyLayer):
tracer
.
_has_grad
=
True
tracer
.
_has_grad
=
True
# need restore auto_cast state as well as w/b list
# need restore auto_cast state as well as w/b list
with
_
swith_rng_state_tracker
(
ctx
.
fwd_cuda_rng_state
,
with
swith_rng_state_tracker
(
ctx
.
fwd_cuda_rng_state
,
ctx
.
fwd_cuda_rng_state_tracker
):
ctx
.
fwd_cuda_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
enable
=
ctx
.
is_fw_autocast
,
...
@@ -393,7 +379,7 @@ class _HPRecomputeFunction(PyLayer):
...
@@ -393,7 +379,7 @@ class _HPRecomputeFunction(PyLayer):
tracer
.
_has_grad
=
True
tracer
.
_has_grad
=
True
# need restore auto_cast state as well as w/b list
# need restore auto_cast state as well as w/b list
with
_
swith_rng_state_tracker
(
ctx
.
fwd_cuda_rng_state
,
with
swith_rng_state_tracker
(
ctx
.
fwd_cuda_rng_state
,
ctx
.
fwd_cuda_rng_state_tracker
):
ctx
.
fwd_cuda_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
enable
=
ctx
.
is_fw_autocast
,
...
...
python/paddle/distributed/fleet/utils/recompute.py
浏览文件 @
f55c0b33
...
@@ -53,18 +53,24 @@ def check_recompute_necessary(inputs):
...
@@ -53,18 +53,24 @@ def check_recompute_necessary(inputs):
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
swith_rng_state
(
rng_state
):
def
swith_rng_state_tracker
(
rng_state
,
tracker
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
get_rng_state_tracker
orig_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
orig_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
orig_cuda_rng_tracker
=
get_rng_state_tracker
().
get_states_tracker
()
paddle
.
set_cuda_rng_state
(
rng_state
)
paddle
.
set_cuda_rng_state
(
rng_state
)
get_rng_state_tracker
().
set_states_tracker
(
tracker
)
try
:
try
:
yield
yield
finally
:
finally
:
paddle
.
set_cuda_rng_state
(
orig_cuda_rng_state
)
paddle
.
set_cuda_rng_state
(
orig_cuda_rng_state
)
get_rng_state_tracker
().
set_states_tracker
(
orig_cuda_rng_tracker
)
class
EagerRecomputeFunction
(
EagerPyLayer
):
class
EagerRecomputeFunction
(
EagerPyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
get_rng_state_tracker
if
framework
.
_dygraph_tracer
().
_has_grad
:
if
framework
.
_dygraph_tracer
().
_has_grad
:
check_recompute_necessary
(
args
)
check_recompute_necessary
(
args
)
...
@@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer):
...
@@ -98,6 +104,8 @@ class EagerRecomputeFunction(EagerPyLayer):
"Recompute with RNG perserve is not support current device: {}."
.
"Recompute with RNG perserve is not support current device: {}."
.
format
(
cur_device
))
format
(
cur_device
))
ctx
.
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
ctx
.
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_rng_state_tracker
(
).
get_states_tracker
()
# TODO support AMP
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
tracer
=
framework
.
_dygraph_tracer
()
...
@@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer):
...
@@ -126,6 +134,7 @@ class EagerRecomputeFunction(EagerPyLayer):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
):
def
backward
(
ctx
,
*
args
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
get_rng_state_tracker
with
paddle
.
fluid
.
dygraph
.
guard
():
with
paddle
.
fluid
.
dygraph
.
guard
():
# TODO need to check the recompute calling is vaild or not
# TODO need to check the recompute calling is vaild or not
...
@@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer):
...
@@ -143,7 +152,8 @@ class EagerRecomputeFunction(EagerPyLayer):
# NOTE support AMP
# NOTE support AMP
# need restore auto_cast state as well as w/b list
# need restore auto_cast state as well as w/b list
if
ctx
.
preserve_rng_state
:
if
ctx
.
preserve_rng_state
:
with
swith_rng_state
(
ctx
.
fw_cuda_rng_state
):
with
swith_rng_state_tracker
(
ctx
.
fw_cuda_rng_state
,
ctx
.
fwd_cuda_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_white_list
=
ctx
.
amp_white_list
,
...
@@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer):
...
@@ -199,6 +209,7 @@ class EagerRecomputeFunction(EagerPyLayer):
class
RecomputeFunction
(
PyLayer
):
class
RecomputeFunction
(
PyLayer
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
def
forward
(
ctx
,
run_function
,
preserve_rng_state
,
*
args
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
get_rng_state_tracker
if
framework
.
_dygraph_tracer
().
_has_grad
:
if
framework
.
_dygraph_tracer
().
_has_grad
:
check_recompute_necessary
(
args
)
check_recompute_necessary
(
args
)
...
@@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer):
...
@@ -232,6 +243,8 @@ class RecomputeFunction(PyLayer):
"Recompute with RNG perserve is not support current device: {}."
.
"Recompute with RNG perserve is not support current device: {}."
.
format
(
cur_device
))
format
(
cur_device
))
ctx
.
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
ctx
.
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
ctx
.
fwd_cuda_rng_state_tracker
=
get_rng_state_tracker
(
).
get_states_tracker
()
# TODO support AMP
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
tracer
=
framework
.
_dygraph_tracer
()
...
@@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer):
...
@@ -260,6 +273,7 @@ class RecomputeFunction(PyLayer):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
*
args
):
def
backward
(
ctx
,
*
args
):
from
paddle.distributed.fleet.meta_parallel.parallel_layers.random
import
get_rng_state_tracker
with
paddle
.
fluid
.
dygraph
.
guard
():
with
paddle
.
fluid
.
dygraph
.
guard
():
# TODO need to check the recompute calling is vaild or not
# TODO need to check the recompute calling is vaild or not
...
@@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer):
...
@@ -277,7 +291,8 @@ class RecomputeFunction(PyLayer):
# NOTE support AMP
# NOTE support AMP
# need restore auto_cast state as well as w/b list
# need restore auto_cast state as well as w/b list
if
ctx
.
preserve_rng_state
:
if
ctx
.
preserve_rng_state
:
with
swith_rng_state
(
ctx
.
fw_cuda_rng_state
):
with
swith_rng_state_tracker
(
ctx
.
fw_cuda_rng_state
,
ctx
.
fwd_cuda_rng_state_tracker
):
with
paddle
.
amp
.
auto_cast
(
with
paddle
.
amp
.
auto_cast
(
enable
=
ctx
.
is_fw_autocast
,
enable
=
ctx
.
is_fw_autocast
,
custom_white_list
=
ctx
.
amp_white_list
,
custom_white_list
=
ctx
.
amp_white_list
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录