diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py index 0e9157d9e5d7c19e29a1d35506fd8498fd65c96d..c8bef8db17e4a4cea110a3ef3fd4f3d7edceeedc 100644 --- a/paddleslim/nas/search_space/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -14,6 +14,8 @@ import mobilenetv2 from .mobilenetv2 import * +import resnet +from .resnet import * import search_space_registry from search_space_registry import * import search_space_factory @@ -26,3 +28,4 @@ __all__ += mobilenetv2.__all__ __all__ += search_space_registry.__all__ __all__ += search_space_factory.__all__ __all__ += search_space_base.__all__ + diff --git a/paddleslim/nas/search_space/base_layer.py b/paddleslim/nas/search_space/base_layer.py index 431e025ed824c7c78ac1e38ad27a785315b0fef4..2e769ec6339b639732995849e9f819a08b749c92 100644 --- a/paddleslim/nas/search_space/base_layer.py +++ b/paddleslim/nas/search_space/base_layer.py @@ -59,5 +59,7 @@ def conv_bn_layer(input, moving_variance_name=bn_name + '_variance') if act == 'relu6': return fluid.layers.relu6(bn) + elif act == 'sigmoid': + return fluid.layers.sigmoid(bn) else: return bn diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py new file mode 100644 index 0000000000000000000000000000000000000000..371bcf5347ebbc21e0688d1611ed3b298b940eb1 --- /dev/null +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .search_space_registry import SEARCHSPACE +from .base_layer import conv_bn_layer + +__all__ = ["CombineSearchSpace"] + +class CombineSearchSpace(object): + """ + Combine Search Space. + Args: + configs(list): multi config. + """ + def __init__(self, config_lists): + self.lens = len(config_lists) + self.spaces = [] + for config_list in config_lists: + key, config = config_list + self.spaces.append(self._get_single_search_space(key, config)) + + def _get_single_search_space(self, key, config): + """ + get specific model space based on key and config. + + Args: + key(str): model space name. + config(dict): basic config information. + return: + model space(class) + """ + cls = SEARCHSPACE.get(key) + space = cls(config['input_size'], config['output_size'], + config['block_num']) + + return space + + + def init_tokens(self): + """ + Combine init tokens. + """ + tokens = [] + self.single_token_num = [] + for space in self.spaces: + tokens.extend(space.init_tokens()) + self.single_token_num.append(len(space.init_tokens())) + return tokens + + def range_table(self): + """ + Combine range table. + """ + range_tables = [] + for space in self.spaces: + range_tables.extend(space.range_table()) + return range_tables + + def token2arch(self, tokens=None): + """ + Combine model arch + """ + if tokens is None: + tokens = self.init_tokens() + + token_list = [] + start_idx = 0 + end_idx = 0 + + for i in range(len(self.single_token_num)): + end_idx += self.single_token_num[i] + token_list.append(tokens[start_idx:end_idx]) + start_idx = end_idx + + model_archs = [] + for space, token in zip(self.spaces, token_list): + model_archs.append(space.token2arch(token)) + + return model_archs + diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 270da2b17d7cd61e782daeb88e8f9897d1bbe844..90e0b2a0e704a2f44a8031b05d475419bc534677 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -52,6 +52,7 @@ class MobileNetV2Space(SearchSpaceBase): self.scale = scale self.class_dim = class_dim + def init_tokens(self): """ The initial token send to controller. @@ -60,7 +61,7 @@ class MobileNetV2Space(SearchSpaceBase): """ # original MobileNetV2 # yapf: disable - return [4, # 1, 16, 1 + init_token_base = [4, # 1, 16, 1 4, 5, 1, 0, # 6, 24, 1 4, 5, 1, 0, # 6, 24, 2 4, 4, 2, 0, # 6, 32, 3 @@ -70,13 +71,20 @@ class MobileNetV2Space(SearchSpaceBase): 4, 9, 0, 0] # 6, 320, 1 # yapf: enable + if self.block_num < 5: + self.token_len = 1 + (self.block_num - 1) * 4 + else: + self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4 + + return init_token_base[:self.token_len] + def range_table(self): """ get range table of current search space """ # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # yapf: disable - return [7, + range_table_base = [7, 5, 8, 6, 2, 5, 8, 6, 2, 5, 8, 6, 2, @@ -85,48 +93,38 @@ class MobileNetV2Space(SearchSpaceBase): 5, 10, 6, 2, 5, 12, 6, 2] # yapf: enable + return range_table_base[:self.token_len] def token2arch(self, tokens=None): """ return net_arch function """ - if tokens is None: - tokens = self.init_tokens() - - base_bottleneck_params_list = [ - (1, self.head_num[tokens[0]], 1, 1, 3), - (self.multiply[tokens[1]], self.filter_num1[tokens[2]], - self.repeat[tokens[3]], 2, self.k_size[tokens[4]]), - (self.multiply[tokens[5]], self.filter_num1[tokens[6]], - self.repeat[tokens[7]], 2, self.k_size[tokens[8]]), - (self.multiply[tokens[9]], self.filter_num2[tokens[10]], - self.repeat[tokens[11]], 2, self.k_size[tokens[12]]), - (self.multiply[tokens[13]], self.filter_num3[tokens[14]], - self.repeat[tokens[15]], 2, self.k_size[tokens[16]]), - (self.multiply[tokens[17]], self.filter_num3[tokens[18]], - self.repeat[tokens[19]], 1, self.k_size[tokens[20]]), - (self.multiply[tokens[21]], self.filter_num5[tokens[22]], - self.repeat[tokens[23]], 2, self.k_size[tokens[24]]), - (self.multiply[tokens[25]], self.filter_num6[tokens[26]], - self.repeat[tokens[27]], 1, self.k_size[tokens[28]]), - ] assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format( self.block_num) - # the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1, - # otherwise, add layers to params_list directly. - bottleneck_params_list = [] - for param_list in base_bottleneck_params_list: - if param_list[3] == 1: - bottleneck_params_list.append(param_list) - else: - if self.block_num > 1: - bottleneck_params_list.append(param_list) - self.block_num -= 1 - else: - break + if tokens is None: + tokens = self.init_tokens() + bottleneck_params_list = [] + if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3)) + if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + if self.block_num >= 5: + bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + if self.block_num >= 6: + bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + def net_arch(input): #conv1 # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. @@ -137,7 +135,7 @@ class MobileNetV2Space(SearchSpaceBase): stride=2, padding='SAME', act='relu6', - name='conv1_1') + name='mobilenetv2_conv1_1') # bottleneck sequences i = 1 @@ -145,7 +143,7 @@ class MobileNetV2Space(SearchSpaceBase): for layer_setting in bottleneck_params_list: t, c, n, s, k = layer_setting i += 1 - input = self.invresi_blocks( + input = self._invresi_blocks( input=input, in_c=in_c, t=t, @@ -153,7 +151,7 @@ class MobileNetV2Space(SearchSpaceBase): n=n, s=s, k=k, - name='conv' + str(i)) + name='mobilenetv2_conv' + str(i)) in_c = int(c * self.scale) # if output_size is 1, add fc layer in the end @@ -161,8 +159,8 @@ class MobileNetV2Space(SearchSpaceBase): input = fluid.layers.fc( input=input, size=self.class_dim, - param_attr=ParamAttr(name='fc10_weights'), - bias_attr=ParamAttr(name='fc10_offset')) + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) else: assert self.output_size == input.shape[2], \ ("output_size must EQUAL to input_size / (2^block_num)." @@ -173,7 +171,7 @@ class MobileNetV2Space(SearchSpaceBase): return net_arch - def shortcut(self, input, data_residual): + def _shortcut(self, input, data_residual): """Build shortcut layer. Args: input(Variable): input. @@ -183,7 +181,7 @@ class MobileNetV2Space(SearchSpaceBase): """ return fluid.layers.elementwise_add(input, data_residual) - def inverted_residual_unit(self, + def _inverted_residual_unit(self, input, num_in_filter, num_filters, @@ -240,10 +238,10 @@ class MobileNetV2Space(SearchSpaceBase): name=name + '_linear') out = linear_out if ifshortcut: - out = self.shortcut(input=input, data_residual=out) + out = self._shortcut(input=input, data_residual=out) return out - def invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): + def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): """Build inverted residual blocks. Args: input: Variable, input. @@ -257,7 +255,7 @@ class MobileNetV2Space(SearchSpaceBase): Returns: Variable, layers output. """ - first_block = self.inverted_residual_unit( + first_block = self._inverted_residual_unit( input=input, num_in_filter=in_c, num_filters=c, @@ -271,7 +269,7 @@ class MobileNetV2Space(SearchSpaceBase): last_c = c for i in range(1, n): - last_residual_block = self.inverted_residual_unit( + last_residual_block = self._inverted_residual_unit( input=last_residual_block, num_in_filter=last_c, num_filters=c, diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..a6ac5817ce89190987d67f1eda644fa2aef79037 --- /dev/null +++ b/paddleslim/nas/search_space/resnet.py @@ -0,0 +1,58 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["ResNetSpace"] + +@SEARCHSPACE.register +class ResNetSpace(SearchSpaceBase): + def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): + super(ResNetSpace, self).__init__(input_size, output_size, block_num) + pass + + def init_tokens(self): + return [0,0,0,0,0,0] + + def range_table(self): + return [3,3,3,3,3,3] + + def token2arch(self,tokens=None): + if tokens is None: + self.init_tokens() + + def net_arch(input): + input = conv_bn_layer( + input, + num_filters=32, + filter_size=3, + stride=2, + padding='SAME', + act='sigmoid', + name='resnet_conv1_1') + + return input + + return net_arch + + diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index a68ec63acbb90191c931508ed96e1bb2d8261e48..bb1ce0f8a4bbd0b18d36fa9199a6ff814ab13236 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -39,6 +39,6 @@ class SearchSpaceBase(object): Args: tokens(list): The tokens which represent a network. Return: - list + model arch """ raise NotImplementedError('Abstract method.') diff --git a/paddleslim/nas/search_space/search_space_factory.py b/paddleslim/nas/search_space/search_space_factory.py index 11d8377d9056eddbeed43a6ff3f1aa67c6cf664c..2fc0be834445e13ddef5d6664d13a69fb6904aa6 100644 --- a/paddleslim/nas/search_space/search_space_factory.py +++ b/paddleslim/nas/search_space/search_space_factory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from search_space_registry import SEARCHSPACE +from .combine_search_space import CombineSearchSpace __all__ = ["SearchSpaceFactory"] @@ -21,18 +21,11 @@ class SearchSpaceFactory(object): def __init__(self): pass - def get_search_space(self, key, config): + def get_search_space(self, config_lists): """ - get specific model space based on key and config. + get model spaces based on list(key, config). - Args: - key(str): model space name. - config(dict): basic config information. - return: - model space(class) """ - cls = SEARCHSPACE.get(key) - space = cls(config['input_size'], config['output_size'], - config['block_num']) + assert isinstance(config_lists, list), "configs must be a list" - return space + return CombineSearchSpace(config_lists) diff --git a/tests/test_searchspace.py b/tests/test_searchspace.py index 5e3557bbfc93b277fbc87339bcf6be605d313456..c751bdc3d31d1822051436b71035b64ee6963fac 100644 --- a/tests/test_searchspace.py +++ b/tests/test_searchspace.py @@ -24,7 +24,7 @@ class TestSearchSpaceFactory(unittest.TestCase): config = {'input_size': 224, 'output_size': 7, 'block_num': 5} space = SearchSpaceFactory() - my_space = space.get_search_space('MobileNetV2Space', config) + my_space = space.get_search_space([('MobileNetV2Space', config)]) model_arch = my_space.token2arch() train_prog = fluid.Program() @@ -36,10 +36,26 @@ class TestSearchSpaceFactory(unittest.TestCase): shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) - print('input shape', model_input.shape) - predict = model_arch(model_input) - print('output shape', predict.shape) + predict = model_arch[0](model_input) + self.assertTrue(predict.shape[2] == config['output_size']) +class TestMultiSearchSpace(unittest.TestCase): + space = SearchSpaceFactory() + + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + my_space = space.get_search_space([('MobileNetV2Space', config0), ('ResNetSpace', config1)]) + model_archs = my_space.token2arch() + + train_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + input_size= config0['input_size'] + model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) + for model_arch in model_archs: + predict = model_arch(model_input) + model_input = predict + print(predict) if __name__ == '__main__': unittest.main()