From 23baa8c60959dc60e6f0eae4be2ff6ae2c95d50a Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Thu, 25 May 2023 06:26:45 +0800 Subject: [PATCH] [XPU] change cuda_rng_state to rng_state in fleet random (#54077) --- .../paddle/distributed/fleet/layers/mpu/random.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index aabb8dc3fa1..c0d4f6320f4 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -51,10 +51,10 @@ class RNGStatesTracker: self.seeds_.add(seed) if name in self.states_: raise ValueError(f'state {name} already exists') - orig_rng_state = paddle.get_cuda_rng_state() + orig_rng_state = paddle.get_rng_state() paddle.seed(seed) - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_rng_state) + self.states_[name] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) def get_states_tracker(self): states = {} @@ -69,13 +69,13 @@ class RNGStatesTracker: def rng_state(self, name=MODEL_PARALLEL_RNG): if name not in self.states_: raise ValueError(f'state {name} does not exist') - orig_cuda_rng_state = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(self.states_[name]) + orig_rng_state = paddle.get_rng_state() + paddle.set_rng_state(self.states_[name]) try: yield finally: - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_cuda_rng_state) + self.states_[name] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) RNG_STATE_TRACKER = RNGStatesTracker() -- GitLab