From 88ef556a9994de31e9f6029778349a4d1b74489a Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 10 Jun 2019 20:44:00 +0800 Subject: [PATCH] Add Light-NAS demo. (#2322) * Add Light-NAS demo. * Update arguments. --- PaddleSlim/light_nas/compress.yaml | 23 ++ PaddleSlim/light_nas/data | 1 + PaddleSlim/light_nas/light_nas_space.py | 277 +++++++++++++++++++ PaddleSlim/light_nas/run.sh | 5 + PaddleSlim/light_nas/search.py | 58 ++++ PaddleSlim/models/__init__.py | 1 + PaddleSlim/models/light_nasnet.py | 339 ++++++++++++++++++++++++ 7 files changed, 704 insertions(+) create mode 100644 PaddleSlim/light_nas/compress.yaml create mode 120000 PaddleSlim/light_nas/data create mode 100644 PaddleSlim/light_nas/light_nas_space.py create mode 100644 PaddleSlim/light_nas/run.sh create mode 100644 PaddleSlim/light_nas/search.py create mode 100644 PaddleSlim/models/light_nasnet.py diff --git a/PaddleSlim/light_nas/compress.yaml b/PaddleSlim/light_nas/compress.yaml new file mode 100644 index 00000000..5ad1a3dc --- /dev/null +++ b/PaddleSlim/light_nas/compress.yaml @@ -0,0 +1,23 @@ +version: 1.0 +controllers: + sa_controller: + class: 'SAController' + reduce_rate: 0.85 + init_temperature: 10.24 + max_iter_number: 300 +strategies: + light_nas_strategy: + class: 'LightNASStrategy' + controller: 'sa_controller' + target_flops: 592948064 + end_epoch: 500 + retrain_epoch: 5 + metric_name: 'acc_top1' + server_ip: '' + server_port: 8871 + is_server: True + search_steps: 100 +compressor: + epoch: 500 + strategies: + - light_nas_strategy diff --git a/PaddleSlim/light_nas/data b/PaddleSlim/light_nas/data new file mode 120000 index 00000000..4909e06e --- /dev/null +++ b/PaddleSlim/light_nas/data @@ -0,0 +1 @@ +../data \ No newline at end of file diff --git a/PaddleSlim/light_nas/light_nas_space.py b/PaddleSlim/light_nas/light_nas_space.py new file mode 100644 index 00000000..63d4473a --- /dev/null +++ b/PaddleSlim/light_nas/light_nas_space.py @@ -0,0 +1,277 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.contrib.slim.nas import SearchSpace +import paddle.fluid as fluid +import paddle +import sys +sys.path.append('..') +from models import LightNASNet +import reader + +total_images = 1281167 +lr = 0.1 +num_epochs = 1 +batch_size = 128 +lr_strategy = "cosine_decay" +l2_decay = 4e-5 +momentum_rate = 0.9 +image_shape = [3, 224, 224] +class_dim = 1000 + +__all__ = ['LightNASSpace'] + +NAS_FILTER_SIZE = [[18, 24, 30], [24, 32, 40], [48, 64, 80], [72, 96, 120], + [120, 160, 192]] +NAS_LAYERS_NUMBER = [[1, 2, 3], [2, 3, 4], [3, 4, 5], [2, 3, 4], [2, 3, 4]] +NAS_KERNEL_SIZE = [3, 5] +NAS_FILTERS_MULTIPLIER = [3, 4, 5, 6] +NAS_SHORTCUT = [0, 1] +NAS_SE = [0, 1] + + +def get_bottleneck_params_list(var): + """Get bottleneck_params_list from var. + Args: + var: list, variable list. + Returns: + list, bottleneck_params_list. + """ + params_list = [ + 1, 16, 1, 1, 3, 1, 0, \ + 6, 24, 2, 2, 3, 1, 0, \ + 6, 32, 3, 2, 3, 1, 0, \ + 6, 64, 4, 2, 3, 1, 0, \ + 6, 96, 3, 1, 3, 1, 0, \ + 6, 160, 3, 2, 3, 1, 0, \ + 6, 320, 1, 1, 3, 1, 0, \ + ] + for i in range(5): + params_list[i * 7 + 7] = NAS_FILTERS_MULTIPLIER[var[i * 6]] + params_list[i * 7 + 8] = NAS_FILTER_SIZE[i][var[i * 6 + 1]] + params_list[i * 7 + 9] = NAS_LAYERS_NUMBER[i][var[i * 6 + 2]] + params_list[i * 7 + 11] = NAS_KERNEL_SIZE[var[i * 6 + 3]] + params_list[i * 7 + 12] = NAS_SHORTCUT[var[i * 6 + 4]] + params_list[i * 7 + 13] = NAS_SE[var[i * 6 + 5]] + return params_list + + +class LightNASSpace(SearchSpace): + def __init__(self): + super(LightNASSpace, self).__init__() + + def init_tokens(self): + """Get init tokens in search space. + """ + return [ + 0, 1, 2, 0, 1, 0, 0, 2, 1, 1, 1, 0, 3, 2, 0, 1, 1, 0, 3, 1, 0, 0, 1, + 0, 3, 2, 2, 1, 1, 0 + ] + + def range_table(self): + """Get range table of current search space. + """ + # [NAS_FILTER_SIZE, NAS_LAYERS_NUMBER, NAS_KERNEL_SIZE, NAS_FILTERS_MULTIPLIER, NAS_SHORTCUT, NAS_SE] + return [ + 4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2, + 2, 4, 3, 3, 2, 2, 2 + ] + + def create_net(self, tokens=None): + """Create a network for training by tokens. + """ + if tokens is None: + tokens = self.init_tokens() + + bottleneck_params_list = get_bottleneck_params_list(tokens) + + startup_prog = fluid.Program() + train_prog = fluid.Program() + test_prog = fluid.Program() + train_py_reader, train_cost, train_acc1, train_acc5, global_lr = build_program( + is_train=True, + main_prog=train_prog, + startup_prog=startup_prog, + bottleneck_params_list=bottleneck_params_list) + test_py_reader, test_cost, test_acc1, test_acc5 = build_program( + is_train=False, + main_prog=test_prog, + startup_prog=startup_prog, + bottleneck_params_list=bottleneck_params_list) + test_prog = test_prog.clone(for_test=True) + train_batch_size = batch_size + test_batch_size = batch_size + train_reader = paddle.batch( + reader.train(), batch_size=train_batch_size, drop_last=True) + test_reader = paddle.batch(reader.val(), batch_size=test_batch_size) + + with fluid.program_guard(train_prog, startup_prog): + train_py_reader.decorate_paddle_reader(train_reader) + + with fluid.program_guard(test_prog, startup_prog): + test_py_reader.decorate_paddle_reader(test_reader) + return startup_prog, train_prog, test_prog, ( + train_cost, train_acc1, train_acc5, + global_lr), (test_cost, test_acc1, + test_acc5), train_py_reader, test_py_reader + + +def build_program(is_train, + main_prog, + startup_prog, + bottleneck_params_list=None): + with fluid.program_guard(main_prog, startup_prog): + py_reader = fluid.layers.py_reader( + capacity=16, + shapes=[[-1] + image_shape, [-1, 1]], + lod_levels=[0, 0], + dtypes=["float32", "int64"], + use_double_buffer=False) + with fluid.unique_name.guard(): + image, label = fluid.layers.read_file(py_reader) + model = LightNASNet() + avg_cost, acc_top1, acc_top5 = net_config( + image, + label, + model, + class_dim=class_dim, + bottleneck_params_list=bottleneck_params_list, + scale_loss=1.0) + + avg_cost.persistable = True + acc_top1.persistable = True + acc_top5.persistable = True + if is_train: + params = model.params + params["total_images"] = total_images + params["lr"] = lr + params["num_epochs"] = num_epochs + params["learning_strategy"]["batch_size"] = batch_size + params["learning_strategy"]["name"] = lr_strategy + params["l2_decay"] = l2_decay + params["momentum_rate"] = momentum_rate + optimizer = optimizer_setting(params) + optimizer.minimize(avg_cost) + global_lr = optimizer._global_learning_rate() + + if is_train: + return py_reader, avg_cost, acc_top1, acc_top5, global_lr + else: + return py_reader, avg_cost, acc_top1, acc_top5 + + +def net_config(image, + label, + model, + class_dim=1000, + bottleneck_params_list=None, + scale_loss=1.0): + bottleneck_params_list = [ + bottleneck_params_list[i:i + 7] + for i in range(0, len(bottleneck_params_list), 7) + ] + out = model.net(input=image, + bottleneck_params_list=bottleneck_params_list, + class_dim=class_dim) + cost, pred = fluid.layers.softmax_with_cross_entropy( + out, label, return_softmax=True) + if scale_loss > 1: + avg_cost = fluid.layers.mean(x=cost) * float(scale_loss) + else: + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5) + return avg_cost, acc_top1, acc_top5 + + +def optimizer_setting(params): + """optimizer setting. + Args: + params: dict, params. + """ + ls = params["learning_strategy"] + l2_decay = params["l2_decay"] + momentum_rate = params["momentum_rate"] + if ls["name"] == "piecewise_decay": + if "total_images" not in params: + total_images = IMAGENET1000 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + step = int(total_images / batch_size + 1) + bd = [step * e for e in ls["epochs"]] + base_lr = params["lr"] + lr = [] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=momentum_rate, + regularization=fluid.regularizer.L2Decay(l2_decay)) + elif ls["name"] == "cosine_decay": + if "total_images" not in params: + total_images = IMAGENET1000 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + step = int(total_images / batch_size + 1) + lr = params["lr"] + num_epochs = params["num_epochs"] + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.cosine_decay( + learning_rate=lr, step_each_epoch=step, epochs=num_epochs), + momentum=momentum_rate, + regularization=fluid.regularizer.L2Decay(l2_decay)) + elif ls["name"] == "cosine_warmup_decay": + if "total_images" not in params: + total_images = IMAGENET1000 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + l2_decay = params["l2_decay"] + momentum_rate = params["momentum_rate"] + step = int(math.ceil(float(total_images) / batch_size)) + lr = params["lr"] + num_epochs = params["num_epochs"] + optimizer = fluid.optimizer.Momentum( + learning_rate=cosine_decay_with_warmup( + learning_rate=lr, step_each_epoch=step, epochs=num_epochs), + momentum=momentum_rate, + regularization=fluid.regularizer.L2Decay(l2_decay)) + elif ls["name"] == "linear_decay": + if "total_images" not in params: + total_images = IMAGENET1000 + else: + total_images = params["total_images"] + batch_size = ls["batch_size"] + num_epochs = params["num_epochs"] + start_lr = params["lr"] + end_lr = 0 + total_step = int((total_images / batch_size) * num_epochs) + lr = fluid.layers.polynomial_decay( + start_lr, total_step, end_lr, power=1) + optimizer = fluid.optimizer.Momentum( + learning_rate=lr, + momentum=momentum_rate, + regularization=fluid.regularizer.L2Decay(l2_decay)) + elif ls["name"] == "adam": + lr = params["lr"] + optimizer = fluid.optimizer.Adam(learning_rate=lr) + else: + lr = params["lr"] + optimizer = fluid.optimizer.Momentum( + learning_rate=lr, + momentum=momentum_rate, + regularization=fluid.regularizer.L2Decay(l2_decay)) + return optimizer diff --git a/PaddleSlim/light_nas/run.sh b/PaddleSlim/light_nas/run.sh new file mode 100644 index 00000000..61eef5b6 --- /dev/null +++ b/PaddleSlim/light_nas/run.sh @@ -0,0 +1,5 @@ +# enable GC strategy +export FLAGS_fast_eager_deletion_mode=1 +export FLAGS_eager_delete_tensor_gb=0.0 +export CUDA_VISIBLE_DEVICES=0,1 +python search.py diff --git a/PaddleSlim/light_nas/search.py b/PaddleSlim/light_nas/search.py new file mode 100644 index 00000000..b8174722 --- /dev/null +++ b/PaddleSlim/light_nas/search.py @@ -0,0 +1,58 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import paddle.fluid as fluid +from paddle.fluid.contrib.slim.core import Compressor +from light_nas_space import LightNASSpace + + +def search(): + if not fluid.core.is_compiled_with_cuda(): + return + + space = LightNASSpace() + + startup_prog, train_prog, test_prog, train_metrics, test_metrics, train_reader, test_reader = space.create_net( + ) + train_cost, train_acc1, train_acc5, global_lr = train_metrics + test_cost, test_acc1, test_acc5 = test_metrics + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_prog) + + val_fetch_list = [('acc_top1', test_acc1.name), ('acc_top5', + test_acc5.name)] + train_fetch_list = [('loss', train_cost.name)] + + com_pass = Compressor( + place, + fluid.global_scope(), + train_prog, + train_reader=train_reader, + train_feed_list=None, + train_fetch_list=train_fetch_list, + eval_program=test_prog, + eval_reader=test_reader, + eval_feed_list=None, + eval_fetch_list=val_fetch_list, + train_optimizer=None, + search_space=space) + com_pass.config('./compress.yaml') + eval_graph = com_pass.run() + + +if __name__ == '__main__': + search() diff --git a/PaddleSlim/models/__init__.py b/PaddleSlim/models/__init__.py index 2141eb9a..2a51d746 100644 --- a/PaddleSlim/models/__init__.py +++ b/PaddleSlim/models/__init__.py @@ -1,3 +1,4 @@ from .mobilenet import MobileNet from .resnet import ResNet50, ResNet101, ResNet152 from .googlenet import GoogleNet +from .light_nasnet import LightNASNet diff --git a/PaddleSlim/models/light_nasnet.py b/PaddleSlim/models/light_nasnet.py new file mode 100644 index 00000000..3c741e90 --- /dev/null +++ b/PaddleSlim/models/light_nasnet.py @@ -0,0 +1,339 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LightNASNet.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['LightNASNet'] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class LightNASNet(object): + """LightNASNet.""" + + def __init__(self): + self.params = train_parameters + + def net(self, input, bottleneck_params_list=None, class_dim=1000, + scale=1.0): + """Build network. + Args: + input: Variable, input. + class_dim: int, class dim. + scale: float, scale. + Returns: + Variable, network output. + """ + if bottleneck_params_list is None: + # MobileNetV2 + # bottleneck_params_list = [ + # (1, 16, 1, 1, 3, 1, 0), + # (6, 24, 2, 2, 3, 1, 0), + # (6, 32, 3, 2, 3, 1, 0), + # (6, 64, 4, 2, 3, 1, 0), + # (6, 96, 3, 1, 3, 1, 0), + # (6, 160, 3, 2, 3, 1, 0), + # (6, 320, 1, 1, 3, 1, 0), + # ] + bottleneck_params_list = [ + (1, 16, 1, 1, 3, 1, 0), + (3, 24, 3, 2, 3, 1, 0), + (3, 40, 3, 2, 5, 1, 0), + (6, 80, 3, 2, 5, 1, 0), + (6, 96, 2, 1, 3, 1, 0), + (6, 192, 4, 2, 5, 1, 0), + (6, 320, 1, 1, 3, 1, 0), + ] + + #conv1 + input = self.conv_bn_layer( + input, + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1, + if_act=True, + name='conv1_1') + + # bottleneck sequences + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s, k, ifshortcut, ifse = layer_setting + i += 1 + input = self.invresi_blocks( + input=input, + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s, + k=k, + ifshortcut=ifshortcut, + ifse=ifse, + name='conv' + str(i)) + in_c = int(c * scale) + #last_conv + input = self.conv_bn_layer( + input=input, + num_filters=int(1280 * scale) if scale > 1.0 else 1280, + filter_size=1, + stride=1, + padding=0, + if_act=True, + name='conv9') + + input = fluid.layers.pool2d( + input=input, + pool_size=7, + pool_stride=1, + pool_type='avg', + global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + param_attr=ParamAttr(name='fc10_weights'), + bias_attr=ParamAttr(name='fc10_offset')) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + num_groups=1, + if_act=True, + name=None, + use_cudnn=True): + """Build convolution and batch normalization layers. + Args: + input: Variable, input. + filter_size: int, filter size. + num_filters: int, number of filters. + stride: int, stride. + padding: int, padding. + num_groups: int, number of groups. + if_act: bool, whether using activation. + name: str, name. + use_cudnn: bool, whether use cudnn. + Returns: + Variable, layers output. + """ + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + if if_act: + return fluid.layers.relu6(bn) + else: + return bn + + def shortcut(self, input, data_residual): + """Build shortcut layer. + Args: + input: Variable, input. + data_residual: Variable, residual layer. + Returns: + Variable, layer output. + """ + return fluid.layers.elementwise_add(input, data_residual) + + def squeeze_excitation(self, + input, + num_channels, + reduction_ratio, + name=None): + """Build squeeze excitation layers. + Args: + input: Variable, input. + num_channels: int, number of channels. + reduction_ratio: float, reduction ratio. + name: str, name. + Returns: + Variable, layers output. + """ + pool = fluid.layers.pool2d( + input=input, pool_size=0, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc( + input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_sqz_weights'), + bias_attr=ParamAttr(name=name + '_sqz_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc( + input=squeeze, + size=num_channels, + act='sigmoid', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_exc_weights'), + bias_attr=ParamAttr(name=name + '_exc_offset')) + scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return scale + + def inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + ifse, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): + """Build inverted residual unit. + Args: + input: Variable, input. + num_in_filter: int, number of in filters. + num_filters: int, number of filters. + ifshortcut: bool, whether using shortcut. + stride: int, stride. + filter_size: int, filter size. + padding: int, padding. + expansion_factor: float, expansion factor. + name: str, name. + Returns: + Variable, layers output. + """ + num_expfilter = int(round(num_in_filter * expansion_factor)) + channel_expand = self.conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name=name + '_expand') + + bottleneck_conv = self.conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=int((filter_size - 1) / 2), + num_groups=num_expfilter, + if_act=True, + name=name + '_dwise', + use_cudnn=False) + + linear_out = self.conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=False, + name=name + '_linear') + out = linear_out + if ifshortcut: + out = self.shortcut(input=input, data_residual=out) + if ifse: + scale = self.squeeze_excitation( + input=linear_out, + num_channels=num_filters, + reduction_ratio=reduction_ratio, + name=name + '_fc') + out = fluid.layers.elementwise_add(x=out, y=scale, act='relu') + return out + + def invresi_blocks(self, + input, + in_c, + t, + c, + n, + s, + k, + ifshortcut, + ifse, + name=None): + """Build inverted residual blocks. + Args: + input: Variable, input. + in_c: int, number of in filters. + t: float, expansion factor. + c: int, number of filters. + n: int, number of layers. + s: int, stride. + k: int, filter size. + ifshortcut: bool, if adding shortcut layers or not. + ifse: bool, if adding squeeze excitation layers or not. + name: str, name. + Returns: + Variable, layers output. + """ + first_block = self.inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + ifse=ifse, + stride=s, + filter_size=k, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block = self.inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=ifshortcut, + ifse=ifse, + stride=1, + filter_size=k, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block -- GitLab