未验证 提交 23baa8c6 编写于 作者: H houj04 提交者: GitHub

[XPU] change cuda_rng_state to rng_state in fleet random (#54077)

上级 44044d80
......@@ -51,10 +51,10 @@ class RNGStatesTracker:
self.seeds_.add(seed)
if name in self.states_:
raise ValueError(f'state {name} already exists')
orig_rng_state = paddle.get_cuda_rng_state()
orig_rng_state = paddle.get_rng_state()
paddle.seed(seed)
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)
def get_states_tracker(self):
states = {}
......@@ -69,13 +69,13 @@ class RNGStatesTracker:
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
raise ValueError(f'state {name} does not exist')
orig_cuda_rng_state = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(self.states_[name])
orig_rng_state = paddle.get_rng_state()
paddle.set_rng_state(self.states_[name])
try:
yield
finally:
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_cuda_rng_state)
self.states_[name] = paddle.get_rng_state()
paddle.set_rng_state(orig_rng_state)
RNG_STATE_TRACKER = RNGStatesTracker()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册