提交 c49d7615 编写于 作者: W wanghaoshuang

1. Make block mask optional in sa nas.

2. Fix unittest of sa nas
上级 05264ac2
......@@ -60,7 +60,7 @@ class SANAS(object):
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
self._keys = hashlib.md5(str(self._configs)).hexdigest()
self._key = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
......@@ -90,6 +90,7 @@ class SANAS(object):
search_steps=search_steps,
key=self._key)
self._controller_server.start()
server_port = self._controller_server.port()
self._controller_client = ControllerClient(
server_ip, server_port, key=self._key)
......
......@@ -51,8 +51,11 @@ class CombineSearchSpace(object):
model space(class)
"""
cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'],
config['block_num'], config['block_mask'])
block_mask = config['block_mask'] if 'block_mask' in config else None
space = cls(config['input_size'],
config['output_size'],
config['block_num'],
block_mask=block_mask)
return space
......
......@@ -19,7 +19,12 @@ class SearchSpaceBase(object):
"""Controller for Neural Architecture Search.
"""
def __init__(self, input_size, output_size, block_num, block_mask, *argss):
def __init__(self,
input_size,
output_size,
block_num,
block_mask=None,
*argss):
self.input_size = input_size
self.output_size = output_size
self.block_num = block_num
......
......@@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase):
base_flops = flops(main_program)
search_steps = 3
sa_nas = SANAS(configs, search_steps=search_steps, is_server=True)
sa_nas = SANAS(
configs,
search_steps=search_steps,
server_addr=("", 0),
is_server=True)
for i in range(search_steps):
archs = sa_nas.next_archs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册