From 58fe161f047966e7bfb6f3073541c4184a66d472 Mon Sep 17 00:00:00 2001 From: haosicheng <47998305+HarperCy@users.noreply.github.com> Date: Tue, 6 Jun 2023 16:12:03 +0800 Subject: [PATCH] fix recompute not support bug on xpu (#54367) --- python/paddle/distributed/fleet/recompute/recompute.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index ec0690a10bc..b3bf3889a34 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -224,6 +224,8 @@ def _recompute_without_reentrant( cur_device = paddle.get_device() if 'gpu:' in cur_device: fw_cuda_rng_state = paddle.get_cuda_rng_state() + elif 'xpu:' in cur_device: + fw_cuda_rng_state = paddle.get_rng_state() elif ( cur_device.split(':')[0] in paddle.device.get_all_custom_device_type() -- GitLab