未验证 提交 58fe161f 编写于 作者: H haosicheng 提交者: GitHub

fix recompute not support bug on xpu (#54367)

上级 0ff446cd
...@@ -224,6 +224,8 @@ def _recompute_without_reentrant( ...@@ -224,6 +224,8 @@ def _recompute_without_reentrant(
cur_device = paddle.get_device() cur_device = paddle.get_device()
if 'gpu:' in cur_device: if 'gpu:' in cur_device:
fw_cuda_rng_state = paddle.get_cuda_rng_state() fw_cuda_rng_state = paddle.get_cuda_rng_state()
elif 'xpu:' in cur_device:
fw_cuda_rng_state = paddle.get_rng_state()
elif ( elif (
cur_device.split(':')[0] cur_device.split(':')[0]
in paddle.device.get_all_custom_device_type() in paddle.device.get_all_custom_device_type()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册