提交 7d0e73e8 编写于 作者: L lvmengsi

Merge branch 'refine_nas' into 'develop'

Remove unnecessary arguments in sa nas API

See merge request !36
...@@ -39,7 +39,7 @@ def init_sa_nas(config): ...@@ -39,7 +39,7 @@ def init_sa_nas(config):
search_steps = 10000000 search_steps = 10000000
### start a server and a client ### start a server and a client
sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) sa_nas = SANAS(config, search_steps=search_steps, is_server=True)
### start a client, server_addr is server address ### start a client, server_addr is server address
#sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False)
......
...@@ -33,40 +33,31 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -33,40 +33,31 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object): class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
max_flops=None, server_addr=("", 8881),
max_latency=None,
server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300, search_steps=300,
key="sa_nas", key="sa_nas",
is_server=True): is_server=False):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num). configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are `key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients. key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs self._configs = configs
...@@ -75,9 +66,7 @@ class SANAS(object): ...@@ -75,9 +66,7 @@ class SANAS(object):
init_tokens = self._search_space.init_tokens() init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table() range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table) range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
print range_table
controller = SAController(range_table, self._reduce_rate, controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func) init_tokens, self._constrain_func)
...@@ -85,7 +74,7 @@ class SANAS(object): ...@@ -85,7 +74,7 @@ class SANAS(object):
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
server_ip = self._get_host_ip() server_ip = self._get_host_ip()
max_client_num = 100
self._controller_server = ControllerServer( self._controller_server = ControllerServer(
controller=controller, controller=controller,
address=(server_ip, server_port), address=(server_ip, server_port),
...@@ -107,24 +96,6 @@ class SANAS(object): ...@@ -107,24 +96,6 @@ class SANAS(object):
def _get_host_ip(self): def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self): def next_archs(self):
""" """
Get next network architectures. Get next network architectures.
......
...@@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase): ...@@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase):
base_flops = flops(main_program) base_flops = flops(main_program)
search_steps = 3 search_steps = 3
sa_nas = SANAS( sa_nas = SANAS(configs, search_steps=search_steps, is_server=True)
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(search_steps): for i in range(search_steps):
archs = sa_nas.next_archs() archs = sa_nas.next_archs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册