diff --git a/python/paddle/distributed/fleet/layers/mpu/random.py b/python/paddle/distributed/fleet/layers/mpu/random.py index aabb8dc3fa14efa3809dcc92e89e892c05f75321..c0d4f6320f4b6ed02ebeb9299afb4eebf1990c9d 100644 --- a/python/paddle/distributed/fleet/layers/mpu/random.py +++ b/python/paddle/distributed/fleet/layers/mpu/random.py @@ -51,10 +51,10 @@ class RNGStatesTracker: self.seeds_.add(seed) if name in self.states_: raise ValueError(f'state {name} already exists') - orig_rng_state = paddle.get_cuda_rng_state() + orig_rng_state = paddle.get_rng_state() paddle.seed(seed) - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_rng_state) + self.states_[name] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) def get_states_tracker(self): states = {} @@ -69,13 +69,13 @@ class RNGStatesTracker: def rng_state(self, name=MODEL_PARALLEL_RNG): if name not in self.states_: raise ValueError(f'state {name} does not exist') - orig_cuda_rng_state = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(self.states_[name]) + orig_rng_state = paddle.get_rng_state() + paddle.set_rng_state(self.states_[name]) try: yield finally: - self.states_[name] = paddle.get_cuda_rng_state() - paddle.set_cuda_rng_state(orig_cuda_rng_state) + self.states_[name] = paddle.get_rng_state() + paddle.set_rng_state(orig_rng_state) RNG_STATE_TRACKER = RNGStatesTracker()