提交 f68db656 编写于 作者: W wanghaoshuang

Fix sa controller.

上级 96daf92f
......@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=100,
max_ratios=0.9,
min_ratios=0.,
is_server=True,
key="auto_pruner")
while True:
......
......@@ -101,7 +101,7 @@ class ControllerServer(object):
reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward), iter)
self._controller.update(tokens, float(reward), int(iter))
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
......
......@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
max_try_times=None,
init_tokens=None,
constrain_func=None):
"""Initialize.
......@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
......@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._max_try_times = max_try_times
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
......@@ -72,6 +72,7 @@ class SAController(EvolutionaryController):
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
iter = int(iter)
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
......@@ -100,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
for _ in range(self._max_iter_number):
for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
......
......@@ -60,7 +60,7 @@ class SANAS(object):
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
self._keys = hashlib.md5(self._configs).hexdigest()
self._keys = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
......@@ -75,8 +75,12 @@ class SANAS(object):
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(
range_table, self._reduce_rate, self._init_temperature,
self._max_try_number, init_tokens, self._constrain_func)
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
max_client_num = 100
self._controller_server = ControllerServer(
......
......@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
......@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
......@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._max_try_times = max_try_times
self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios)
......@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册