未验证 提交 fdb09f05 编写于 作者: C ceci3 提交者: GitHub

update nas (#18)

上级 cbaac40d
......@@ -14,6 +14,7 @@
"""The controller used to search hyperparameters or neural architecture"""
import os
import sys
import copy
import math
import logging
......@@ -34,20 +35,21 @@ class SAController(EvolutionaryController):
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_try_times=None,
max_try_times=300,
init_tokens=None,
reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None):
checkpoints=None,
searched=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_try_times(int): max try times before get legal tokens.
max_try_times(int): max try times before get legal tokens. Default: 300.
init_tokens(list<int>): The initial tokens. Default: None.
reward(float): The reward of current tokens. Default: -1.
max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1.
......@@ -55,6 +57,7 @@ class SAController(EvolutionaryController):
best_tokens(list<int>): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
searched(dict<list, float>): remember tokens which are searched.
"""
super(SAController, self).__init__()
self._range_table = range_table
......@@ -70,6 +73,7 @@ class SAController(EvolutionaryController):
self._best_tokens = best_tokens
self._iter = iters
self._checkpoints = checkpoints
self._searched = searched if searched != None else dict()
def __getstate__(self):
d = {}
......@@ -78,6 +82,18 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key]
return d
@property
def best_tokens(self):
return self._best_tokens
@property
def max_reward(self):
return self._max_reward
@property
def current_tokens(self):
return self._tokens
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
......@@ -88,6 +104,7 @@ class SAController(EvolutionaryController):
iter = int(iter)
if iter > self._iter:
self._iter = iter
self._searched[str(tokens)] = reward
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
......@@ -112,22 +129,31 @@ class SAController(EvolutionaryController):
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index])
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(
self._range_table[0][index], self._range_table[1][index])
for it in range(self._max_try_times):
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index])
_logger.debug("change index[{}] from {} to {}".format(
index, tokens[index], new_tokens[index]))
if self._searched.has_key(str(new_tokens)):
_logger.debug('get next tokens including searched tokens: {}'.
format(new_tokens))
continue
else:
self._searched[str(new_tokens)] = -1
break
if it == self._max_try_times - 1:
_logger.info(
"cannot get a effective search space which is not searched in max try times!!!"
)
sys.exit()
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
return new_tokens
def _save_checkpoint(self, output_dir):
......
......@@ -93,29 +93,32 @@ class SANAS(object):
premax_reward = scene['_max_reward']
prebest_tokens = scene['_best_tokens']
preiter = scene['_iter']
psearched = screen['_searched']
else:
preinit_tokens = init_tokens
prereward = -1
premax_reward = -1
prebest_tokens = None
preiter = 0
psearched = None
controller = SAController(
self._controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
max_try_times=500,
init_tokens=preinit_tokens,
reward=prereward,
max_reward=premax_reward,
iters=preiter,
best_tokens=prebest_tokens,
constrain_func=None,
checkpoints=save_checkpoint)
checkpoints=save_checkpoint,
searched = psearched)
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
controller=self._controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
......@@ -137,6 +140,18 @@ class SANAS(object):
def tokens2arch(self, tokens):
return self._search_space.token2arch(tokens)
def current_info(self):
"""
Get current information, including best tokens, best reward in all the search, and current token.
Returns:
dict<name, value>: a dictionary include best tokens, best reward and current reward.
"""
current_dict = dict()
current_dict['best_tokens'] = self._controller.best_tokens
current_dict['best_reward'] = self._controller.max_reward
current_dict['current_tokens'] = self._controller.current_tokens
return current_dict
def next_archs(self):
"""
Get next network architectures.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册