未验证 提交 629f20e2 编写于 作者: R Roc 提交者: GitHub

Hybrid recompute for xpu (#50717)

上级 652d12cc
...@@ -94,10 +94,8 @@ class _HPRecomputeFunction(PyLayer): ...@@ -94,10 +94,8 @@ class _HPRecomputeFunction(PyLayer):
ctx.kwargs = kwargs ctx.kwargs = kwargs
# store the rng states # store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fwd_rng_state = paddle.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = ( ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker()
get_rng_state_tracker().get_states_tracker()
)
# save config info # save config info
ctx.mp_group = mp_group ctx.mp_group = mp_group
...@@ -112,7 +110,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -112,7 +110,7 @@ class _HPRecomputeFunction(PyLayer):
cur_device = paddle.get_device() cur_device = paddle.get_device()
assert ( assert (
'gpu:' in paddle.get_device() 'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device()
), "Recompute with RNG is not support current device: {}.".format( ), "Recompute with RNG is not support current device: {}.".format(
cur_device cur_device
) )
...@@ -203,7 +201,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -203,7 +201,7 @@ class _HPRecomputeFunction(PyLayer):
# need restore auto_cast state as well as w/b list # need restore auto_cast state as well as w/b list
with swith_rng_state_tracker( with swith_rng_state_tracker(
ctx.fwd_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker ctx.fwd_rng_state, ctx.fwd_rng_state_tracker
): ):
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册