diff --git a/benchmark/torch/AlphaZero/Coach.py b/benchmark/torch/AlphaZero/Coach.py index 01394b076db969db42a7277b5d95f82bd661db3d..ad856b9e96b31881f1af5a60253ac22a7d9c24a5 100644 --- a/benchmark/torch/AlphaZero/Coach.py +++ b/benchmark/torch/AlphaZero/Coach.py @@ -54,9 +54,9 @@ class Coach(): self.test_dataset = get_test_dataset() - def _run_remote_tasks(self, signal_queue): + def _run_remote_tasks(self, signal_queue, seed): # The remote actor will actually run on the local machine or other machines of xparl cluster - remote_actor = Actor(self.game, self.args) + remote_actor = Actor(self.game, self.args, seed) while True: # receive running task signal @@ -91,12 +91,12 @@ class Coach(): # connect to xparl cluster to submit jobs parl.connect(self.args.master_address) - for i in range(self.args.actors_num): + for seed in range(self.args.actors_num): signal_queue = queue.Queue() self.remote_actors_signal_queues.append(signal_queue) remote_thread = threading.Thread( - target=self._run_remote_tasks, args=(signal_queue, )) + target=self._run_remote_tasks, args=(signal_queue, seed)) remote_thread.setDaemon(True) remote_thread.start() diff --git a/benchmark/torch/AlphaZero/actor.py b/benchmark/torch/AlphaZero/actor.py index 5ed719b92d292903f81f7c92a983927bf5c9cab5..19d684bb47370126af394c07f571882c3b32948c 100644 --- a/benchmark/torch/AlphaZero/actor.py +++ b/benchmark/torch/AlphaZero/actor.py @@ -23,7 +23,8 @@ from utils import win_loss_draw @parl.remote_class class Actor(object): - def __init__(self, game, args): + def __init__(self, game, args, seed): + np.random.seed(seed) os.environ['OMP_NUM_THREADS'] = "1" self.game = game self.args = args