提交 656fabca 编写于 作者: W wanghaoshuang

Merge branch 'fix_nas' into 'develop'

Fix sa nas

See merge request !38
......@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=100,
max_ratios=0.9,
min_ratios=0.,
is_server=True,
key="auto_pruner")
while True:
......
......@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
......@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
iter).encode())
response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok":
return True
......
......@@ -51,23 +51,8 @@ class ControllerServer(object):
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
......@@ -82,7 +67,6 @@ class ControllerServer(object):
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
......@@ -109,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
self._controller.update(tokens, float(reward), int(iter))
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
......
......@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
max_try_times=None,
init_tokens=None,
constrain_func=None):
"""Initialize.
......@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
......@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._max_try_times = max_try_times
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
......@@ -65,14 +65,16 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
iter = int(iter)
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
......@@ -99,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
for _ in range(self._max_iter_number):
for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
......
......@@ -15,6 +15,7 @@
import socket
import logging
import numpy as np
import hashlib
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
......@@ -58,38 +59,40 @@ class SANAS(object):
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
self._keys = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
# create controller server
if self._is_server:
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
key=self._key)
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
server_ip, server_port, key=self._key)
self._iter = 0
......@@ -115,4 +118,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure.
"""
self._iter += 1
return self._controller_client.update(self._current_tokens, score)
return self._controller_client.update(self._current_tokens, score,
self._iter)
......@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
......@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
......@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._max_try_times = max_try_times
self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios)
......@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
......@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope)
self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score)
self._controller_client.update(tokens, score, self._iter)
self._iter += 1
def _restore(self, scope):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册