提交 656fabca 编写于 作者: W wanghaoshuang

Merge branch 'fix_nas' into 'develop'

Fix sa nas

See merge request !38
...@@ -195,11 +195,12 @@ def compress(args): ...@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=100, search_steps=100,
max_ratios=0.9, max_ratios=0.9,
min_ratios=0., min_ratios=0.,
is_server=True,
key="auto_pruner") key="auto_pruner")
while True: while True:
......
...@@ -38,7 +38,7 @@ class ControllerClient(object): ...@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -48,8 +48,8 @@ class ControllerClient(object): ...@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
.encode()) iter).encode())
response = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok": if response.strip('\n').split("\t") == "ok":
return True return True
......
...@@ -51,23 +51,8 @@ class ControllerServer(object): ...@@ -51,23 +51,8 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._socket_file = "./controller_server.socket"
def start(self): def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address) self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num) self._socket_server.listen(self._max_client_num)
...@@ -82,7 +67,6 @@ class ControllerServer(object): ...@@ -82,7 +67,6 @@ class ControllerServer(object):
def close(self): def close(self):
"""Close the server.""" """Close the server."""
self._closed = True self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!") _logger.info("server closed!")
def port(self): def port(self):
...@@ -109,14 +93,15 @@ class ControllerServer(object): ...@@ -109,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key): if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward), int(iter))
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController): ...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=1024,
max_iter_number=300, max_try_times=None,
init_tokens=None, init_tokens=None,
constrain_func=None): constrain_func=None):
"""Initialize. """Initialize.
...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController): ...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table. range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature. reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature. init_temperature(float): Init temperature.
max_iter_number(int): max iteration number. max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens. init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
""" """
...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController): ...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2) len(self._range_table) == 2)
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_iter_number = max_iter_number self._max_try_times = max_try_times
self._reward = -1 self._reward = -1
self._tokens = init_tokens self._tokens = init_tokens
self._constrain_func = constrain_func self._constrain_func = constrain_func
...@@ -65,15 +65,17 @@ class SAController(EvolutionaryController): ...@@ -65,15 +65,17 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key] d[key] = self.__dict__[key]
return d return d
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
""" """
self._iter += 1 iter = int(iter)
temperature = self._init_temperature * self._reduce_rate**self._iter if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
self._reward = reward self._reward = reward
...@@ -99,9 +101,9 @@ class SAController(EvolutionaryController): ...@@ -99,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1) self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[ _logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index])) index], new_tokens[index]))
if self._constrain_func is None: if self._constrain_func is None or self._max_try_times is None:
return new_tokens return new_tokens
for _ in range(self._max_iter_number): for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens): if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import socket import socket
import logging import logging
import numpy as np import numpy as np
import hashlib
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController from ..common import SAController
...@@ -58,38 +59,40 @@ class SANAS(object): ...@@ -58,38 +59,40 @@ class SANAS(object):
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._is_server = is_server self._is_server = is_server
self._configs = configs self._configs = configs
self._keys = hashlib.md5(str(self._configs)).hexdigest()
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
server_ip = self._get_host_ip() server_ip = self._get_host_ip()
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server # create controller server
if self._is_server: if self._is_server:
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=self._key)
self._controller_server.start() self._controller_server.start()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
self._controller_server.ip(), server_ip, server_port, key=self._key)
self._controller_server.port(),
key=key)
self._iter = 0 self._iter = 0
...@@ -115,4 +118,5 @@ class SANAS(object): ...@@ -115,4 +118,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure. bool: True means updating successfully while false means failure.
""" """
self._iter += 1 self._iter += 1
return self._controller_client.update(self._current_tokens, score) return self._controller_client.update(self._current_tokens, score,
self._iter)
...@@ -42,7 +42,7 @@ class AutoPruner(object): ...@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=300, search_steps=300,
max_ratios=[0.9], max_ratios=[0.9],
...@@ -66,7 +66,7 @@ class AutoPruner(object): ...@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens. max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server. max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
...@@ -88,7 +88,7 @@ class AutoPruner(object): ...@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number self._max_try_times = max_try_times
self._is_server = is_server self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
...@@ -110,7 +110,7 @@ class AutoPruner(object): ...@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table)) _logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate, controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func) init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
...@@ -212,7 +212,7 @@ class AutoPruner(object): ...@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope) self._restore(self._scope)
self._param_backup = {} self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios) tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score) self._controller_client.update(tokens, score, self._iter)
self._iter += 1 self._iter += 1
def _restore(self, scope): def _restore(self, scope):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册