diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index ec0690a10bc33f80c069d5bfaca1c8620f7e5a3b..b3bf3889a347b5c487ea454543f324a1edfdc63e 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -224,6 +224,8 @@ def _recompute_without_reentrant( cur_device = paddle.get_device() if 'gpu:' in cur_device: fw_cuda_rng_state = paddle.get_cuda_rng_state() + elif 'xpu:' in cur_device: + fw_cuda_rng_state = paddle.get_rng_state() elif ( cur_device.split(':')[0] in paddle.device.get_all_custom_device_type()