提交 c49d7615 编写于 作者: W wanghaoshuang

1. Make block mask optional in sa nas.

2. Fix unittest of sa nas
上级 05264ac2
...@@ -60,7 +60,7 @@ class SANAS(object): ...@@ -60,7 +60,7 @@ class SANAS(object):
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._is_server = is_server self._is_server = is_server
self._configs = configs self._configs = configs
self._keys = hashlib.md5(str(self._configs)).hexdigest() self._key = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
...@@ -90,6 +90,7 @@ class SANAS(object): ...@@ -90,6 +90,7 @@ class SANAS(object):
search_steps=search_steps, search_steps=search_steps,
key=self._key) key=self._key)
self._controller_server.start() self._controller_server.start()
server_port = self._controller_server.port()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
server_ip, server_port, key=self._key) server_ip, server_port, key=self._key)
......
...@@ -51,8 +51,11 @@ class CombineSearchSpace(object): ...@@ -51,8 +51,11 @@ class CombineSearchSpace(object):
model space(class) model space(class)
""" """
cls = SEARCHSPACE.get(key) cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'], block_mask = config['block_mask'] if 'block_mask' in config else None
config['block_num'], config['block_mask']) space = cls(config['input_size'],
config['output_size'],
config['block_num'],
block_mask=block_mask)
return space return space
......
...@@ -19,7 +19,12 @@ class SearchSpaceBase(object): ...@@ -19,7 +19,12 @@ class SearchSpaceBase(object):
"""Controller for Neural Architecture Search. """Controller for Neural Architecture Search.
""" """
def __init__(self, input_size, output_size, block_num, block_mask, *argss): def __init__(self,
input_size,
output_size,
block_num,
block_mask=None,
*argss):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.block_num = block_num self.block_num = block_num
......
...@@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase): ...@@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase):
base_flops = flops(main_program) base_flops = flops(main_program)
search_steps = 3 search_steps = 3
sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) sa_nas = SANAS(
configs,
search_steps=search_steps,
server_addr=("", 0),
is_server=True)
for i in range(search_steps): for i in range(search_steps):
archs = sa_nas.next_archs() archs = sa_nas.next_archs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册