diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index f57caaa6beb6fec59b618a689b44652f0cf259fc..b9dca29b49f5f298a740ca5d40e99d27c5558de7 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -60,7 +60,7 @@ class SANAS(object): self._init_temperature = init_temperature self._is_server = is_server self._configs = configs - self._keys = hashlib.md5(str(self._configs)).hexdigest() + self._key = hashlib.md5(str(self._configs)).hexdigest() server_ip, server_port = server_addr if server_ip == None or server_ip == "": @@ -90,6 +90,7 @@ class SANAS(object): search_steps=search_steps, key=self._key) self._controller_server.start() + server_port = self._controller_server.port() self._controller_client = ControllerClient( server_ip, server_port, key=self._key) diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 667720a9110aa92e096a4f8fa30bb3e4b3e3cecb..37459ebc5d351f7149d5b87737e190569ccb1bbe 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -51,8 +51,11 @@ class CombineSearchSpace(object): model space(class) """ cls = SEARCHSPACE.get(key) - space = cls(config['input_size'], config['output_size'], - config['block_num'], config['block_mask']) + block_mask = config['block_mask'] if 'block_mask' in config else None + space = cls(config['input_size'], + config['output_size'], + config['block_num'], + block_mask=block_mask) return space diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index 6a83f86005a5fb2408f7f85f40dff8a9e5cba819..c80a88b0209ff306188b827efa6249e1f1f03142 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -19,7 +19,12 @@ class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ - def __init__(self, input_size, output_size, block_num, block_mask, *argss): + def __init__(self, + input_size, + output_size, + block_num, + block_mask=None, + *argss): self.input_size = input_size self.output_size = output_size self.block_num = block_num diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index 5666e1410a820c09bc10fa0b10d282434c7837fe..a4203a85a898632ac2102eb61ab7dd7b475e73ef 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase): base_flops = flops(main_program) search_steps = 3 - sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) + sa_nas = SANAS( + configs, + search_steps=search_steps, + server_addr=("", 0), + is_server=True) for i in range(search_steps): archs = sa_nas.next_archs()