diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index f6ad753dfa11c695be32a2211e27dddfdaed7072..9bb4b2bf34f3aed2f74ce4fd5936527b17737181 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -14,4 +14,8 @@ from . import graph_wrapper from .graph_wrapper import * +from . import registry +from .registry import * + __all__ = graph_wrapper.__all__ +__all__ += registry.__all__ diff --git a/paddleslim/nas/utils/registry.py b/paddleslim/core/registry.py similarity index 51% rename from paddleslim/nas/utils/registry.py rename to paddleslim/core/registry.py index 5d055a9c3de98db19bc2f6ccd85f762b384e6ce3..208dceca1ff9958591b7e427d47124f3c57e4d5b 100644 --- a/paddleslim/nas/utils/registry.py +++ b/paddleslim/core/registry.py @@ -1,16 +1,36 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect +__all__ = ["Registry"] + + class Registry(object): def __init__(self, name): self._name = name self._module_dict = dict() + def __repr__(self): - format_str = self.__class__.__name__ + '(name={}, items={})'.format(self._name, list(self._module_dict.keys())) + format_str = self.__class__.__name__ + '(name={}, items={})'.format( + self._name, list(self._module_dict.keys())) return format_str @property def name(self): return self._name + @property def module_dict(self): return self._module_dict @@ -20,12 +40,14 @@ class Registry(object): def _register_module(self, module_class): if not inspect.isclass(module_class): - raise TypeError('module must be a class, but receive {}.'.format(type(module_class))) + raise TypeError('module must be a class, but receive {}.'.format( + type(module_class))) module_name = module_class.__name__ if module_name in self._module_dict: - raise KeyError('{} is already registered in {}.'.format(module_name, self.name)) + raise KeyError('{} is already registered in {}.'.format( + module_name, self.name)) self._module_dict[module_name] = module_class - def register_module(self, cls): + def register(self, cls): self._register_module(cls) return cls diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..2f5509144f53529ae717b72bbb7252b4b06a0048 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import search_space +from search_space import * + +__all__ = [] +__all__ += search_space.__all__ diff --git a/paddleslim/nas/utils/__init__.py b/paddleslim/nas/search_space/__init__.py similarity index 61% rename from paddleslim/nas/utils/__init__.py rename to paddleslim/nas/search_space/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..767760fabca2803d2cd9c6ec4a51ec2562fcfb5c 100644 --- a/paddleslim/nas/utils/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -11,3 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .mobilenetv2_space import MobileNetV2Space +import search_space_registry +from search_space_registry import * +import search_space_factory +from search_space_factory import * +import search_space_base +from search_space_base import * + +__all__ = ["MobileNetV2Space"] +__all__ += search_space_registry.__all__ +__all__ += search_space_factory.__all__ +__all__ += search_space_base.__all__ diff --git a/paddleslim/nas/searchspace/base_layer.py b/paddleslim/nas/search_space/base_layer.py similarity index 59% rename from paddleslim/nas/searchspace/base_layer.py rename to paddleslim/nas/search_space/base_layer.py index 75ce180b4279e32601fe3fab6fea44d719a1a701..431e025ed824c7c78ac1e38ad27a785315b0fef4 100644 --- a/paddleslim/nas/searchspace/base_layer.py +++ b/paddleslim/nas/search_space/base_layer.py @@ -16,7 +16,15 @@ import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1, act=None, name=None, use_cudnn=True): +def conv_bn_layer(input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + act=None, + name=None, + use_cudnn=True): """Build convolution and batch normalization layers. Args: input(Variable): input. @@ -31,11 +39,24 @@ def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1 Returns: Variable, layers output. """ - conv = fluid.layers.conv2d(input, num_filters=num_filters, filter_size=filter_size, stride=stride, padding=padding, - groups=num_groups, act=None, use_cudnn=use_cudnn, param_attr=ParamAttr(name=name+'_weights'), bias_attr=False) + conv = fluid.layers.conv2d( + input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) bn_name = name + '_bn' - bn = fluid.layers.batch_norm(input=conv, param_attr=ParamAttr(name=bn_name+'_scale'), bias_attr=ParamAttr(name=bn_name+'_offset'), - moving_mean_name=bn_name+'_mean', moving_variance_name=bn_name+'_variance') + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') if act == 'relu6': return fluid.layers.relu6(bn) else: diff --git a/paddleslim/nas/searchspace/mobilenetv2_space.py b/paddleslim/nas/search_space/mobilenetv2_space.py similarity index 73% rename from paddleslim/nas/searchspace/mobilenetv2_space.py rename to paddleslim/nas/search_space/mobilenetv2_space.py index e09e00a6d8d685dff6e0e373ad97629b48dafd2a..0455fc9521bf5881cb6ad7eedeb97276a01b39bb 100644 --- a/paddleslim/nas/searchspace/mobilenetv2_space.py +++ b/paddleslim/nas/search_space/mobilenetv2_space.py @@ -17,30 +17,39 @@ from __future__ import division from __future__ import print_function import sys -sys.path.append('..') import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr -from searchspacebase import SearchSpaceBase +from .search_space_base import SearchSpaceBase from .base_layer import conv_bn_layer -from .registry import SEARCHSPACE +from .search_space_registry import SEARCHSPACE -@SEARCHSPACE.register_module + +@SEARCHSPACE.register class MobileNetV2Space(SearchSpaceBase): - def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): - super(MobileNetV2Space, self).__init__(input_size, output_size, block_num) - self.head_num = np.array([3,4,8,12,16,24,32]) #7 - self.filter_num1 = np.array([3,4,8,12,16,24,32,48]) #8 - self.filter_num2 = np.array([8,12,16,24,32,48,64,80]) #8 - self.filter_num3 = np.array([16,24,32,48,64,80,96,128]) #8 - self.filter_num4 = np.array([24,32,48,64,80,96,128,144,160,192]) #10 - self.filter_num5 = np.array([32,48,64,80,96,128,144,160,192,224]) #10 - self.filter_num6 = np.array([64,80,96,128,144,160,192,224,256,320,384,512]) #12 - self.k_size = np.array([3,5]) #2 - self.multiply = np.array([1,2,3,4,6]) #5 - self.repeat = np.array([1,2,3,4,5,6]) #6 - self.scale=scale - self.class_dim=class_dim + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): + super(MobileNetV2Space, self).__init__(input_size, output_size, + block_num) + self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 + self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) #8 + self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) #8 + self.filter_num3 = np.array([16, 24, 32, 48, 64, 80, 96, 128]) #8 + self.filter_num4 = np.array( + [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10 + self.filter_num5 = np.array( + [32, 48, 64, 80, 96, 128, 144, 160, 192, 224]) #10 + self.filter_num6 = np.array( + [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512]) #12 + self.k_size = np.array([3, 5]) #2 + self.multiply = np.array([1, 2, 3, 4, 6]) #5 + self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 + self.scale = scale + self.class_dim = class_dim def init_tokens(self): """ @@ -49,28 +58,47 @@ class MobileNetV2Space(SearchSpaceBase): each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] """ # original MobileNetV2 - return [4, # 1, 16, 1 - 4, 5, 1, 0, # 6, 24, 1 - 4, 5, 1, 0, # 6, 24, 2 - 4, 4, 2, 0, # 6, 32, 3 - 4, 4, 3, 0, # 6, 64, 4 - 4, 5, 2, 0, # 6, 96, 3 - 4, 7, 2, 0, # 6, 160, 3 - 4, 9, 0, 0] # 6, 320, 1 + return [ + 4, # 1, 16, 1 + 4, + 5, + 1, + 0, # 6, 24, 1 + 4, + 5, + 1, + 0, # 6, 24, 2 + 4, + 4, + 2, + 0, # 6, 32, 3 + 4, + 4, + 3, + 0, # 6, 64, 4 + 4, + 5, + 2, + 0, # 6, 96, 3 + 4, + 7, + 2, + 0, # 6, 160, 3 + 4, + 9, + 0, + 0 + ] # 6, 320, 1 def range_table(self): """ get range table of current search space """ # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] - return [7, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 10, 6, 2, - 5, 10, 6, 2, - 5, 12, 6, 2] + return [ + 7, 5, 8, 6, 2, 5, 8, 6, 2, 5, 8, 6, 2, 5, 8, 6, 2, 5, 10, 6, 2, 5, + 10, 6, 2, 5, 12, 6, 2 + ] def token2arch(self, tokens=None): """ @@ -81,16 +109,24 @@ class MobileNetV2Space(SearchSpaceBase): base_bottleneck_params_list = [ (1, self.head_num[tokens[0]], 1, 1, 3), - (self.multiply[tokens[1]], self.filter_num1[tokens[2]], self.repeat[tokens[3]], 2, self.k_size[tokens[4]]), - (self.multiply[tokens[5]], self.filter_num1[tokens[6]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]]), - (self.multiply[tokens[9]], self.filter_num2[tokens[10]], self.repeat[tokens[11]], 2, self.k_size[tokens[12]]), - (self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]]), - (self.multiply[tokens[17]], self.filter_num3[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]]), - (self.multiply[tokens[21]], self.filter_num5[tokens[22]], self.repeat[tokens[23]], 2, self.k_size[tokens[24]]), - (self.multiply[tokens[25]], self.filter_num6[tokens[26]], self.repeat[tokens[27]], 1, self.k_size[tokens[28]]), + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]]), + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]]), + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]]), + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]]), + (self.multiply[tokens[17]], self.filter_num3[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]]), + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]]), + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]]), ] - assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(self.block_num) + assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format( + self.block_num) # the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1, # otherwise, add layers to params_list directly. @@ -136,10 +172,11 @@ class MobileNetV2Space(SearchSpaceBase): # if output_size is 1, add fc layer in the end if self.output_size == 1: - input = fluid.layers.fc(input=input, - size=self.class_dim, - param_attr=ParamAttr(name='fc10_weights'), - bias_attr=ParamAttr(name='fc10_offset')) + input = fluid.layers.fc( + input=input, + size=self.class_dim, + param_attr=ParamAttr(name='fc10_weights'), + bias_attr=ParamAttr(name='fc10_offset')) else: assert self.output_size == input.shape[2], \ ("output_size must EQUAL to input_size / (2^block_num)." @@ -150,7 +187,6 @@ class MobileNetV2Space(SearchSpaceBase): return net_arch - def shortcut(self, input, data_residual): """Build shortcut layer. Args: @@ -161,7 +197,6 @@ class MobileNetV2Space(SearchSpaceBase): """ return fluid.layers.elementwise_add(input, data_residual) - def inverted_residual_unit(self, input, num_in_filter, @@ -222,15 +257,7 @@ class MobileNetV2Space(SearchSpaceBase): out = self.shortcut(input=input, data_residual=out) return out - def invresi_blocks(self, - input, - in_c, - t, - c, - n, - s, - k, - name=None): + def invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): """Build inverted residual blocks. Args: input: Variable, input. @@ -268,5 +295,3 @@ class MobileNetV2Space(SearchSpaceBase): expansion_factor=t, name=name + '_' + str(i + 1)) return last_residual_block - - diff --git a/paddleslim/nas/searchspacebase.py b/paddleslim/nas/search_space/search_space_base.py similarity index 100% rename from paddleslim/nas/searchspacebase.py rename to paddleslim/nas/search_space/search_space_base.py index cc1d462aca3b81c231df53ec2b5995cbb1deb5d5..a68ec63acbb90191c931508ed96e1bb2d8261e48 100644 --- a/paddleslim/nas/searchspacebase.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -14,6 +14,7 @@ __all__ = ['SearchSpaceBase'] + class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ @@ -41,4 +42,3 @@ class SearchSpaceBase(object): list """ raise NotImplementedError('Abstract method.') - diff --git a/paddleslim/nas/searchspacefactory.py b/paddleslim/nas/search_space/search_space_factory.py similarity index 89% rename from paddleslim/nas/searchspacefactory.py rename to paddleslim/nas/search_space/search_space_factory.py index 10d076e8722a89cb46bac94360b1968e1de9e33a..11d8377d9056eddbeed43a6ff3f1aa67c6cf664c 100644 --- a/paddleslim/nas/searchspacefactory.py +++ b/paddleslim/nas/search_space/search_space_factory.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from searchspace.registry import SEARCHSPACE +from search_space_registry import SEARCHSPACE + +__all__ = ["SearchSpaceFactory"] + class SearchSpaceFactory(object): def __init__(self): @@ -29,8 +32,7 @@ class SearchSpaceFactory(object): model space(class) """ cls = SEARCHSPACE.get(key) - space = cls(config['input_size'], config['output_size'], config['block_num']) + space = cls(config['input_size'], config['output_size'], + config['block_num']) return space - - diff --git a/paddleslim/nas/searchspace/__init__.py b/paddleslim/nas/search_space/search_space_registry.py similarity index 86% rename from paddleslim/nas/searchspace/__init__.py rename to paddleslim/nas/search_space/search_space_registry.py index d1b5c527794b03967e6ce77f8bb16e883c06dbbf..2fea80fba4c908759e6123d3d898e94d7ef54c42 100644 --- a/paddleslim/nas/searchspace/__init__.py +++ b/paddleslim/nas/search_space/search_space_registry.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .mobilenetv2_space import MobileNetV2Space +from ...core import Registry + +__all__ = ["SEARCHSPACE"] + +SEARCHSPACE = Registry('searchspace') diff --git a/paddleslim/nas/searchspace/registry.py b/paddleslim/nas/searchspace/registry.py deleted file mode 100644 index 33fb7212e99c886e4ad720f21db145e6f0e2443f..0000000000000000000000000000000000000000 --- a/paddleslim/nas/searchspace/registry.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys -sys.path.append('..') -from utils.registry import Registry - -SEARCHSPACE = Registry('searchspace') diff --git a/paddleslim/nas/test_searchspace.py b/paddleslim/nas/test_searchspace.py deleted file mode 100644 index 4761bf36e2722dfed83acfd69621a4e586a2ed93..0000000000000000000000000000000000000000 --- a/paddleslim/nas/test_searchspace.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -from searchspacefactory import SearchSpaceFactory -if __name__ == '__main__': - # if output_size is 1, the model will add fc layer in the end. - config = {'input_size': 224, 'output_size': 7, 'block_num': 5} - space = SearchSpaceFactory() - - my_space = space.get_search_space('MobileNetV2Space', config) - model_arch = my_space.token2arch() - - train_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard(train_prog, startup_prog): - input_size= config['input_size'] - model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False) - print('input shape', model_input.shape) - predict = model_arch(model_input) - print('output shape', predict.shape) - - - #for op in train_prog.global_block().ops: - # print(op.type) diff --git a/tests/test_searchspace.py b/tests/test_searchspace.py new file mode 100644 index 0000000000000000000000000000000000000000..5e3557bbfc93b277fbc87339bcf6be605d313456 --- /dev/null +++ b/tests/test_searchspace.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.nas import SearchSpaceFactory + + +class TestSearchSpaceFactory(unittest.TestCase): + def test_factory(self): + # if output_size is 1, the model will add fc layer in the end. + config = {'input_size': 224, 'output_size': 7, 'block_num': 5} + space = SearchSpaceFactory() + + my_space = space.get_search_space('MobileNetV2Space', config) + model_arch = my_space.token2arch() + + train_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + input_size = config['input_size'] + model_input = fluid.layers.data( + name='model_in', + shape=[1, 3, input_size, input_size], + dtype='float32', + append_batch_size=False) + print('input shape', model_input.shape) + predict = model_arch(model_input) + print('output shape', predict.shape) + + +if __name__ == '__main__': + unittest.main()