未验证 提交 629f20e2 编写于 作者: R Roc 提交者: GitHub

Hybrid recompute for xpu (#50717)

上级 652d12cc
......@@ -94,10 +94,8 @@ class _HPRecomputeFunction(PyLayer):
ctx.kwargs = kwargs
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = (
get_rng_state_tracker().get_states_tracker()
)
ctx.fwd_rng_state = paddle.get_rng_state()
ctx.fwd_rng_state_tracker = get_rng_state_tracker().get_states_tracker()
# save config info
ctx.mp_group = mp_group
......@@ -112,7 +110,7 @@ class _HPRecomputeFunction(PyLayer):
cur_device = paddle.get_device()
assert (
'gpu:' in paddle.get_device()
'gpu:' in paddle.get_device() or 'xpu:' in paddle.get_device()
), "Recompute with RNG is not support current device: {}.".format(
cur_device
)
......@@ -203,7 +201,7 @@ class _HPRecomputeFunction(PyLayer):
# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(
ctx.fwd_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker
ctx.fwd_rng_state, ctx.fwd_rng_state_tracker
):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册