From 17215fd7e61dcbe313753fdfe22f754145ec7ac9 Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Tue, 3 Dec 2019 13:05:41 +0000 Subject: [PATCH] update nas search space --- demo/nas/block_sa_nas_mobilenetv2.py | 281 ++++++++++++++++++ demo/nas/block_sa_nas_mobilenetv2_cifar10.py | 143 --------- demo/nas/sa_nas_mobilenetv2.py | 21 +- paddleslim/nas/search_space/__init__.py | 4 +- .../nas/search_space/combine_search_space.py | 37 ++- .../nas/search_space/mobilenet_block.py | 28 +- paddleslim/nas/search_space/mobilenetv1.py | 150 +++++----- paddleslim/nas/search_space/mobilenetv2.py | 99 ++---- paddleslim/nas/search_space/resnet.py | 63 ++-- .../nas/search_space/search_space_base.py | 8 +- paddleslim/nas/search_space/utils.py | 14 + 11 files changed, 490 insertions(+), 358 deletions(-) create mode 100644 demo/nas/block_sa_nas_mobilenetv2.py delete mode 100644 demo/nas/block_sa_nas_mobilenetv2_cifar10.py diff --git a/demo/nas/block_sa_nas_mobilenetv2.py b/demo/nas/block_sa_nas_mobilenetv2.py new file mode 100644 index 00000000..fb5f094c --- /dev/null +++ b/demo/nas/block_sa_nas_mobilenetv2.py @@ -0,0 +1,281 @@ +import sys +sys.path.append('..') +import numpy as np +import argparse +import ast +import logging +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddleslim.analysis import flops +from paddleslim.nas import SANAS +from paddleslim.common import get_logger +from optimizer import create_optimizer +import imagenet_reader + +_logger = get_logger(__name__, level=logging.INFO) + +def create_data_loader(image_shape): + data_shape = [-1] + image_shape + data = fluid.data(name='data', shape=data_shape, dtype='float32') + label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data, label], + capacity=1024, + use_double_buffer=True, + iterable=True) + return data_loader, data, label + +def conv_bn_layer(input, + filter_size, + num_filters, + stride, + padding='SAME', + num_groups=1, + act=None, + name=None, + use_cudnn=True): + conv = fluid.layers.conv2d( + input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + return fluid.layers.batch_norm( + input=conv, + act = act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + + +def search_mobilenetv2_block(config, args, image_size): + image_shape = [3, image_size, image_size] + if args.is_server: + sa_nas = SANAS(config, server_addr=("", args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=True) + else: + sa_nas = SANAS(config, server_addr=(args.server_address, args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, is_server=False) + + for step in range(args.search_steps): + archs = sa_nas.next_archs()[0] + + train_program = fluid.Program() + test_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + train_loader, data, label = create_data_loader(image_shape) + data = conv_bn_layer(input=data, num_filters=32, filter_size=3, stride=2, padding='SAME', act='relu6', name='mobilenetv2_conv1') + data = archs(data)[0] + data = conv_bn_layer(input=data, num_filters=1280, filter_size=1, stride=1, padding='SAME', act='relu6', name='mobilenetv2_last_conv') + data = fluid.layers.pool2d(input=data, pool_size=7, pool_stride=1, pool_type='avg', global_pooling=True, name='mobilenetv2_last_pool') + output = fluid.layers.fc( + input=data, + size=args.class_dim, + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + + softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) + cost = fluid.layers.cross_entropy(input=softmax_out, label=label) + avg_cost = fluid.layers.mean(cost) + acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5) + test_program = train_program.clone(for_test=True) + + optimizer = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(avg_cost) + + current_flops = flops(train_program) + print('step: {}, current_flops: {}'.format(step, current_flops)) + if current_flops > args.max_flops: + continue + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + + if args.data == 'cifar10': + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(cycle=False), buf_size=1024), + batch_size=args.batch_size, + drop_last=True) + + test_reader = paddle.batch( + paddle.dataset.cifar.test10(cycle=False), + batch_size=args.batch_size, + drop_last=False) + elif args.data == 'imagenet': + train_reader = paddle.batch( + imagenet_reader.train(), + batch_size=args.batch_size, + drop_last=True) + test_reader = paddle.batch( + imagenet_reader.val(), + batch_size=args.batch_size, + drop_last=False) + + test_loader, _, _ = create_data_loader(image_shape) + train_loader.set_sample_list_generator( + train_reader, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + test_loader.set_sample_list_generator(test_reader, places=place) + + + build_strategy = fluid.BuildStrategy() + train_compiled_program = fluid.CompiledProgram( + train_program).with_data_parallel( + loss_name=avg_cost.name, build_strategy=build_strategy) + for epoch_id in range(args.retain_epoch): + for batch_id, data in enumerate(train_loader()): + fetches = [avg_cost.name] + s_time = time.time() + outs = exe.run(train_compiled_program, + feed=data, + fetch_list=fetches)[0] + batch_time = time.time() - s_time + if batch_id % 10 == 0: + _logger.info( + 'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'. + format(step, epoch_id, batch_id, outs[0], batch_time)) + + reward = [] + for batch_id, data in enumerate(test_loader()): + test_fetches = [ + avg_cost.name, acc_top1.name, acc_top5.name + ] + batch_reward = exe.run(test_program, + feed=data, + fetch_list=test_fetches) + reward_avg = np.mean(np.array(batch_reward), axis=1) + reward.append(reward_avg) + + _logger.info( + 'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'. + format(step, batch_id, batch_reward[0], batch_reward[1], + batch_reward[2])) + + finally_reward = np.mean(np.array(reward), axis=0) + _logger.info( + 'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format( + finally_reward[0], finally_reward[1], finally_reward[2])) + + sa_nas.reward(float(finally_reward[1])) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='SA NAS MobileNetV2 cifar10 argparase') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='Whether to use GPU in train/test model.') + parser.add_argument( + '--class_dim', type=int, default=1000, help='classify number.') + parser.add_argument( + '--batch_size', type=int, default=256, help='batch size.') + parser.add_argument( + '--data', + type=str, + default='cifar10', + choices=['cifar10', 'imagenet'], + help='server address.') + # controller + parser.add_argument( + '--reduce_rate', type=float, default=0.85, help='reduce rate.') + parser.add_argument( + '--init_temperature', + type=float, + default=10.24, + help='init temperature.') + parser.add_argument( + '--is_server', + type=ast.literal_eval, + default=True, + help='Whether to start a server.') + # nas args + parser.add_argument( + '--max_flops', type=int, default=592948064, help='reduce rate.') + parser.add_argument( + '--retain_epoch', type=int, default=5, help='train epoch before val.') + parser.add_argument( + '--end_epoch', type=int, default=500, help='end epoch present client.') + parser.add_argument( + '--search_steps', + type=int, + default=100, + help='controller server number.') + parser.add_argument( + '--server_address', type=str, default=None, help='server address.') + parser.add_argument( + '--port', type=int, default=8889, help='server port.') + # optimizer args + parser.add_argument( + '--lr_strategy', + type=str, + default='piecewise_decay', + help='learning rate decay strategy.') + parser.add_argument('--lr', type=float, default=0.1, help='learning rate.') + parser.add_argument( + '--l2_decay', type=float, default=1e-4, help='learning rate decay.') + parser.add_argument( + '--step_epochs', + nargs='+', + type=int, + default=[30, 60, 90], + help="piecewise decay step") + parser.add_argument( + '--momentum_rate', + type=float, + default=0.9, + help='learning rate decay.') + parser.add_argument( + '--warm_up_epochs', + type=float, + default=5.0, + help='learning rate decay.') + parser.add_argument( + '--num_epochs', type=int, default=120, help='learning rate decay.') + parser.add_argument( + '--decay_epochs', type=float, default=2.4, help='learning rate decay.') + parser.add_argument( + '--decay_rate', type=float, default=0.97, help='learning rate decay.') + parser.add_argument( + '--total_images', + type=int, + default=1281167, + help='learning rate decay.') + args = parser.parse_args() + print(args) + + if args.data == 'cifar10': + image_size = 32 + elif args.data == 'imagenet': + image_size = 224 + else: + raise NotImplemented( + 'data must in [cifar10, imagenet], but received: {}'.format( + args.data)) + + # block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block + config_info = { + 'input_size': None, + 'output_size': None, + 'block_num': None, + 'block_mask': [0, 1, 1, 1, 1, 0, 1, 0] + } + config = [('MobileNetV2BlockSpace', config_info)] + + search_mobilenetv2_block(config, args, image_size) diff --git a/demo/nas/block_sa_nas_mobilenetv2_cifar10.py b/demo/nas/block_sa_nas_mobilenetv2_cifar10.py deleted file mode 100644 index a363ba7e..00000000 --- a/demo/nas/block_sa_nas_mobilenetv2_cifar10.py +++ /dev/null @@ -1,143 +0,0 @@ -import sys -sys.path.append('..') -import numpy as np -import argparse -import ast -import paddle -import paddle.fluid as fluid -from paddle.fluid.param_attr import ParamAttr -from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory -from paddleslim.analysis import flops -from paddleslim.nas import SANAS - - -def create_data_loader(): - data = fluid.data(name='data', shape=[-1, 3, 32, 32], dtype='float32') - label = fluid.data(name='label', shape=[-1, 1], dtype='int64') - data_loader = fluid.io.DataLoader.from_generator( - feed_list=[data, label], - capacity=1024, - use_double_buffer=True, - iterable=True) - return data_loader, data, label - - -def init_sa_nas(config): - factory = SearchSpaceFactory() - space = factory.get_search_space(config) - model_arch = space.token2arch()[0] - main_program = fluid.Program() - startup_program = fluid.Program() - - with fluid.program_guard(main_program, startup_program): - data_loader, data, label = create_data_loader() - output = model_arch(data) - output = fluid.layers.fc( - input=output, - size=args.class_dim, - param_attr=ParamAttr(name='mobilenetv2_fc_weights'), - bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) - cost = fluid.layers.mean( - fluid.layers.softmax_with_cross_entropy( - logits=output, label=label)) - - base_flops = flops(main_program) - search_steps = 10000000 - - ### start a server and a client - sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) - - ### start a client, server_addr is server address - #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) - - return sa_nas, search_steps - - -def search_mobilenetv2_cifar10(config, args): - sa_nas, search_steps = init_sa_nas(config) - for i in range(search_steps): - print('search step: ', i) - archs = sa_nas.next_archs()[0] - - train_program = fluid.Program() - test_program = fluid.Program() - startup_program = fluid.Program() - with fluid.program_guard(train_program, startup_program): - train_loader, data, label = create_data_loader() - output = archs(data) - output = fluid.layers.fc( - input=output, - size=args.class_dim, - param_attr=ParamAttr(name='mobilenetv2_fc_weights'), - bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) - cost = fluid.layers.mean( - fluid.layers.softmax_with_cross_entropy( - logits=output, label=label))[0] - test_program = train_program.clone(for_test=True) - - optimizer = fluid.optimizer.Momentum( - learning_rate=0.1, - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - optimizer.minimize(cost) - - place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(startup_program) - - train_reader = paddle.reader.shuffle( - paddle.dataset.cifar.train10(cycle=False), buf_size=1024) - train_loader.set_sample_generator( - train_reader, - batch_size=512, - places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) - - test_loader, _, _ = create_data_loader() - test_reader = paddle.dataset.cifar.test10(cycle=False) - test_loader.set_sample_generator( - test_reader, - batch_size=256, - drop_last=False, - places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) - - for epoch_id in range(10): - for batch_id, data in enumerate(train_loader()): - loss = exe.run(train_program, - feed=data, - fetch_list=[cost.name])[0] - if batch_id % 5 == 0: - print('epoch: {}, batch: {}, loss: {}'.format( - epoch_id, batch_id, loss[0])) - - for data in test_loader(): - reward = exe.run(test_program, feed=data, - fetch_list=[cost.name])[0] - - print('reward:', reward) - sa_nas.reward(float(reward)) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser( - description='SA NAS MobileNetV2 cifar10 argparase') - parser.add_argument( - '--use_gpu', - type=ast.literal_eval, - default=True, - help='Whether to use GPU in train/test model.') - parser.add_argument( - '--class_dim', type=int, default=1000, help='classify number.') - args = parser.parse_args() - print(args) - - # block mask means block number, 1 mean downsample, 0 means the size of feature map don't change after this block - config_info = { - 'input_size': 32, - 'output_size': 1, - 'block_num': 5, - 'block_mask': [0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0] - } - config = [('MobileNetV2BlockSpace', config_info)] - - search_mobilenetv2_cifar10(config, args) diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index 142c2c08..4df1c364 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -9,7 +9,7 @@ import ast import logging import paddle import paddle.fluid as fluid -from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory +from paddle.fluid.param_attr import ParamAttr from paddleslim.analysis import flops from paddleslim.nas import SANAS from paddleslim.common import get_logger @@ -40,6 +40,7 @@ def build_program(main_program, with fluid.program_guard(main_program, startup_program): data_loader, data, label = create_data_loader(image_shape) output = archs(data) + output = fluid.layers.fc(input=output, size=args.class_dim, param_attr=ParamAttr(name='mobilenetv2_fc_weights'), bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) softmax_out = fluid.layers.softmax(input=output, use_cudnn=False) cost = fluid.layers.cross_entropy(input=softmax_out, label=label) @@ -54,13 +55,11 @@ def build_program(main_program, def search_mobilenetv2(config, args, image_size, is_server=True): - factory = SearchSpaceFactory() - space = factory.get_search_space(config) if is_server: ### start a server and a client sa_nas = SANAS( config, - server_addr=("", 8883), + server_addr=("", args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, @@ -69,7 +68,7 @@ def search_mobilenetv2(config, args, image_size, is_server=True): ### start a client sa_nas = SANAS( config, - server_addr=("10.255.125.38", 8883), + server_addr=(args.server_address, args.port), init_temperature=args.init_temperature, reduce_rate=args.reduce_rate, search_steps=args.search_steps, @@ -215,6 +214,8 @@ if __name__ == '__main__': help='controller server number.') parser.add_argument( '--server_address', type=str, default=None, help='server address.') + parser.add_argument( + '--port', type=int, default=8889, help='server port.') # optimizer args parser.add_argument( '--lr_strategy', @@ -224,6 +225,8 @@ if __name__ == '__main__': parser.add_argument('--lr', type=float, default=0.1, help='learning rate.') parser.add_argument( '--l2_decay', type=float, default=1e-4, help='learning rate decay.') + parser.add_argument( + '--class_dim', type=int, default=1000, help='classify number.') parser.add_argument( '--step_epochs', nargs='+', @@ -265,12 +268,6 @@ if __name__ == '__main__': 'data must in [cifar10, imagenet], but received: {}'.format( args.data)) - config_info = { - 'input_size': image_size, - 'output_size': 1, - 'block_num': block_num, - 'block_mask': None - } - config = [('MobileNetV2Space', config_info)] + config = [('MobileNetV2Space')] search_mobilenetv2(config, args, image_size, is_server=args.is_server) diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py index af946669..b415930b 100644 --- a/paddleslim/nas/search_space/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -14,8 +14,8 @@ import mobilenetv2 from .mobilenetv2 import * -import mobilenetv2_block -from .mobilenetv2_block import * +import mobilenet_block +from .mobilenet_block import * import mobilenetv1 from .mobilenetv1 import * import resnet diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 17ebbd39..2f862955 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -19,12 +19,15 @@ from __future__ import print_function import numpy as np import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr +import logging +from ...common import get_logger from .search_space_base import SearchSpaceBase from .search_space_registry import SEARCHSPACE from .base_layer import conv_bn_layer __all__ = ["CombineSearchSpace"] +_logger = get_logger(__name__, level=logging.INFO) class CombineSearchSpace(object): """ @@ -37,7 +40,13 @@ class CombineSearchSpace(object): self.lens = len(config_lists) self.spaces = [] for config_list in config_lists: - key, config = config_list + if isinstance(config_list, tuple): + key, config = config_list + if isinstance(config_list, str): + key = config_list + config = None + else: + raise NotImplementedError('the type of config is Error!!! Please check the config information. Receive the type of config is {}'.format(type(config_list))) self.spaces.append(self._get_single_search_space(key, config)) self.init_tokens() @@ -52,10 +61,28 @@ class CombineSearchSpace(object): model space(class) """ cls = SEARCHSPACE.get(key) - block_mask = config['block_mask'] if 'block_mask' in config else None - space = cls(config['input_size'], - config['output_size'], - config['block_num'], + + if config is None: + block_mask = None + input_size = None + output_size = None + block_num = None + else: + if 'Block' not in cls.__name__: + _logger.warn('if space is not a Block space, config is useless, current space is {}'.format(cls.__name__)) + + block_mask = config['block_mask'] if 'block_mask' in config else None + input_size = config['input_size'] if 'input_size' in config else None + output_size = config['output_size'] if 'output_size' in config else None + block_num = config['block_num'] if 'block_num' in config else None + + if 'Block' in cls.__name__: + if block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None): + raise NotImplementedError("block_mask or (block num and input_size and output_size) can NOT be None at the same time in Block SPACE!") + + space = cls(input_size, + output_size, + block_num, block_mask=block_mask) return space diff --git a/paddleslim/nas/search_space/mobilenet_block.py b/paddleslim/nas/search_space/mobilenet_block.py index dd8d0611..f0f25af4 100644 --- a/paddleslim/nas/search_space/mobilenet_block.py +++ b/paddleslim/nas/search_space/mobilenet_block.py @@ -98,19 +98,19 @@ class MobileNetV2BlockSpace(SearchSpaceBase): num_minus = self.block_num % self.downsample_num ### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers for i in range(self.downsample_num): - self.bottleneck_params_list.append(self.mutiply[tokens[i * 4], self.filter_num[tokens[i * 4 + 1]], - self.repeat[tokens[i * 4 + 2]], 2, self.k_size[tokens[i * 4 + 3]]) + self.bottleneck_params_list.append((self.mutiply[tokens[i * 4]], self.filter_num[tokens[i * 4 + 1]], + self.repeat[tokens[i * 4 + 2]], 2, self.k_size[tokens[i * 4 + 3]])) ### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block for k in range(repeat_num - 1): kk = k * self.downsample_num + i - self.bottleneck_params_list.append(self.mutiply[tokens[kk * 4], self.filter_num[tokens[kk * 4 + 1]], - self.repeat[tokens[kk * 4 + 2]], 1, self.k_size[tokens[kk * 4 + 3]]) + self.bottleneck_params_list.append((self.mutiply[tokens[kk * 4]], self.filter_num[tokens[kk * 4 + 1]], + self.repeat[tokens[kk * 4 + 2]], 1, self.k_size[tokens[kk * 4 + 3]])) if self.downsample_num - i <= num_minus: j = self.downsample_num * repeat_num + i - self.bottleneck_params_list.append(self.mutiply[tokens[j * 4], self.filter_num[tokens[j * 4 + 1]], - self.repeat[tokens[j * 4 + 2]], 1, self.k_size[tokens[j * 4 + 3]]) + self.bottleneck_params_list.append((self.mutiply[tokens[j * 4]], self.filter_num[tokens[j * 4 + 1]], + self.repeat[tokens[j * 4 + 2]], 1, self.k_size[tokens[j * 4 + 3]])) def net_arch(input, return_mid_layer=False, return_block=[]): assert isinstance(return_block, list), 'return_block must be a list.' @@ -288,7 +288,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase): range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.k_size)) else: - for i in range(self.block_num)): + for i in range(self.block_num): range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.k_size)) @@ -300,26 +300,26 @@ class MobileNetV1BlockSpace(SearchSpaceBase): tokens = self.init_tokens() self.bottleneck_param_list = [] - if self.block_mask != None + if self.block_mask != None: for i in range(len(self.block_mask)): - self.bottleneck_params_list.append(self.filter_num[tokens[i * 3]], self.filter_num[tokens[i * 3 + 1]], 2 if self.block_mask[i] == 1 else 1, self.k_size[tokens[i * 3 + 2]) + self.bottleneck_params_list.append((self.filter_num[tokens[i * 3]], self.filter_num[tokens[i * 3 + 1]], 2 if self.block_mask[i] == 1 else 1, self.k_size[tokens[i * 3 + 2]])) else: repeat_num = self.block_num / self.downsample_num num_minus = self.block_num % self.downsample_num for i in range(self.block_num): ### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers - self.bottleneck_params_list.append(self.filter_num[tokens[i * 3]], self.filter_num[tokens[i * 3 + 1]], 2, self.k_size[tokens[i * 3 + 2]) + self.bottleneck_params_list.append((self.filter_num[tokens[i * 3]], self.filter_num[tokens[i * 3 + 1]], 2, self.k_size[tokens[i * 3 + 2]])) ### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block for k in range(repeat_num - 1): kk = k * self.downsample_num + i - self.bottleneck_params_list.append(self.filter_num[tokens[kk * 3], self.filter_num[tokens[kk * 3 + 1]], - 1, self.k_size[tokens[kk * 3 + 2]]) + self.bottleneck_params_list.append((self.filter_num[tokens[kk * 3]], self.filter_num[tokens[kk * 3 + 1]], + 1, self.k_size[tokens[kk * 3 + 2]])) if self.downsample_num - i <= num_minus: j = self.downsample_num * repeat_num + i - self.bottleneck_params_list.append(self.filter_num[tokens[j * 3], self.filter_num[tokens[j * 3 + 1]], - 1, self.k_size[tokens[j * 3 + 2]]) + self.bottleneck_params_list.append((self.filter_num[tokens[j * 3]], self.filter_num[tokens[j * 3 + 1]], + 1, self.k_size[tokens[j * 3 + 2]])) def net_arch(input, return_mid_layer=False, return_block=[]): diff --git a/paddleslim/nas/search_space/mobilenetv1.py b/paddleslim/nas/search_space/mobilenetv1.py index 3976d21d..c0fa05b3 100644 --- a/paddleslim/nas/search_space/mobilenetv1.py +++ b/paddleslim/nas/search_space/mobilenetv1.py @@ -32,14 +32,9 @@ class MobileNetV1Space(SearchSpaceBase): input_size, output_size, block_num, - block_mask, - scale=1.0, - class_dim=1000): + block_mask): super(MobileNetV1Space, self).__init__(input_size, output_size, block_num, block_mask) - assert self.block_mask == None, 'MobileNetV1Space will use origin MobileNetV1 as seach space, so use input_size, output_size and block_num to search' - self.scale = scale - self.class_dim = class_dim # self.head_num means the channel of first convolution self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) # 7 # self.filter_num1 ~ self.filtet_num9 means channel of the following convolution @@ -67,8 +62,6 @@ class MobileNetV1Space(SearchSpaceBase): # self.repeat means repeat_num in forth downsample self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 - assert self.block_num < 6, 'MobileNetV1: block number must less than 6, but receive block number is {}'.format( - self.block_num) def init_tokens(self): """ @@ -90,11 +83,7 @@ class MobileNetV1Space(SearchSpaceBase): 8, 10, 0, # 512, 1024, 3 10, 10, 0] # 1024, 1024, 3 # yapf: enable - if self.block_num < 5: - self.token_len = 1 + (self.block_num * 2 - 1) * 3 - else: - self.token_len = 2 + (self.block_num * 2 - 1) * 3 - return base_init_tokens[:self.token_len] + return base_init_tokens def range_table(self): """ @@ -113,63 +102,88 @@ class MobileNetV1Space(SearchSpaceBase): len(self.filter_num8), len(self.filter_num9), len(self.k_size), len(self.filter_num9), len(self.filter_num9), len(self.k_size)] # yapf: enable - return base_range_table[:self.token_len] + return base_range_table def token2arch(self, tokens=None): if tokens is None: tokens = self.tokens() - bottleneck_param_list = [] - - if self.block_num >= 1: - # tokens[0] = 32 - # 32, 64 - bottleneck_param_list.append( - (self.filter_num1[tokens[1]], self.filter_num2[tokens[2]], 1, - self.k_size[tokens[3]])) - if self.block_num >= 2: - # 64 128 128 128 - bottleneck_param_list.append( - (self.filter_num2[tokens[4]], self.filter_num3[tokens[5]], 2, - self.k_size[tokens[6]])) - bottleneck_param_list.append( - (self.filter_num3[tokens[7]], self.filter_num4[tokens[8]], 1, - self.k_size[tokens[9]])) - if self.block_num >= 3: - # 128 256 256 256 - bottleneck_param_list.append( - (self.filter_num4[tokens[10]], self.filter_num5[tokens[11]], 2, - self.k_size[tokens[12]])) - bottleneck_param_list.append( - (self.filter_num5[tokens[13]], self.filter_num6[tokens[14]], 1, - self.k_size[tokens[15]])) - if self.block_num >= 4: - # 256 512 (512 512) * 5 - bottleneck_param_list.append( - (self.filter_num6[tokens[16]], self.filter_num7[tokens[17]], 2, - self.k_size[tokens[18]])) - for i in range(self.repeat[tokens[19]]): - bottleneck_param_list.append( - (self.filter_num7[tokens[20]], - self.filter_num8[tokens[21]], 1, self.k_size[tokens[22]])) - if self.block_num >= 5: - # 512 1024 1024 1024 - bottleneck_param_list.append( - (self.filter_num8[tokens[23]], self.filter_num9[tokens[24]], 2, - self.k_size[tokens[25]])) - bottleneck_param_list.append( - (self.filter_num9[tokens[26]], self.filter_num9[tokens[27]], 1, - self.k_size[tokens[28]])) - - def net_arch(input): + self.bottleneck_param_list = [] + + # tokens[0] = 32 + # 32, 64 + self.bottleneck_param_list.append( + (self.filter_num1[tokens[1]], self.filter_num2[tokens[2]], 1, + self.k_size[tokens[3]])) + # 64 128 128 128 + self.bottleneck_param_list.append( + (self.filter_num2[tokens[4]], self.filter_num3[tokens[5]], 2, + self.k_size[tokens[6]])) + self.bottleneck_param_list.append( + (self.filter_num3[tokens[7]], self.filter_num4[tokens[8]], 1, + self.k_size[tokens[9]])) + # 128 256 256 256 + self.bottleneck_param_list.append( + (self.filter_num4[tokens[10]], self.filter_num5[tokens[11]], 2, + self.k_size[tokens[12]])) + self.bottleneck_param_list.append( + (self.filter_num5[tokens[13]], self.filter_num6[tokens[14]], 1, + self.k_size[tokens[15]])) + # 256 512 (512 512) * 5 + self.bottleneck_param_list.append( + (self.filter_num6[tokens[16]], self.filter_num7[tokens[17]], 2, + self.k_size[tokens[18]])) + for i in range(self.repeat[tokens[19]]): + self.bottleneck_param_list.append( + (self.filter_num7[tokens[20]], + self.filter_num8[tokens[21]], 1, self.k_size[tokens[22]])) + # 512 1024 1024 1024 + self.bottleneck_param_list.append( + (self.filter_num8[tokens[23]], self.filter_num9[tokens[24]], 2, + self.k_size[tokens[25]])) + self.bottleneck_param_list.append( + (self.filter_num9[tokens[26]], self.filter_num9[tokens[27]], 1, + self.k_size[tokens[28]])) + + def _modify_bottle_params(output_stride=None): + if output_stride is not None and output_stride % 2 != 0: + raise Exception("output stride must to be even number") + if output_stride is None: + return + else: + stride = 2 + for i, layer_setting in enumerate(self.bottleneck_params_list): + f1, f2, s, ks = layer_setting + stride = stride * s + if stride > output_stride: + s = 1 + self.bottleneck_params_list[i] = (f1, f2, s, ks) + + + def net_arch(input, scale=1.0, return_block=[], end_points=None, output_stride=None): + self.scale = scale + _modify_bottle_params(output_stride) + + decode_ends = dict() + + def check_points(count, points): + if points is None: + return False + else: + if isinstance(points, list): + return (True if count in points else False) + else: + return (True if count == points else False) + input = conv_bn_layer( input=input, filter_size=3, num_filters=self.head_num[tokens[0]], stride=2, - name='mobilenetv1') + name='mobilenetv1_conv1') + layer_count = 1 for i, layer_setting in enumerate(bottleneck_param_list): filter_num1, filter_num2, stride, kernel_size = layer_setting input = self._depthwise_separable( @@ -181,19 +195,15 @@ class MobileNetV1Space(SearchSpaceBase): scale=self.scale, kernel_size=kernel_size, name='mobilenetv1_{}'.format(str(i + 1))) + layer_count += 1 + ### return_block and end_points means block num + if check_points(layer_count, return_block): + decode_ends[layer_count] = depthwise_output - if self.output_size == 1: - print('NOTE: if output_size is 1, add fc layer in the end!!!') - input = fluid.layers.fc( - input=input, - size=self.class_dim, - param_attr=ParamAttr(name='mobilenetv2_fc_weights'), - bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) - else: - assert self.output_size == input.shape[2], \ - ("output_size must EQUAL to input_size / (2^block_num)." - "But receive input_size={}, output_size={}, block_num={}".format( - self.input_size, self.output_size, self.block_num)) + if check_points(layer_count, end_points): + return input, decode_ends + + input = fluid.layers.pool2d(input=input, pool_type='avg', global_pooling=True, name='mobilenetv1_last_pool') return input diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index 36231912..09c3aef3 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -32,12 +32,9 @@ class MobileNetV2Space(SearchSpaceBase): input_size, output_size, block_num, - block_mask=None, - scale=1.0, - class_dim=1000): + block_mask=None): super(MobileNetV2Space, self).__init__(input_size, output_size, block_num, block_mask) - assert self.block_mask == None, 'MobileNetV2Space will use origin MobileNetV2 as seach space, so use input_size, output_size and block_num to search' # self.head_num means the first convolution channel self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 # self.filter_num1 ~ self.filter_num6 means following convlution channel @@ -56,11 +53,7 @@ class MobileNetV2Space(SearchSpaceBase): self.multiply = np.array([1, 2, 3, 4, 6]) #5 # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 - self.scale = scale - self.class_dim = class_dim - assert self.block_num < 7, 'MobileNetV2: block number must less than 7, but receive block number is {}'.format( - self.block_num) def init_tokens(self): """ @@ -80,13 +73,7 @@ class MobileNetV2Space(SearchSpaceBase): 4, 9, 0, 0] # 6, 320, 1 # yapf: enable - if self.block_num < 5: - self.token_len = 1 + (self.block_num - 1) * 4 - else: - self.token_len = 1 + (self.block_num + 2 * - (self.block_num - 5)) * 4 - - return init_token_base[:self.token_len] + return init_token_base def range_table(self): """ @@ -102,9 +89,8 @@ class MobileNetV2Space(SearchSpaceBase): len(self.multiply), len(self.filter_num4), len(self.repeat), len(self.k_size), len(self.multiply), len(self.filter_num5), len(self.repeat), len(self.k_size), len(self.multiply), len(self.filter_num6), len(self.repeat), len(self.k_size)] - range_table_base = list(np.array(range_table_base) - 1) # yapf: enable - return range_table_base[:self.token_len] + return range_table_base def token2arch(self, tokens=None): """ @@ -115,35 +101,29 @@ class MobileNetV2Space(SearchSpaceBase): tokens = self.init_tokens() self.bottleneck_params_list = [] - if self.block_num >= 1: - self.bottleneck_params_list.append( - (1, self.head_num[tokens[0]], 1, 1, 3)) - if self.block_num >= 2: - self.bottleneck_params_list.append( - (self.multiply[tokens[1]], self.filter_num1[tokens[2]], - self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) - if self.block_num >= 3: - self.bottleneck_params_list.append( - (self.multiply[tokens[5]], self.filter_num1[tokens[6]], - self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) - if self.block_num >= 4: - self.bottleneck_params_list.append( - (self.multiply[tokens[9]], self.filter_num2[tokens[10]], - self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) - if self.block_num >= 5: - self.bottleneck_params_list.append( - (self.multiply[tokens[13]], self.filter_num3[tokens[14]], - self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) - self.bottleneck_params_list.append( - (self.multiply[tokens[17]], self.filter_num4[tokens[18]], - self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) - if self.block_num >= 6: - self.bottleneck_params_list.append( - (self.multiply[tokens[21]], self.filter_num5[tokens[22]], - self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) - self.bottleneck_params_list.append( - (self.multiply[tokens[25]], self.filter_num6[tokens[26]], - self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + self.bottleneck_params_list.append( + (1, self.head_num[tokens[0]], 1, 1, 3)) + self.bottleneck_params_list.append( + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[17]], self.filter_num4[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + self.bottleneck_params_list.append( + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) def _modify_bottle_params(output_stride=None): if output_stride is not None and output_stride % 2 != 0: @@ -160,9 +140,11 @@ class MobileNetV2Space(SearchSpaceBase): self.bottleneck_params_list[i] = (t, c, n, s, ks) def net_arch(input, + scale = 1.0, + return_block = [], end_points=None, - decode_points=None, output_stride=None): + self.scale = scale _modify_bottle_params(output_stride) decode_ends = dict() @@ -185,9 +167,9 @@ class MobileNetV2Space(SearchSpaceBase): stride=2, padding='SAME', act='relu6', - name='mobilenetv2_conv1_1') + name='mobilenetv2_conv1') layer_count = 1 - if check_points(layer_count, decode_points): + if check_points(layer_count, return_block): decode_ends[layer_count] = input if check_points(layer_count, end_points): @@ -212,8 +194,8 @@ class MobileNetV2Space(SearchSpaceBase): in_c = int(c * self.scale) layer_count += 1 - ### decode_points and end_points means block num - if check_points(layer_count, decode_points): + ### return_block and end_points means block num + if check_points(layer_count, return_block): decode_ends[layer_count] = depthwise_output if check_points(layer_count, end_points): @@ -232,25 +214,10 @@ class MobileNetV2Space(SearchSpaceBase): input = fluid.layers.pool2d( input=input, - pool_size=7, - pool_stride=1, pool_type='avg', global_pooling=True, name='mobilenetv2_last_pool') - # if output_size is 1, add fc layer in the end - if self.output_size == 1: - input = fluid.layers.fc( - input=input, - size=self.class_dim, - param_attr=ParamAttr(name='mobilenetv2_fc_weights'), - bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) - else: - assert self.output_size == input.shape[2], \ - ("output_size must EQUAL to input_size / (2^block_num)." - "But receive input_size={}, output_size={}, block_num={}".format( - self.input_size, self.output_size, self.block_num)) - return input return net_arch diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py index fd761d41..842c3e0c 100644 --- a/paddleslim/nas/search_space/resnet.py +++ b/paddleslim/nas/search_space/resnet.py @@ -32,12 +32,9 @@ class ResNetSpace(SearchSpaceBase): input_size, output_size, block_num, - block_mask=None, - extract_feature=False, - class_dim=1000): + block_mask=None): super(ResNetSpace, self).__init__(input_size, output_size, block_num, block_mask) - assert self.block_mask == None, 'ResNetSpace will use origin ResNet as seach space, so use input_size, output_size and block_num to search' # self.filter_num1 ~ self.filter_num4 means convolution channel self.filter_num1 = np.array([48, 64, 96, 128, 160, 192, 224]) #7 self.filter_num2 = np.array([64, 96, 128, 160, 192, 256, 320]) #7 @@ -48,31 +45,24 @@ class ResNetSpace(SearchSpaceBase): self.repeat2 = [2, 3, 4, 5, 6, 7] #6 self.repeat3 = [2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24] #13 self.repeat4 = [2, 3, 4, 5, 6, 7] #6 - self.class_dim = class_dim - self.extract_feature = extract_feature - assert self.block_num < 5, 'ResNet: block number must less than 5, but receive block number is {}'.format( - self.block_num) def init_tokens(self): """ The initial token. - return 2 * self.block_num, 2 means depth and num_filter """ init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] - self.token_len = self.block_num * 2 - return init_token_base[:self.token_len] + return init_token_base def range_table(self): """ Get range table of current search space, constrains the range of tokens. """ - #2 * self.block_num, 2 means depth and num_filter range_table_base = [ len(self.filter_num1), len(self.repeat1), len(self.filter_num2), len(self.repeat2), len(self.filter_num3), len(self.repeat3), len(self.filter_num4), len(self.repeat4) ] - return range_table_base[:self.token_len] + return range_table_base def token2arch(self, tokens=None): """ @@ -83,26 +73,23 @@ class ResNetSpace(SearchSpaceBase): depth = [] num_filters = [] - if self.block_num >= 1: - filter1 = self.filter_num1[tokens[0]] - repeat1 = self.repeat1[tokens[1]] - num_filters.append(filter1) - depth.append(repeat1) - if self.block_num >= 2: - filter2 = self.filter_num2[tokens[2]] - repeat2 = self.repeat2[tokens[3]] - num_filters.append(filter2) - depth.append(repeat2) - if self.block_num >= 3: - filter3 = self.filter_num3[tokens[4]] - repeat3 = self.repeat3[tokens[5]] - num_filters.append(filter3) - depth.append(repeat3) - if self.block_num >= 4: - filter4 = self.filter_num4[tokens[6]] - repeat4 = self.repeat4[tokens[7]] - num_filters.append(filter4) - depth.append(repeat4) + + filter1 = self.filter_num1[tokens[0]] + repeat1 = self.repeat1[tokens[1]] + num_filters.append(filter1) + depth.append(repeat1) + filter2 = self.filter_num2[tokens[2]] + repeat2 = self.repeat2[tokens[3]] + num_filters.append(filter2) + depth.append(repeat2) + filter3 = self.filter_num3[tokens[4]] + repeat3 = self.repeat3[tokens[5]] + num_filters.append(filter3) + depth.append(repeat3) + filter4 = self.filter_num4[tokens[6]] + repeat4 = self.repeat4[tokens[7]] + num_filters.append(filter4) + depth.append(repeat4) def net_arch(input): conv = conv_bn_layer( @@ -120,16 +107,6 @@ class ResNetSpace(SearchSpaceBase): stride=2 if i == 0 and block != 0 else 1, name='resnet_depth{}_block{}'.format(i, block)) - if self.output_size == 1: - conv = fluid.layers.fc( - input=conv, - size=self.class_dim, - act=None, - param_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.NormalInitializer(0.0, - 0.01)), - bias_attr=fluid.param_attr.ParamAttr( - initializer=fluid.initializer.ConstantInitializer(0))) return conv diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index 53799154..9dee1431 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from ...common import get_logger + __all__ = ['SearchSpaceBase'] +_logger = get_logger(__name__, level=logging.INFO) class SearchSpaceBase(object): """Controller for Neural Architecture Search. @@ -29,12 +33,10 @@ class SearchSpaceBase(object): if self.block_mask != None: assert isinstance(self.block_mask, list), 'Block_mask must be a list.' - print( + _logger.warn( "If block_mask is NOT None, we will use block_mask as major configs!" ) self.block_num = None - if self.block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None): - print("block_mask and (block num or input_size or output_size) can NOT be None at the same time!") def init_tokens(self): """Get init tokens in search space. diff --git a/paddleslim/nas/search_space/utils.py b/paddleslim/nas/search_space/utils.py index daed1d59..1338e387 100644 --- a/paddleslim/nas/search_space/utils.py +++ b/paddleslim/nas/search_space/utils.py @@ -1,3 +1,17 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math def compute_downsample_num(input_size, output_size): -- GitLab