提交 f68db656 编写于 作者: W wanghaoshuang

Fix sa controller.

上级 96daf92f
...@@ -195,11 +195,12 @@ def compress(args): ...@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=100, search_steps=100,
max_ratios=0.9, max_ratios=0.9,
min_ratios=0., min_ratios=0.,
is_server=True,
key="auto_pruner") key="auto_pruner")
while True: while True:
......
...@@ -101,7 +101,7 @@ class ControllerServer(object): ...@@ -101,7 +101,7 @@ class ControllerServer(object):
reward = messages[2] reward = messages[2]
iter = messages[3] iter = messages[3]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward), iter) self._controller.update(tokens, float(reward), int(iter))
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController): ...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=1024,
max_iter_number=300, max_try_times=None,
init_tokens=None, init_tokens=None,
constrain_func=None): constrain_func=None):
"""Initialize. """Initialize.
...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController): ...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table. range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature. reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature. init_temperature(float): Init temperature.
max_iter_number(int): max iteration number. max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens. init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
""" """
...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController): ...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2) len(self._range_table) == 2)
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_iter_number = max_iter_number self._max_try_times = max_try_times
self._reward = -1 self._reward = -1
self._tokens = init_tokens self._tokens = init_tokens
self._constrain_func = constrain_func self._constrain_func = constrain_func
...@@ -72,6 +72,7 @@ class SAController(EvolutionaryController): ...@@ -72,6 +72,7 @@ class SAController(EvolutionaryController):
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
""" """
iter = int(iter)
if iter > self._iter: if iter > self._iter:
self._iter = iter self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter temperature = self._init_temperature * self._reduce_rate**self._iter
...@@ -100,9 +101,9 @@ class SAController(EvolutionaryController): ...@@ -100,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1) self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[ _logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index])) index], new_tokens[index]))
if self._constrain_func is None: if self._constrain_func is None or self._max_try_times is None:
return new_tokens return new_tokens
for _ in range(self._max_iter_number): for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens): if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
......
...@@ -60,7 +60,7 @@ class SANAS(object): ...@@ -60,7 +60,7 @@ class SANAS(object):
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._is_server = is_server self._is_server = is_server
self._configs = configs self._configs = configs
self._keys = hashlib.md5(self._configs).hexdigest() self._keys = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
...@@ -75,8 +75,12 @@ class SANAS(object): ...@@ -75,8 +75,12 @@ class SANAS(object):
range_table = (len(range_table) * [0], range_table) range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table)) _logger.info("range table: {}".format(range_table))
controller = SAController( controller = SAController(
range_table, self._reduce_rate, self._init_temperature, range_table,
self._max_try_number, init_tokens, self._constrain_func) self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
max_client_num = 100 max_client_num = 100
self._controller_server = ControllerServer( self._controller_server = ControllerServer(
......
...@@ -42,7 +42,7 @@ class AutoPruner(object): ...@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=300, search_steps=300,
max_ratios=[0.9], max_ratios=[0.9],
...@@ -66,7 +66,7 @@ class AutoPruner(object): ...@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens. max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server. max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
...@@ -88,7 +88,7 @@ class AutoPruner(object): ...@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number self._max_try_times = max_try_times
self._is_server = is_server self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
...@@ -110,7 +110,7 @@ class AutoPruner(object): ...@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table)) _logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate, controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func) init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册