# Third party code # # The following code are copied or modified from: # https://github.com/suragnair/alpha-zero-general import math import time import numpy as np EPS = 1e-8 class MCTS(): """ This class handles the MCTS tree. """ def __init__(self, game, nn_agent, args, dirichlet_noise=False): self.game = game self.nn_agent = nn_agent self.args = args self.dirichlet_noise = dirichlet_noise self.Qsa = {} # stores Q values for s,a (as defined in the paper) self.Nsa = {} # stores #times edge s,a was visited self.Ns = {} # stores #times board s was visited self.Ps = {} # stores initial policy (returned by neural net) self.Es = {} # stores game.getGameEnded ended for board s self.Vs = {} # stores game.getValidMoves for board s def getActionProb(self, canonicalBoard, temp=1): """ This function performs numMCTSSims simulations of MCTS starting from canonicalBoard. Returns: probs: a policy vector where the probability of the ith action is proportional to Nsa[(s,a)]**(1./temp) """ for i in range(self.args.numMCTSSims): dir_noise = (i == 0 and self.dirichlet_noise) self.search(canonicalBoard, dirichlet_noise=dir_noise) s = self.game.stringRepresentation(canonicalBoard) counts = [ self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize()) ] if temp == 0: bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten() bestA = np.random.choice(bestAs) probs = [0] * len(counts) probs[bestA] = 1 return probs counts = [x**(1. / temp) for x in counts] counts_sum = float(sum(counts)) probs = [x / counts_sum for x in counts] return probs def search(self, canonicalBoard, dirichlet_noise=False): """ This function performs one iteration of MCTS. It is recursively called till a leaf node is found. The action chosen at each node is one that has the maximum upper confidence bound as in the paper. Once a leaf node is found, the neural network is called to return an initial policy P and a value v for the state. This value is propagated up the search path. In case the leaf node is a terminal state, the outcome is propagated up the search path. The values of Ns, Nsa, Qsa are updated. NOTE: the return values are the negative of the value of the current state. This is done since v is in [-1,1] and if v is the value of a state for the current player, then its value is -v for the other player. Returns: v: the negative of the value of the current canonicalBoard """ s = self.game.stringRepresentation(canonicalBoard) if s not in self.Es: self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) if self.Es[s] != 0: # terminal node return -self.Es[s] if s not in self.Ps: # leaf node self.Ps[s], v = self.nn_agent.predict(canonicalBoard) valids = self.game.getValidMoves(canonicalBoard, 1) self.Ps[s] = self.Ps[s] * valids # masking invalid moves if dirichlet_noise: self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) if sum_Ps_s > 0: self.Ps[s] /= sum_Ps_s # renormalize else: # if all valid moves were masked make all valid moves equally probable # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else. # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process. print("All valid moves were masked, doing a workaround.") self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s]) self.Vs[s] = valids self.Ns[s] = 0 return -v valids = self.Vs[s] if dirichlet_noise: self.applyDirNoise(s, valids) sum_Ps_s = np.sum(self.Ps[s]) self.Ps[s] /= sum_Ps_s # renormalize cur_best = -float('inf') best_act = -1 # pick the action with the highest upper confidence bound for a in range(self.game.getActionSize()): if valids[a]: if (s, a) in self.Qsa: u = self.Qsa[ (s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt( self.Ns[s]) / (1 + self.Nsa[(s, a)]) else: u = self.args.cpuct * self.Ps[s][a] * math.sqrt( self.Ns[s] + EPS) # Q = 0 ? if u > cur_best: cur_best = u best_act = a a = best_act next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) next_s = self.game.getCanonicalForm(next_s, next_player) v = self.search(next_s) if (s, a) in self.Qsa: self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[ (s, a)] + v) / (self.Nsa[(s, a)] + 1) self.Nsa[(s, a)] += 1 else: self.Qsa[(s, a)] = v self.Nsa[(s, a)] = 1 self.Ns[s] += 1 return -v def applyDirNoise(self, s, valids): dir_values = np.random.dirichlet( [self.args.dirichletAlpha] * np.count_nonzero(valids)) dir_idx = 0 for idx in range(len(self.Ps[s])): if self.Ps[s][idx]: self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + ( 0.25 * dir_values[dir_idx]) dir_idx += 1