diff --git a/demo/sa_nas_mobilenetv2_cifar10.py b/demo/sa_nas_mobilenetv2_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..3e903960b1c783c38d672238d5a2b3a0c1581c4d --- /dev/null +++ b/demo/sa_nas_mobilenetv2_cifar10.py @@ -0,0 +1,122 @@ +import sys +sys.path.append('..') +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory +from paddleslim.analysis import flops +from paddleslim.nas import SANAS + + +def create_data_loader(): + data = fluid.data(name='data', shape=[-1, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data, label], + capacity=1024, + use_double_buffer=True, + iterable=True) + return data_loader, data, label + + +def init_sa_nas(config): + factory = SearchSpaceFactory() + space = factory.get_search_space(config) + model_arch = space.token2arch()[0] + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + data_loader, data, label = create_data_loader() + output = model_arch(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label)) + + base_flops = flops(main_program) + search_steps = 10000000 + + ### start a server and a client + sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) + + ### start a client, server_addr is server address + #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) + + return sa_nas, search_steps + + +def search_mobilenetv2_cifar10(config, args): + sa_nas, search_steps = init_sa_nas(config) + for i in range(search_steps): + print('search step: ', i) + archs = sa_nas.next_archs()[0] + train_program = fluid.Program() + test_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + train_loader, data, label = create_data_loader() + output = archs(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label))[0] + test_program = train_program.clone(for_test=True) + + optimizer = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(cost) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + train_reader = paddle.reader.shuffle( + paddle.dataset.cifar.train10(cycle=False), buf_size=1024) + train_loader.set_sample_generator( + train_reader, + batch_size=512, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + test_loader, _, _ = create_data_loader() + test_reader = paddle.dataset.cifar.test10(cycle=False) + test_loader.set_sample_generator( + test_reader, + batch_size=256, + drop_last=False, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + for epoch_id in range(10): + for batch_id, data in enumerate(train_loader()): + loss = exe.run(train_program, + feed=data, + fetch_list=[cost.name])[0] + if batch_id % 5 == 0: + print('epoch: {}, batch: {}, loss: {}'.format( + epoch_id, batch_id, loss[0])) + + for data in test_loader(): + reward = exe.run(test_program, feed=data, + fetch_list=[cost.name])[0] + + print('reward:', reward) + sa_nas.reward(float(reward)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='SA NAS MobileNetV2 cifar10 argparase') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='Whether to use GPU in train/test model.') + args = parser.parse_args() + print(args) + + config_info = {'input_size': 32, 'output_size': 1, 'block_num': 5} + config = [('MobileNetV2Space', config_info)] + + search_mobilenetv2_cifar10(config, args) diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py index c8bef8db17e4a4cea110a3ef3fd4f3d7edceeedc..51b433d452b8cd8c3eb32582d9caa43634b700d0 100644 --- a/paddleslim/nas/search_space/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -14,6 +14,8 @@ import mobilenetv2 from .mobilenetv2 import * +import mobilenetv1 +from .mobilenetv1 import * import resnet from .resnet import * import search_space_registry @@ -28,4 +30,3 @@ __all__ += mobilenetv2.__all__ __all__ += search_space_registry.__all__ __all__ += search_space_factory.__all__ __all__ += search_space_base.__all__ - diff --git a/paddleslim/nas/search_space/base_layer.py b/paddleslim/nas/search_space/base_layer.py index 2e769ec6339b639732995849e9f819a08b749c92..b497c92a2ca57b4acab0c39c5dbd69d30083e295 100644 --- a/paddleslim/nas/search_space/base_layer.py +++ b/paddleslim/nas/search_space/base_layer.py @@ -20,7 +20,7 @@ def conv_bn_layer(input, filter_size, num_filters, stride, - padding, + padding='SAME', num_groups=1, act=None, name=None, @@ -51,15 +51,10 @@ def conv_bn_layer(input, param_attr=ParamAttr(name=name + '_weights'), bias_attr=False) bn_name = name + '_bn' - bn = fluid.layers.batch_norm( - input=conv, - param_attr=ParamAttr(name=bn_name + '_scale'), - bias_attr=ParamAttr(name=bn_name + '_offset'), - moving_mean_name=bn_name + '_mean', - moving_variance_name=bn_name + '_variance') - if act == 'relu6': - return fluid.layers.relu6(bn) - elif act == 'sigmoid': - return fluid.layers.sigmoid(bn) - else: - return bn + return fluid.layers.batch_norm( + input=conv, + act = act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 371bcf5347ebbc21e0688d1611ed3b298b940eb1..667720a9110aa92e096a4f8fa30bb3e4b3e3cecb 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -25,12 +25,14 @@ from .base_layer import conv_bn_layer __all__ = ["CombineSearchSpace"] + class CombineSearchSpace(object): """ Combine Search Space. Args: configs(list): multi config. """ + def __init__(self, config_lists): self.lens = len(config_lists) self.spaces = [] @@ -50,11 +52,10 @@ class CombineSearchSpace(object): """ cls = SEARCHSPACE.get(key) space = cls(config['input_size'], config['output_size'], - config['block_num']) + config['block_num'], config['block_mask']) return space - def init_tokens(self): """ Combine init tokens. @@ -96,4 +97,3 @@ class CombineSearchSpace(object): model_archs.append(space.token2arch(token)) return model_archs - diff --git a/paddleslim/nas/search_space/mobilenetv1.py b/paddleslim/nas/search_space/mobilenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3277d2cb1b472ccd5e27407e3099b28e64f42b --- /dev/null +++ b/paddleslim/nas/search_space/mobilenetv1.py @@ -0,0 +1,224 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["MobileNetV1Space"] + + +@SEARCHSPACE.register +class MobileNetV1Space(SearchSpaceBase): + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): + super(MobileNetV1Space, self).__init__(input_size, output_size, + block_num) + self.scale = scale + self.class_dim = class_dim + # self.head_num means the channel of first convolution + self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) # 7 + # self.filter_num1 ~ self.filtet_num9 means channel of the following convolution + self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) # 8 + self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) # 8 + self.filter_num3 = np.array( + [16, 24, 32, 48, 64, 80, 96, 128, 144, 160]) #10 + self.filter_num4 = np.array( + [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10 + self.filter_num5 = np.array( + [32, 48, 64, 80, 96, 128, 144, 160, 192, 224, 256, 320]) #12 + self.filter_num6 = np.array( + [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384]) #11 + self.filter_num7 = np.array([ + 64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512, 1024, 1048 + ]) #14 + self.filter_num8 = np.array( + [128, 144, 160, 192, 224, 256, 320, 384, 512, 576, 640, 704, + 768]) #13 + self.filter_num9 = np.array( + [160, 192, 224, 256, 320, 384, 512, 640, 768, 832, 1024, + 1048]) #12 + # self.k_size means kernel size + self.k_size = np.array([3, 5]) #2 + # self.repeat means repeat_num in forth downsample + self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 + + assert self.block_num < 6, 'MobileNetV1: block number must less than 6, but receive block number is {}'.format( + self.block_num) + + def init_tokens(self): + """ + The initial token. + The first one is the index of the first layers' channel in self.head_num, + each line in the following represent the index of the [filter_num1, filter_num2, kernel_size] + and depth means repeat times for forth downsample + """ + # yapf: disable + base_init_tokens = [6, # 32 + 6, 6, 0, # 32, 64, 3 + 6, 7, 0, # 64, 128, 3 + 7, 6, 0, # 128, 128, 3 + 6, 10, 0, # 128, 256, 3 + 10, 8, 0, # 256, 256, 3 + 8, 11, 0, # 256, 512, 3 + 4, # depth 5 + 11, 8, 0, # 512, 512, 3 + 8, 10, 0, # 512, 1024, 3 + 10, 10, 0] # 1024, 1024, 3 + # yapf: enable + if self.block_num < 5: + self.token_len = 1 + (self.block_num * 2 - 1) * 3 + else: + self.token_len = 2 + (self.block_num * 2 - 1) * 3 + return base_init_tokens[:self.token_len] + + def range_table(self): + """ + Get range table of current search space, constrains the range of tokens. + """ + # yapf: disable + base_range_table = [len(self.head_num), + len(self.filter_num1), len(self.filter_num2), len(self.k_size), + len(self.filter_num2), len(self.filter_num3), len(self.k_size), + len(self.filter_num3), len(self.filter_num4), len(self.k_size), + len(self.filter_num4), len(self.filter_num5), len(self.k_size), + len(self.filter_num5), len(self.filter_num6), len(self.k_size), + len(self.filter_num6), len(self.filter_num7), len(self.k_size), + len(self.repeat), + len(self.filter_num7), len(self.filter_num8), len(self.k_size), + len(self.filter_num8), len(self.filter_num9), len(self.k_size), + len(self.filter_num9), len(self.filter_num9), len(self.k_size)] + # yapf: enable + return base_range_table[:self.token_len] + + def token2arch(self, tokens=None): + + if tokens is None: + tokens = self.tokens() + + bottleneck_param_list = [] + + if self.block_num >= 1: + # tokens[0] = 32 + # 32, 64 + bottleneck_param_list.append( + (self.filter_num1[tokens[1]], self.filter_num2[tokens[2]], 1, + self.k_size[tokens[3]])) + if self.block_num >= 2: + # 64 128 128 128 + bottleneck_param_list.append( + (self.filter_num2[tokens[4]], self.filter_num3[tokens[5]], 2, + self.k_size[tokens[6]])) + bottleneck_param_list.append( + (self.filter_num3[tokens[7]], self.filter_num4[tokens[8]], 1, + self.k_size[tokens[9]])) + if self.block_num >= 3: + # 128 256 256 256 + bottleneck_param_list.append( + (self.filter_num4[tokens[10]], self.filter_num5[tokens[11]], 2, + self.k_size[tokens[12]])) + bottleneck_param_list.append( + (self.filter_num5[tokens[13]], self.filter_num6[tokens[14]], 1, + self.k_size[tokens[15]])) + if self.block_num >= 4: + # 256 512 (512 512) * 5 + bottleneck_param_list.append( + (self.filter_num6[tokens[16]], self.filter_num7[tokens[17]], 2, + self.k_size[tokens[18]])) + for i in range(self.repeat[tokens[19]]): + bottleneck_param_list.append( + (self.filter_num7[tokens[20]], + self.filter_num8[tokens[21]], 1, self.k_size[tokens[22]])) + if self.block_num >= 5: + # 512 1024 1024 1024 + bottleneck_param_list.append( + (self.filter_num8[tokens[23]], self.filter_num9[tokens[24]], 2, + self.k_size[tokens[25]])) + bottleneck_param_list.append( + (self.filter_num9[tokens[26]], self.filter_num9[tokens[27]], 1, + self.k_size[tokens[28]])) + + def net_arch(input): + input = conv_bn_layer( + input=input, + filter_size=3, + num_filters=self.head_num[tokens[0]], + stride=2, + name='mobilenetv1') + + for i, layer_setting in enumerate(bottleneck_param_list): + filter_num1, filter_num2, stride, kernel_size = layer_setting + input = self._depthwise_separable( + input=input, + num_filters1=filter_num1, + num_filters2=filter_num2, + num_groups=filter_num1, + stride=stride, + scale=self.scale, + kernel_size=kernel_size, + name='mobilenetv1_{}'.format(str(i + 1))) + + if self.output_size == 1: + print('NOTE: if output_size is 1, add fc layer in the end!!!') + input = fluid.layers.fc( + input=input, + size=self.class_dim, + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + else: + assert self.output_size == input.shape[2], \ + ("output_size must EQUAL to input_size / (2^block_num)." + "But receive input_size={}, output_size={}, block_num={}".format( + self.input_size, self.output_size, self.block_num)) + + return input + + return net_arch + + def _depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + kernel_size, + name=None): + depthwise_conv = conv_bn_layer( + input=input, + filter_size=kernel_size, + num_filters=int(num_filters1 * scale), + stride=stride, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + '_dw') + pointwise_conv = conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + name=name + '_sep') + + return pointwise_conv diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 28d8a7ea03bc94618b9b5575f837f09879d309c8..e974a676a70546e19aa4649679393031634e7822 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -32,11 +32,15 @@ class MobileNetV2Space(SearchSpaceBase): input_size, output_size, block_num, + block_mask=None, scale=1.0, class_dim=1000): super(MobileNetV2Space, self).__init__(input_size, output_size, - block_num) + block_num, block_mask) + assert self.block_mask == None, 'MobileNetV2Space will use origin MobileNetV2 as seach space, so use input_size, output_size and block_num to search' + # self.head_num means the first convolution channel self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 + # self.filter_num1 ~ self.filter_num6 means following convlution channel self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) #8 self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) #8 self.filter_num3 = np.array([16, 24, 32, 48, 64, 80, 96, 128]) #8 @@ -46,15 +50,21 @@ class MobileNetV2Space(SearchSpaceBase): [32, 48, 64, 80, 96, 128, 144, 160, 192, 224]) #10 self.filter_num6 = np.array( [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512]) #12 + # self.k_size means kernel size self.k_size = np.array([3, 5]) #2 + # self.multiply means expansion_factor of each _inverted_residual_unit self.multiply = np.array([1, 2, 3, 4, 6]) #5 + # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 self.scale = scale self.class_dim = class_dim + assert self.block_num < 7, 'MobileNetV2: block number must less than 7, but receive block number is {}'.format( + self.block_num) + def init_tokens(self): """ - The initial token send to controller. + The initial token. The first one is the index of the first layers' channel in self.head_num, each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] """ @@ -80,18 +90,18 @@ class MobileNetV2Space(SearchSpaceBase): def range_table(self): """ - get range table of current search space + Get range table of current search space, constrains the range of tokens. """ # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # yapf: disable - range_table_base = [7, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 8, 6, 2, - 5, 10, 6, 2, - 5, 10, 6, 2, - 5, 12, 6, 2] + range_table_base = [len(self.head_num), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num2), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num3), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num4), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num5), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num6), len(self.repeat), len(self.k_size)] range_table_base = list(np.array(range_table_base) - 1) # yapf: enable return range_table_base[:self.token_len] @@ -101,11 +111,9 @@ class MobileNetV2Space(SearchSpaceBase): return net_arch function """ - assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format( - self.block_num) - if tokens is None: tokens = self.init_tokens() + print(tokens) bottleneck_params_list = [] if self.block_num >= 1: @@ -128,7 +136,7 @@ class MobileNetV2Space(SearchSpaceBase): (self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) bottleneck_params_list.append( - (self.multiply[tokens[17]], self.filter_num3[tokens[18]], + (self.multiply[tokens[17]], self.filter_num4[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) if self.block_num >= 6: bottleneck_params_list.append( diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py index 7ed404e5e145c9f173aee95823c8d6ac6a47dfdb..fd761d417575988e8ba8bd99da25372613c5912f 100644 --- a/paddleslim/nas/search_space/resnet.py +++ b/paddleslim/nas/search_space/resnet.py @@ -32,31 +32,144 @@ class ResNetSpace(SearchSpaceBase): input_size, output_size, block_num, - scale=1.0, + block_mask=None, + extract_feature=False, class_dim=1000): - super(ResNetSpace, self).__init__(input_size, output_size, block_num) - pass + super(ResNetSpace, self).__init__(input_size, output_size, block_num, + block_mask) + assert self.block_mask == None, 'ResNetSpace will use origin ResNet as seach space, so use input_size, output_size and block_num to search' + # self.filter_num1 ~ self.filter_num4 means convolution channel + self.filter_num1 = np.array([48, 64, 96, 128, 160, 192, 224]) #7 + self.filter_num2 = np.array([64, 96, 128, 160, 192, 256, 320]) #7 + self.filter_num3 = np.array([128, 160, 192, 256, 320, 384]) #6 + self.filter_num4 = np.array([192, 256, 384, 512, 640]) #5 + # self.repeat1 ~ self.repeat4 means depth of network + self.repeat1 = [2, 3, 4, 5, 6] #5 + self.repeat2 = [2, 3, 4, 5, 6, 7] #6 + self.repeat3 = [2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24] #13 + self.repeat4 = [2, 3, 4, 5, 6, 7] #6 + self.class_dim = class_dim + self.extract_feature = extract_feature + assert self.block_num < 5, 'ResNet: block number must less than 5, but receive block number is {}'.format( + self.block_num) def init_tokens(self): - return [0, 0, 0, 0, 0, 0] + """ + The initial token. + return 2 * self.block_num, 2 means depth and num_filter + """ + init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] + self.token_len = self.block_num * 2 + return init_token_base[:self.token_len] def range_table(self): - return [2, 2, 2, 2, 2, 2] + """ + Get range table of current search space, constrains the range of tokens. + """ + #2 * self.block_num, 2 means depth and num_filter + range_table_base = [ + len(self.filter_num1), len(self.repeat1), len(self.filter_num2), + len(self.repeat2), len(self.filter_num3), len(self.repeat3), + len(self.filter_num4), len(self.repeat4) + ] + return range_table_base[:self.token_len] def token2arch(self, tokens=None): + """ + return net_arch function + """ if tokens is None: - self.init_tokens() + tokens = self.init_tokens() + + depth = [] + num_filters = [] + if self.block_num >= 1: + filter1 = self.filter_num1[tokens[0]] + repeat1 = self.repeat1[tokens[1]] + num_filters.append(filter1) + depth.append(repeat1) + if self.block_num >= 2: + filter2 = self.filter_num2[tokens[2]] + repeat2 = self.repeat2[tokens[3]] + num_filters.append(filter2) + depth.append(repeat2) + if self.block_num >= 3: + filter3 = self.filter_num3[tokens[4]] + repeat3 = self.repeat3[tokens[5]] + num_filters.append(filter3) + depth.append(repeat3) + if self.block_num >= 4: + filter4 = self.filter_num4[tokens[6]] + repeat4 = self.repeat4[tokens[7]] + num_filters.append(filter4) + depth.append(repeat4) def net_arch(input): - input = conv_bn_layer( - input, - num_filters=32, - filter_size=3, + conv = conv_bn_layer( + input=input, + filter_size=5, + num_filters=filter1, stride=2, - padding='SAME', - act='sigmoid', - name='resnet_conv1_1') + act='relu', + name='resnet_conv0') + for block in range(len(depth)): + for i in range(depth[block]): + conv = self._bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name='resnet_depth{}_block{}'.format(i, block)) - return input + if self.output_size == 1: + conv = fluid.layers.fc( + input=conv, + size=self.class_dim, + act=None, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer(0.0, + 0.01)), + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.ConstantInitializer(0))) + + return conv return net_arch + + def _shortcut(self, input, ch_out, stride, name=None): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return conv_bn_layer( + input=input, + filter_size=1, + num_filters=ch_out, + stride=stride, + name=name + '_conv') + else: + return input + + def _bottleneck_block(self, input, num_filters, stride, name=None): + conv0 = conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + '_bottleneck_conv0') + conv1 = conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + '_bottleneck_conv1') + conv2 = conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + '_bottleneck_conv2') + + short = self._shortcut( + input, num_filters * 4, stride, name=name + '_shortcut') + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + '_bottleneck_add') diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index bb1ce0f8a4bbd0b18d36fa9199a6ff814ab13236..6a83f86005a5fb2408f7f85f40dff8a9e5cba819 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -19,10 +19,11 @@ class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ - def __init__(self, input_size, output_size, block_num, *argss): + def __init__(self, input_size, output_size, block_num, block_mask, *argss): self.input_size = input_size self.output_size = output_size self.block_num = block_num + self.block_mask = block_mask def init_tokens(self): """Get init tokens in search space.