diff --git a/demo/sensitive_prune/train.py b/demo/sensitive_prune/train.py new file mode 100644 index 0000000000000000000000000000000000000000..a3a37e315f3b15bbcdcadb35b0304bbdc5d84254 --- /dev/null +++ b/demo/sensitive_prune/train.py @@ -0,0 +1,214 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +import paddle.fluid as fluid +from paddleslim.prune import SensitivePruner +from paddleslim.common import get_logger +from paddleslim.analysis import flops +sys.path.append(sys.path[0] + "/../") +import models +from utility import add_arguments, print_arguments + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64 * 4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('config_file', str, None, "The config file for compression with yaml format.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 10, "Log period in batches.") +add_arg('test_period', int, 10, "Test period in epoches.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + + +def compress(args): + + train_reader = None + test_reader = None + if args.data == "mnist": + import paddle.dataset.mnist as reader + train_reader = reader.train() + val_reader = reader.test() + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_reader = reader.train() + val_reader = reader.val() + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=True) + opt = create_optimizer(args) + opt.minimize(avg_cost) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists( + os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) + + val_reader = paddle.batch(val_reader, batch_size=args.batch_size) + train_reader = paddle.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + + train_feeder = feeder = fluid.DataFeeder([image, label], place) + val_feeder = feeder = fluid.DataFeeder( + [image, label], place, program=val_program) + + def test(epoch, program): + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + for data in val_reader(): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed=train_feeder.feed(data), + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + batch_id += 1 + + _logger.info( + "Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean( + np.array(acc_top5_ns)))) + return np.mean(np.array(acc_top1_ns)) + + def train(epoch, program): + + build_strategy = fluid.BuildStrategy() + exec_strategy = fluid.ExecutionStrategy() + train_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + batch_id = 0 + for data in train_reader(): + start_time = time.time() + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed=train_feeder.feed(data), + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() + loss_n = np.mean(loss_n) + acc_top1_n = np.mean(acc_top1_n) + acc_top5_n = np.mean(acc_top5_n) + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}". + format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, + end_time - start_time)) + batch_id += 1 + + params = [] + for param in fluid.default_main_program().global_block().all_parameters(): + if "_sep_weights" in param.name: + params.append(param.name) + + def eval_func(program): + return test(0, program) + + pruner = SensitivePruner(place, eval_func) + + if args.data == "mnist": + train(0, fluid.default_main_program()) + pruned_program = fluid.default_main_program() + pruned_val_program = val_program + for iter in range(6): + pruned_program, pruned_val_program = pruner.prune( + pruned_program, pruned_val_program, params, 0.1) + train(iter, pruned_program) + test(iter, pruned_val_program) + + print("before flops: {}".format(flops(fluid.default_main_program()))) + print("after flops: {}".format(flops(pruned_val_program))) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/analysis/sensitive.py index 09dd2a875ae21caf64034cf79421d7cc1661b817..ca9ee6f4ae7a790481a8e3b46c03cf37d096b3dc 100644 --- a/paddleslim/analysis/sensitive.py +++ b/paddleslim/analysis/sensitive.py @@ -17,6 +17,7 @@ import os import logging import pickle import numpy as np +import paddle.fluid as fluid from ..core import GraphWrapper from ..common import get_logger from ..prune import Pruner @@ -27,13 +28,12 @@ __all__ = ["sensitivity"] def sensitivity(program, - scope, place, param_names, eval_func, sensitivities_file=None, step_size=0.2): - + scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = _load_sensitivities(sensitivities_file) @@ -55,7 +55,7 @@ def sensitivity(program, ratio += step_size continue if baseline is None: - baseline = eval_func(graph.program, scope) + baseline = eval_func(graph.program) param_backup = {} pruner = Pruner() @@ -68,7 +68,7 @@ def sensitivity(program, lazy=True, only_graph=False, param_backup=param_backup) - pruned_metric = eval_func(pruned_program, scope) + pruned_metric = eval_func(pruned_program) loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, loss)) @@ -81,7 +81,7 @@ def sensitivity(program, param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) ratio += step_size - return sensitivities + return sensitivities def _load_sensitivities(sensitivities_file): diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 98b314ab6d144924bff6b68e3fb176ce73583f5c..2794cd4d86c0996155fd8d6e9dd830cdc8775e09 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -23,6 +23,8 @@ import controller_client from controller_client import * import lock_utils from lock_utils import * +import cached_reader as cached_reader_module +from cached_reader import * __all__ = [] __all__ += controller.__all__ @@ -30,3 +32,4 @@ __all__ += sa_controller.__all__ __all__ += controller_server.__all__ __all__ += controller_client.__all__ __all__ += lock_utils.__all__ +__all__ += cached_reader_module.__all__ diff --git a/paddleslim/common/cached_reader.py b/paddleslim/common/cached_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..55f27054efe55d9df90352b3e707fe51c8996023 --- /dev/null +++ b/paddleslim/common/cached_reader.py @@ -0,0 +1,57 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import numpy as np +from .log_helper import get_logger + +__all__ = ['cached_reader'] + +_logger = get_logger(__name__, level=logging.INFO) + + +def cached_reader(reader, sampled_rate, cache_path, cached_id): + """ + Sample partial data from reader and cache them into local file system. + Args: + reader: Iterative data source. + sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. + cache_path(str): The path to cache the sampled data. + cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. + """ + np.random.seed(cached_id) + cache_path = os.path.join(cache_path, str(cached_id)) + _logger.debug('read data from: {}'.format(cache_path)) + + def s_reader(): + if os.path.isdir(cache_path): + for file_name in open(os.path.join(cache_path, "list")): + yield np.load( + os.path.join(cache_path, file_name.strip()), + allow_pickle=True) + else: + os.makedirs(cache_path) + list_file = open(os.path.join(cache_path, "list"), 'w') + batch = 0 + dtype = None + for data in reader(): + if batch == 0 or (np.random.uniform() < sampled_rate): + np.save( + os.path.join(cache_path, 'batch' + str(batch)), data) + list_file.write('batch' + str(batch) + '.npy\n') + batch += 1 + yield data + + return s_reader diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index bb615b9dfca03ed2b289f902f6d75c73543f6fb2..f8f87862f7c0e9c09c23b753be600eed5c915a90 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -19,9 +19,12 @@ import controller_server from controller_server import * import controller_client from controller_client import * +import sensitive_pruner +from sensitive_pruner import * __all__ = [] __all__ += pruner.__all__ __all__ += auto_pruner.__all__ __all__ += controller_server.__all__ __all__ += controller_client.__all__ +__all__ += sensitive_pruner.__all__ diff --git a/paddleslim/prune/sensitive_pruner.py b/paddleslim/prune/sensitive_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..03f33d4059ac16a47e3deb60a12e234e174e5973 --- /dev/null +++ b/paddleslim/prune/sensitive_pruner.py @@ -0,0 +1,162 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import copy +from scipy.optimize import leastsq +import numpy as np +import paddle.fluid as fluid +from ..common import get_logger +from ..analysis import sensitivity +from ..analysis import flops +from .pruner import Pruner + +__all__ = ["SensitivePruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SensitivePruner(object): + def __init__(self, place, eval_func, scope=None): + """ + Pruner used to prune parameters iteratively according to sensitivities of parameters in each step. + Args: + place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute. + eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program. + scope(fluid.scope): The scope used to execute program. + """ + self._eval_func = eval_func + self._iter = 0 + self._place = place + self._scope = fluid.global_scope() if scope is None else scope + self._pruner = Pruner() + + def prune(self, train_program, eval_program, params, pruned_flops): + """ + Pruning parameters of training and evaluation network by sensitivities in current step. + Args: + train_program(fluid.Program): The training program to be pruned. + eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters. + params(list): The parameters to be pruned. + pruned_flops(float): The ratio of FLOPS to be pruned in current step. + Return: + tuple: A tuple of pruned training program and pruned evaluation program. + """ + _logger.info("Pruning: {}".format(params)) + sensitivities_file = "sensitivities_iter{}.data".format(self._iter) + with fluid.scope_guard(self._scope): + sensitivities = sensitivity( + eval_program, + self._place, + params, + self._eval_func, + sensitivities_file=sensitivities_file, + step_size=0.1) + print sensitivities + _, ratios = self._get_ratios_by_sensitive(sensitivities, pruned_flops, + eval_program) + + pruned_program = self._pruner.prune( + train_program, + self._scope, + params, + ratios, + place=self._place, + only_graph=False) + pruned_val_program = None + if eval_program is not None: + pruned_val_program = self._pruner.prune( + eval_program, + self._scope, + params, + ratios, + place=self._place, + only_graph=True) + self._iter += 1 + return pruned_program, pruned_val_program + + def _get_ratios_by_sensitive(self, sensitivities, pruned_flops, + eval_program): + """ + Search a group of ratios for pruning target flops. + """ + + def func(params, x): + a, b, c, d = params + return a * x * x * x + b * x * x + c * x + d + + def error(params, x, y): + return func(params, x) - y + + def slove_coefficient(x, y): + init_coefficient = [10, 10, 10, 10] + coefficient, loss = leastsq(error, init_coefficient, args=(x, y)) + return coefficient + + min_loss = 0. + max_loss = 0. + + # step 1: fit curve by sensitivities + coefficients = {} + for param in sensitivities: + losses = np.array([0] * 5 + sensitivities[param]['loss']) + precents = np.array([0] * 5 + sensitivities[param][ + 'pruned_percent']) + coefficients[param] = slove_coefficient(precents, losses) + loss = np.max(losses) + max_loss = np.max([max_loss, loss]) + + # step 2: Find a group of ratios by binary searching. + base_flops = flops(eval_program) + ratios = [] + max_times = 20 + while min_loss < max_loss and max_times > 0: + loss = (max_loss + min_loss) / 2 + _logger.info( + '-----------Try pruned ratios while acc loss={}-----------'. + format(loss)) + ratios = [] + # step 2.1: Get ratios according to current loss + for param in sensitivities: + coefficient = copy.deepcopy(coefficients[param]) + coefficient[-1] = coefficient[-1] - loss + roots = np.roots(coefficient) + for root in roots: + min_root = 1 + if np.isreal(root) and root > 0 and root < 1: + selected_root = min(root.real, min_root) + ratios.append(selected_root) + _logger.info('Pruned ratios={}'.format( + [round(ratio, 3) for ratio in ratios])) + # step 2.2: Pruning by current ratios + param_shape_backup = {} + pruned_program = self._pruner.prune( + eval_program, + None, # scope + sensitivities.keys(), + ratios, + None, # place + only_graph=True) + pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops) + _logger.info('Pruned flops: {:.4f}'.format(pruned_ratio)) + + # step 2.3: Check whether current ratios is enough + if abs(pruned_ratio - pruned_flops) < 0.015: + break + if pruned_ratio > pruned_flops: + max_loss = loss + else: + min_loss = loss + max_times -= 1 + return sensitivities.keys(), ratios