提交 7279e146 编写于 作者: W wanghaoshuang

Refine sa nas.

上级 72c800e9
......@@ -33,40 +33,31 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
def __init__(self,
configs,
max_flops=None,
max_latency=None,
server_addr=("", 0),
server_addr=("", 8881),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300,
key="sa_nas",
is_server=True):
is_server=False):
"""
Search a group of ratios used to prune program.
Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs
......@@ -75,9 +66,7 @@ class SANAS(object):
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
_logger.info("range table: {}".format(range_table))
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
......@@ -85,7 +74,7 @@ class SANAS(object):
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
......@@ -107,24 +96,6 @@ class SANAS(object):
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self):
"""
Get next network architectures.
......
......@@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase):
base_flops = flops(main_program)
search_steps = 3
sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps)
sa_nas = SANAS(configs, search_steps=search_steps, is_server=True)
for i in range(search_steps):
archs = sa_nas.next_archs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册