From fdb09f05c6ae9b7627aeb9051217d9f51c6fde4d Mon Sep 17 00:00:00 2001 From: ceci3 Date: Tue, 31 Dec 2019 11:10:29 +0800 Subject: [PATCH] update nas (#18) --- paddleslim/common/sa_controller.py | 60 +++++++++++++++++++++--------- paddleslim/nas/sa_nas.py | 23 ++++++++++-- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index 2beb9eed..8a081761 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -14,6 +14,7 @@ """The controller used to search hyperparameters or neural architecture""" import os +import sys import copy import math import logging @@ -34,20 +35,21 @@ class SAController(EvolutionaryController): range_table=None, reduce_rate=0.85, init_temperature=1024, - max_try_times=None, + max_try_times=300, init_tokens=None, reward=-1, max_reward=-1, iters=0, best_tokens=None, constrain_func=None, - checkpoints=None): + checkpoints=None, + searched=None): """Initialize. Args: range_table(list): Range table. reduce_rate(float): The decay rate of temperature. init_temperature(float): Init temperature. - max_try_times(int): max try times before get legal tokens. + max_try_times(int): max try times before get legal tokens. Default: 300. init_tokens(list): The initial tokens. Default: None. reward(float): The reward of current tokens. Default: -1. max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. @@ -55,6 +57,7 @@ class SAController(EvolutionaryController): best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. + searched(dict): remember tokens which are searched. """ super(SAController, self).__init__() self._range_table = range_table @@ -70,6 +73,7 @@ class SAController(EvolutionaryController): self._best_tokens = best_tokens self._iter = iters self._checkpoints = checkpoints + self._searched = searched if searched != None else dict() def __getstate__(self): d = {} @@ -78,6 +82,18 @@ class SAController(EvolutionaryController): d[key] = self.__dict__[key] return d + @property + def best_tokens(self): + return self._best_tokens + + @property + def max_reward(self): + return self._max_reward + + @property + def current_tokens(self): + return self._tokens + def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. @@ -88,6 +104,7 @@ class SAController(EvolutionaryController): iter = int(iter) if iter > self._iter: self._iter = iter + self._searched[str(tokens)] = reward temperature = self._init_temperature * self._reduce_rate**self._iter if (reward > self._reward) or (np.random.random() <= math.exp( (reward - self._reward) / temperature)): @@ -112,22 +129,31 @@ class SAController(EvolutionaryController): tokens = control_token[:] else: tokens = self._tokens - new_tokens = tokens[:] - index = int(len(self._range_table[0]) * np.random.random()) - new_tokens[index] = np.random.randint(self._range_table[0][index], - self._range_table[1][index]) - _logger.debug("change index[{}] from {} to {}".format(index, tokens[ - index], new_tokens[index])) - if self._constrain_func is None or self._max_try_times is None: - return new_tokens - for _ in range(self._max_try_times): - if not self._constrain_func(new_tokens): - index = int(len(self._range_table[0]) * np.random.random()) - new_tokens = tokens[:] - new_tokens[index] = np.random.randint( - self._range_table[0][index], self._range_table[1][index]) + for it in range(self._max_try_times): + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], + self._range_table[1][index]) + _logger.debug("change index[{}] from {} to {}".format( + index, tokens[index], new_tokens[index])) + + if self._searched.has_key(str(new_tokens)): + _logger.debug('get next tokens including searched tokens: {}'. + format(new_tokens)) + continue else: + self._searched[str(new_tokens)] = -1 break + + if it == self._max_try_times - 1: + _logger.info( + "cannot get a effective search space which is not searched in max try times!!!" + ) + sys.exit() + + if self._constrain_func is None or self._max_try_times is None: + return new_tokens + return new_tokens def _save_checkpoint(self, output_dir): diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 34ec47b8..428385f5 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -93,29 +93,32 @@ class SANAS(object): premax_reward = scene['_max_reward'] prebest_tokens = scene['_best_tokens'] preiter = scene['_iter'] + psearched = screen['_searched'] else: preinit_tokens = init_tokens prereward = -1 premax_reward = -1 prebest_tokens = None preiter = 0 + psearched = None - controller = SAController( + self._controller = SAController( range_table, self._reduce_rate, self._init_temperature, - max_try_times=None, + max_try_times=500, init_tokens=preinit_tokens, reward=prereward, max_reward=premax_reward, iters=preiter, best_tokens=prebest_tokens, constrain_func=None, - checkpoints=save_checkpoint) + checkpoints=save_checkpoint, + searched = psearched) max_client_num = 100 self._controller_server = ControllerServer( - controller=controller, + controller=self._controller, address=(server_ip, server_port), max_client_num=max_client_num, search_steps=search_steps, @@ -137,6 +140,18 @@ class SANAS(object): def tokens2arch(self, tokens): return self._search_space.token2arch(tokens) + def current_info(self): + """ + Get current information, including best tokens, best reward in all the search, and current token. + Returns: + dict: a dictionary include best tokens, best reward and current reward. + """ + current_dict = dict() + current_dict['best_tokens'] = self._controller.best_tokens + current_dict['best_reward'] = self._controller.max_reward + current_dict['current_tokens'] = self._controller.current_tokens + return current_dict + def next_archs(self): """ Get next network architectures. -- GitLab