From c49d7615a725e1402e9599c9d26df70c52809c33 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 22 Nov 2019 15:01:19 +0800 Subject: [PATCH] 1. Make block mask optional in sa nas. 2. Fix unittest of sa nas --- paddleslim/nas/sa_nas.py | 3 ++- paddleslim/nas/search_space/combine_search_space.py | 7 +++++-- paddleslim/nas/search_space/search_space_base.py | 7 ++++++- tests/test_sa_nas.py | 6 +++++- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index f57caaa6..b9dca29b 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -60,7 +60,7 @@ class SANAS(object): self._init_temperature = init_temperature self._is_server = is_server self._configs = configs - self._keys = hashlib.md5(str(self._configs)).hexdigest() + self._key = hashlib.md5(str(self._configs)).hexdigest() server_ip, server_port = server_addr if server_ip == None or server_ip == "": @@ -90,6 +90,7 @@ class SANAS(object): search_steps=search_steps, key=self._key) self._controller_server.start() + server_port = self._controller_server.port() self._controller_client = ControllerClient( server_ip, server_port, key=self._key) diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 667720a9..37459ebc 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -51,8 +51,11 @@ class CombineSearchSpace(object): model space(class) """ cls = SEARCHSPACE.get(key) - space = cls(config['input_size'], config['output_size'], - config['block_num'], config['block_mask']) + block_mask = config['block_mask'] if 'block_mask' in config else None + space = cls(config['input_size'], + config['output_size'], + config['block_num'], + block_mask=block_mask) return space diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index 6a83f860..c80a88b0 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -19,7 +19,12 @@ class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ - def __init__(self, input_size, output_size, block_num, block_mask, *argss): + def __init__(self, + input_size, + output_size, + block_num, + block_mask=None, + *argss): self.input_size = input_size self.output_size = output_size self.block_num = block_num diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index 5666e141..a4203a85 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase): base_flops = flops(main_program) search_steps = 3 - sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) + sa_nas = SANAS( + configs, + search_steps=search_steps, + server_addr=("", 0), + is_server=True) for i in range(search_steps): archs = sa_nas.next_archs() -- GitLab