From 043d3bf46272f5e350fde5884911f503e103017f Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 6 Feb 2020 15:21:40 +0800 Subject: [PATCH] Fix nas api unittest (#90) --- paddleslim/common/controller_client.py | 2 +- paddleslim/common/controller_server.py | 1 + paddleslim/common/sa_controller.py | 4 +- paddleslim/nas/sa_nas.py | 1 - tests/test_nas_search_space.py | 69 --------------- tests/test_sa_nas.py | 112 +++++++++++++++++-------- 6 files changed, 78 insertions(+), 111 deletions(-) delete mode 100644 tests/test_nas_search_space.py diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index 36ae4990..c60bcfaf 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -56,7 +56,7 @@ class ControllerClient(object): socket_client.send("{}\t{}\t{}\t{}\t{}".format( self._key, tokens, reward, iter, self._client_name).encode()) response = socket_client.recv(1024).decode() - if response.strip('\n').split("\t") == "ok": + if "ok" in response.strip('\n').split("\t"): return True else: return False diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index fb4d76bc..5e1ef737 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -66,6 +66,7 @@ class ControllerServer(object): _logger.info("ControllerServer - listen on: [{}:{}]".format( self._ip, self._port)) thread = Thread(target=self.run) + thread.setDaemon(True) thread.start() return str(thread) diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index bc80045a..b1034762 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -115,7 +115,6 @@ class SAController(EvolutionaryController): self._searched[str(tokens)] = reward temperature = self._init_temperature * self._reduce_rate**(client_num * self._iter) - self._current_tokens = tokens if (reward > self._reward) or (np.random.random() <= math.exp( (reward - self._reward) / temperature)): self._reward = reward @@ -164,8 +163,7 @@ class SAController(EvolutionaryController): ) sys.exit() - if self._constrain_func is None or self._max_try_times is None: - return new_tokens + self._current_tokens = new_tokens return new_tokens diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 02dd5e92..245f9048 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -20,7 +20,6 @@ import json import hashlib import time import paddle.fluid as fluid -from ..core import VarWrapper, OpWrapper, GraphWrapper from ..common import SAController from ..common import get_logger from ..analysis import flops diff --git a/tests/test_nas_search_space.py b/tests/test_nas_search_space.py deleted file mode 100644 index ad373cf1..00000000 --- a/tests/test_nas_search_space.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -sys.path.append('..') -import unittest -import paddle.fluid as fluid -from nas.search_space_factory import SearchSpaceFactory - - -class TestSearchSpace(unittest.TestCase): - def test_searchspace(self): - # if output_size is 1, the model will add fc layer in the end. - config = {'input_size': 224, 'output_size': 7, 'block_num': 5} - space = SearchSpaceFactory() - - my_space = space.get_search_space([('MobileNetV2Space', config)]) - model_arch = my_space.token2arch() - - train_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard(train_prog, startup_prog): - input_size = config['input_size'] - model_input = fluid.layers.data( - name='model_in', - shape=[1, 3, input_size, input_size], - dtype='float32', - append_batch_size=False) - predict = model_arch[0](model_input) - self.assertTrue(predict.shape[2] == config['output_size']) - - -class TestMultiSearchSpace(unittest.TestCase): - space = SearchSpaceFactory() - - config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} - config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} - my_space = space.get_search_space( - [('MobileNetV2Space', config0), ('ResNetSpace', config1)]) - model_archs = my_space.token2arch() - - train_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard(train_prog, startup_prog): - input_size = config0['input_size'] - model_input = fluid.layers.data( - name='model_in', - shape=[1, 3, input_size, input_size], - dtype='float32', - append_batch_size=False) - for model_arch in model_archs: - predict = model_arch(model_input) - model_input = predict - print(predict) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index a4203a85..8630136e 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -11,52 +11,90 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys -sys.path.append("../") import unittest import paddle.fluid as fluid from paddleslim.nas import SANAS -from paddleslim.nas import SearchSpaceFactory from paddleslim.analysis import flops +import numpy as np +def compute_op_num(program): + params = {} + ch_list = [] + for block in program.blocks: + for param in block.all_parameters(): + if len(param.shape) == 4: + params[param.name] = param.shape + ch_list.append(int(param.shape[0])) + return params, ch_list class TestSANAS(unittest.TestCase): - def test_nas(self): - - factory = SearchSpaceFactory() - config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} - config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} - configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] - - space = factory.get_search_space([('MobileNetV2Space', config0)]) - origin_arch = space.token2arch()[0] - - main_program = fluid.Program() - s_program = fluid.Program() - with fluid.program_guard(main_program, s_program): - input = fluid.data( - name="input", shape=[None, 3, 224, 224], dtype="float32") - origin_arch(input) - base_flops = flops(main_program) - - search_steps = 3 - sa_nas = SANAS( - configs, - search_steps=search_steps, - server_addr=("", 0), - is_server=True) - - for i in range(search_steps): - archs = sa_nas.next_archs() - main_program = fluid.Program() - s_program = fluid.Program() - with fluid.program_guard(main_program, s_program): - input = fluid.data( - name="input", shape=[None, 3, 224, 224], dtype="float32") - archs[0](input) - sa_nas.reward(1) - self.assertTrue(flops(main_program) < base_flops) + def setUp(self): + self.init_test_case() + port = np.random.randint(8337, 8773) + self.sanas = SANAS(configs=self.configs, server_addr=("", port), save_checkpoint=None) + def init_test_case(self): + self.configs=[('MobileNetV2BlockSpace', {'block_mask':[0]})] + self.filter_num = np.array([ + 3, 4, 8, 12, 16, 24, 32, 48, 64, 80, 96, 128, 144, 160, 192, 224, + 256, 320, 384, 512 + ]) + self.k_size = np.array([3, 5]) + self.multiply = np.array([1, 2, 3, 4, 5, 6]) + self.repeat = np.array([1, 2, 3, 4, 5, 6]) + + def check_chnum_convnum(self, program): + current_tokens = self.sanas.current_info()['current_tokens'] + channel_exp = self.multiply[current_tokens[0]] + filter_num = self.filter_num[current_tokens[1]] + repeat_num = self.repeat[current_tokens[2]] + + conv_list, ch_pro = compute_op_num(program) + ### assert conv number + self.assertTrue((repeat_num * 3) == len(conv_list), "the number of conv is NOT match, the number compute from token: {}, actual conv number: {}".format(repeat_num * 3, len(conv_list))) + + ### assert number of channels + ch_token = [] + init_ch_num = 32 + for i in range(repeat_num): + ch_token.append(init_ch_num * channel_exp) + ch_token.append(init_ch_num * channel_exp) + ch_token.append(filter_num) + init_ch_num = filter_num + + self.assertTrue(str(ch_token) == str(ch_pro), "channel num is WRONG, channel num from token is {}, channel num come fom program is {}".format(str(ch_token), str(ch_pro))) + + def test_all_function(self): + ### unittest for next_archs + next_program = fluid.Program() + startup_program = fluid.Program() + token2arch_program = fluid.Program() + + with fluid.program_guard(next_program, startup_program): + inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') + archs = self.sanas.next_archs() + for arch in archs: + output = arch(inputs) + inputs = output + self.check_chnum_convnum(next_program) + + ### unittest for reward + self.assertTrue(self.sanas.reward(float(1.0)), "reward is False") + + ### uniitest for tokens2arch + with fluid.program_guard(token2arch_program, startup_program): + inputs = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') + arch = self.sanas.tokens2arch(self.sanas.current_info()['current_tokens']) + for arch in archs: + output = arch(inputs) + inputs = output + self.check_chnum_convnum(token2arch_program) + + ### unittest for current_info + current_info = self.sanas.current_info() + self.assertTrue(isinstance(current_info, dict), "the type of current info must be dict, but now is {}".format(type(current_info))) if __name__ == '__main__': unittest.main() -- GitLab