diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6c61a3d63df3f1734cf4a27e7e27e6b954232af3..70a4d7b40b154ff80ff6d30adaa147556749e905 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -377,23 +377,9 @@ paddle.fluid.contrib.Calibrator.__init__ (ArgSpec(args=['self'], varargs='args', paddle.fluid.contrib.Calibrator.sample_data (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '3b8c85ca1e2cf753cc8c90a6c6992958')) paddle.fluid.contrib.Calibrator.save_int8_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.contrib.reader.ctr_reader.ctr_reader (ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b2ebf3de2a6ef1af2c3b88d2db7591ab')) -paddle.fluid.contrib.build_compressor (ArgSpec(args=['place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'config'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.CompressPass.__init__ (ArgSpec(args=['self', 'place', 'data_reader', 'data_feeder', 'scope', 'metrics', 'epoch', 'program_exe'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.CompressPass.add_strategy (ArgSpec(args=['self', 'strategy'], varargs=None, keywords=None, defaults=None), ('document', '3bf6010b6f47d3c86df0ec8957be95e0')) -paddle.fluid.contrib.CompressPass.apply (ArgSpec(args=['self', 'graph'], varargs=None, keywords=None, defaults=None), ('document', 'a92bf85d4b59bd4f2ac1706d7c4899a6')) -paddle.fluid.contrib.ImitationGraph.__init__ (ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.ImitationGraph.all_parameters (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.__init__ (ArgSpec(args=['self', 'pruner', 'start_epoch', 'end_epoch', 'delta_rate', 'acc_loss_threshold', 'sensitivities'], varargs=None, keywords=None, defaults=(None, 0, 10, 0.2, 0.2, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_batch_begin (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_batch_end (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_compress_begin (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_compress_end (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_epoch_begin (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.SensitivePruneStrategy.on_epoch_end (ArgSpec(args=['self', 'context'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.MagnitudePruner.__init__ (ArgSpec(args=['self', 'threshold'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.MagnitudePruner.prune (ArgSpec(args=['self', 'param', 'threshold'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.contrib.RatioPruner.__init__ (ArgSpec(args=['self', 'ratios'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e7a81a325b296a9ca502ee5adb4fc85d')) -paddle.fluid.contrib.RatioPruner.prune (ArgSpec(args=['self', 'param', 'ratio'], varargs=None, keywords=None, defaults=(None,)), ('document', '358cbf2978c91028fb96a195a9884645')) +paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], './checkpoints', None, None)), ('document', '31ae143830c9bf6b43547dd546c5ba80')) +paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0')) +paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9')) paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67')) paddle.fluid.contrib.load_persistables_for_inference (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None), ('document', '59066bac9db0ac6ce414d05780b7333f')) paddle.fluid.contrib.convert_dist_to_sparse_program (ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None), ('document', '74c39c595dc70d6be2f16d8e462d282b')) diff --git a/python/paddle/fluid/contrib/slim/__init__.py b/python/paddle/fluid/contrib/slim/__init__.py index 22dbf7c8b6bb2da7c310a20bdcbaffca248575b0..4a71fab6d0fc73aa3bbe9c9fe56278e473f354e1 100644 --- a/python/paddle/fluid/contrib/slim/__init__.py +++ b/python/paddle/fluid/contrib/slim/__init__.py @@ -13,13 +13,4 @@ # limitations under the License. from .core import * -from .graph import * -from .prune import * -__all__ = [ - 'build_compressor', - 'CompressPass', - 'ImitationGraph', - 'SensitivePruneStrategy', - 'MagnitudePruner', - 'RatioPruner', -] +__all__ = ['Compressor', ] diff --git a/python/paddle/fluid/contrib/slim/core/__init__.py b/python/paddle/fluid/contrib/slim/core/__init__.py index 7826d5830a6f7f6d42cb1275c2289695c080e52f..831bd70ecc62f8d576b304c52b0abea994fd2ceb 100644 --- a/python/paddle/fluid/contrib/slim/core/__init__.py +++ b/python/paddle/fluid/contrib/slim/core/__init__.py @@ -14,11 +14,9 @@ from . import config from .config import * -from . import compress_pass -from .compress_pass import * +from . import compressor +from .compressor import * from . import strategy from .strategy import * -from . import pass_builder -from .pass_builder import * -__all__ = config.__all__ + compress_pass.__all__ + strategy.__all__ + pass_builder.__all__ +__all__ = config.__all__ + compressor.__all__ + strategy.__all__ diff --git a/python/paddle/fluid/contrib/slim/core/compress_pass.py b/python/paddle/fluid/contrib/slim/core/compress_pass.py deleted file mode 100644 index c4c348b878a1df43d7fb909f506c8cf65366866f..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/core/compress_pass.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ....core import CPUPlace -from ..graph import get_executor - -__all__ = ['Context', 'CompressPass'] - - -class Context(object): - """ - The context in the process of compression. - Args: - exe: The executor used to execute graph. - graph: The graph to be compressed. - scope: The scope used to execute graph. - program_exe: The program_exe is used to execute the program - created for modifying the variables in scope. - """ - - def __init__(self, exe, graph, scope, program_exe=None): - # The total number of epoches to be trained. - self.epoch = 0 - # Current epoch - self.epoch_id = 0 - # Current batch - self.batch_id = 0 - self.exe = exe - self.graph = graph - self.scope = scope - self.program_exe = program_exe - - -class CompressPass(object): - """ - The pass used to compress model. - Args: - place: The device used in compression. - data_reader: The data_reader used to run graph. - data_feeder: The data_feeder used to run graph. - scope: The scope used to run graph. - metrics: The metrics for evaluating model. - epoch: The total epoches of trainning in compression. - program_exe: The program_exe is used to execute the program - created for modifying the variables in scope. - """ - - def __init__(self, - place=None, - data_reader=None, - data_feeder=None, - scope=None, - metrics=None, - epoch=None, - program_exe=None): - self.strategies = [] - self.place = CPUPlace() if place is None else place - self.data_reader = data_reader - self.data_feeder = data_feeder - self.scope = scope - self.metrics = metrics - self.epoch = epoch - self.program_exe = program_exe - - def add_strategy(self, strategy): - """ - Add a strategy to current compress pass. - Args: - strategy: The strategy to be added into current compress pass. - """ - self.strategies.append(strategy) - self.epoch = max(strategy.end_epoch, self.epoch) - - def apply(self, graph): - """ - Compress a model. - Args: - graph: The target graph to be compressed. - """ - self.executor = get_executor(graph, self.place) - context = Context( - self.executor, graph, self.scope, program_exe=self.program_exe) - - for strategy in self.strategies: - strategy.on_compress_begin(context) - - for epoch in range(self.epoch): - - for strategy in self.strategies: - strategy.on_epoch_begin(context) - - for data in self.data_reader(): - - for strategy in self.strategies: - strategy.on_batch_begin(context) - fetches = None - if self.metrics: - fetches = self.metrics.values() - feed = None - if self.data_feeder: - feed = self.data_feeder.feed(data) - results = self.executor.run(graph, - fetches=fetches, - scope=self.scope, - feed=feed) - if results: - print("results: {}".format( - zip(self.metrics.keys(), results))) - for strategy in self.strategies: - strategy.on_batch_end(context) - context.batch_id += 1 - - for strategy in self.strategies: - strategy.on_epoch_end(context) - context.epoch_id += 1 - - for strategy in self.strategies: - strategy.on_compress_end(context) diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py new file mode 100644 index 0000000000000000000000000000000000000000..832ade497c67ee16b6068cad4f0edace94128989 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -0,0 +1,481 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ....core import CPUPlace +from .... import compiler +from .... import io +from .... import profiler +from .... import scope_guard +from ....data_feeder import DataFeeder +from ..graph import * +from .config import ConfigFactory +import numpy as np +from collections import Iterable +import time +import os +import logging +import sys +import pickle +import functools + +__all__ = ['Context', 'Compressor'] + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def cached_reader(reader, sampled_rate, cache_path, cached_id): + """ + Sample partial data from reader and cache them into local file system. + Args: + reader: Iterative data source. + sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. + cache_path(str): The path to cache the sampled data. + cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. + """ + np.random.seed(cached_id) + cache_path = os.path.join(cache_path, str(cached_id)) + _logger.debug('read data from: {}'.format(cache_path)) + + def s_reader(): + if os.path.isdir(cache_path): + for file_name in open(os.path.join(cache_path, "list")): + yield np.load(os.path.join(cache_path, file_name.strip())) + else: + os.makedirs(cache_path) + list_file = open(os.path.join(cache_path, "list"), 'w') + batch = 0 + dtype = None + for data in reader(): + if batch == 0 or (np.random.uniform() < sampled_rate): + np.save( + os.path.join(cache_path, 'batch' + str(batch)), data) + list_file.write('batch' + str(batch) + '.npy\n') + batch += 1 + yield data + + return s_reader + + +class Context(object): + """ + The context in the process of compression. + """ + + def __init__(self, + place, + scope, + train_graph=None, + train_reader=None, + eval_graph=None, + eval_reader=None, + teacher_graphs=None, + train_optimizer=None, + distiller_optimizer=None): + """ + Args: + place: The device place where the compression job running. + scope: The scope used in compression job. + train_graph: The graph with loss as output node. + eval_graph: The graph used for evaluation. + eval_reader: The data reader used for evaluation. + teacher_graphs: The teacher graphs used in distillation strategies. + train_optimizer: The optimizer used to append backward ops and + optimization ops into train_graph. + distiller_optimizer: The optimizer used by distillation strategies. + """ + # The total number of epoches to be trained. + self.epoch = 0 + # Current epoch + self.epoch_id = 0 + # Current batch + self.batch_id = 0 + + self.k_v = {} + + self.place = place + self.scope = scope + self.train_graph = train_graph + self.train_reader = train_reader + self.eval_graph = eval_graph + self.eval_reader = eval_reader + self.executor = None + self.teacher_graphs = teacher_graphs + self.train_optimizer = train_optimizer + self.distiller_optimizer = distiller_optimizer + self.optimize_graph = None + self.cache_path = './eval_cache' + self.eval_results = {} + + def to_file(self, file_name): + """ + Save the context into file. + """ + data = {} + data['epoch_id'] = self.epoch_id + data['eval_results'] = self.eval_results + with open(file_name, 'wb') as context_file: + pickle.dump(data, context_file) + + def from_file(self, file_name): + """ + Load the context from file. + """ + with open(file_name) as context_file: + if sys.version_info < (3, 0): + data = pickle.load(context_file) + else: + data = pickle.load(context_file, encoding='bytes') + self.epoch_id = data['epoch_id'] + self.eval_results = data['eval_results'] + + def eval_converged(self, metric_name, delta=0.001): + """ + Check whether the training has been converged. + Args: + metric_name(str): The metric used to check convergence. + delta(float): '(metric[k] - metric[k-1] / metric[k-1]) < delta' + means that the training has been converged. + Returns: + bool: True means the training has been converged. + """ + # TODO(wanghaoshuang@baidu.com): enhence this method. + if (metric_name not in self.eval_results + ) or len(self.eval_results[metric_name]) < 2: + return False + results = self.eval_results[metric_name][-2:] + _logger.info('Latest evaluations: {}'.format(results)) + return abs(results[1] - results[0]) / results[0] < delta + + def run_eval_graph(self, sampled_rate=None, cached_id=0): + """ + Evaluate the current mode in context. + Args: + sampled_rate(float): The sampled rate used to sample partial data + for evaluation. None means using all data in eval_reader. default: None. + cached_id(int): The id of dataset sampled. Evaluations with same + cached_id use the same sampled dataset. default: 0. + """ + _logger.info('Running evaluation') + assert self.eval_graph is not None + assert self.eval_reader is not None + eval_graph = self.eval_graph.clone(for_test=True) + + executor = SlimGraphExecutor(self.place) + results = [] + batch_id = 0 + s_time = time.time() + reader = self.eval_reader + if sampled_rate: + reader = cached_reader(reader, sampled_rate, self.cache_path, + cached_id) + for data in reader(): + result = executor.run(eval_graph, self.scope, data=data) + result = [np.mean(r) for r in result] + results.append(result) + if batch_id % 20 == 0: + _logger.info("batch-{}; {}={}".format( + batch_id, eval_graph.out_nodes.keys(), result)) + batch_id += 1 + result = np.mean(np.array(results), axis=0) + _logger.info("Final eval result: {}={}".format( + eval_graph.out_nodes.keys(), result)) + if not isinstance(result, Iterable): + result = [result] + _logger.info('Finish evaluation') + return result, eval_graph.out_nodes.keys() + + def put(self, key, value): + self.k_v[key] = value + + def get(self, key): + return self.k_v.get(key) + + +class Compressor(object): + """ + The pass used to compress model. + """ + + def __init__(self, + place, + scope, + train_program, + train_reader=None, + train_feed_list=None, + train_fetch_list=None, + eval_program=None, + eval_reader=None, + eval_feed_list=None, + eval_fetch_list=None, + teacher_programs=[], + checkpoint_path='./checkpoints', + train_optimizer=None, + distiller_optimizer=None): + """ + Args: + place(fluid.Place): The device place where the compression job running. + scope(fluid.core.Scope): The scope used to run graph. + train_program(Program): The main program to be compressed. It must have loss op. + train_reader: The data reader used for training. + train_feed_list(dict): A dict to indicate the input variable of the training program. + The key is user-defined and human-readable name. + The value is the name of Variable. + train_fetch_list(dict): A dict to indicate the output variable of the training program. + The key is user-defined and human-readable name. + The value is the name of Variable. + eval_program(Program): The program used for evaluation. + eval_reader: The data reader used for evaluation. + eval_feed_list(dict): A dict to indicate the input variable of the evaluation program. + The key is user-defined and human-readable name. + The value is the name of Variable. + eval_fetch_list(dict): A dict to indicate the output variable of the evaluation program. + The key is user-defined and human-readable name. + The value is the name of Variable. + teacher_programs: The teacher graphs used in distillation strategies. + train_optimizer: The optimizer used to append backward ops and + optimization ops into train_graph. + distiller_optimizer: The optimizer used by distillation strategies. In distillation strategy, + this optimizer is used to minimize the combined loss of student-net and + teacher-net while train_optimizer is used to minimize loss of + student-net in fine-tune stage. + + """ + assert isinstance( + train_feed_list, list + ), "train_feed_list should be a list of tuple, such as [('image', image.name), ('label', gt.name)]" + assert isinstance( + eval_feed_list, list + ), "eval_feed_list should be a list of tuple, such as [('image', image.name), ('label', gt.name)]" + self.strategies = [] + self.epoch = 0 + self.place = CPUPlace() if place is None else place + self.scope = scope + self.train_graph = GraphWrapper( + train_program, in_nodes=train_feed_list, out_nodes=train_fetch_list) + self.eval_graph = GraphWrapper( + eval_program, in_nodes=eval_feed_list, out_nodes=eval_fetch_list) + self.train_reader = train_reader + self.eval_reader = eval_reader + self.teacher_graphs = [] + for teacher in teacher_programs: + self.teacher_graphs.append(ImitationGraph(teacher, scope=scope)) + + self.checkpoint = None + self.checkpoint_path = checkpoint_path + self.eval_epoch = 1 + + self.train_optimizer = train_optimizer + self.distiller_optimizer = distiller_optimizer + self.init_model = None + + def _add_strategy(self, strategy): + """ + Add a strategy to current compress pass. + Args: + strategy: The strategy to be added into current compress pass. + """ + self.strategies.append(strategy) + self.epoch = max(strategy.end_epoch, self.epoch) + + def config(self, config_file): + """ + Configure the compress pass from file with yaml format. + Args: + config_file(str): The config file in local file system. + """ + factory = ConfigFactory(config_file) + self.epoch = factory.compressor['epoch'] + for strategy in factory.compressor['strategies']: + self._add_strategy(strategy) + if 'checkpoint_path' in factory.compressor: + self.checkpoint_path = factory.compressor['checkpoint_path'] + + if 'init_model' in factory.compressor: + self.init_model = factory.compressor['init_model'] + + def _init_model(self, context): + """ + Load model that has been compressed. + """ + if self.init_model and os.path.exists(self.init_model): + exe = SlimGraphExecutor(context.place) + with scope_guard(context.scope): + context.train_graph.load_persistables(self.init_model, exe) + flops = context.eval_graph.flops() + conv_flops = context.eval_graph.flops(only_conv=True) + context.eval_graph.update_param_shape(context.scope) + context.eval_graph.update_groups_of_conv() + _logger.info("conv flops: -{}".format(1 - float( + context.eval_graph.flops(only_conv=True)) / conv_flops)) + _logger.info("total flops: -{}".format(1 - float( + context.eval_graph.flops()) / flops)) + context.train_graph.update_param_shape(context.scope) + context.train_graph.update_groups_of_conv() + context.train_graph.infer_shape() + _logger.info("Init model from: {}".format(self.init_model)) + + def _load_checkpoint(self, context): + """ + Load checkpoints from file. + """ + _logger.debug('_load_checkpoint') + strategies = self.strategies + if self.checkpoint_path: + if not os.path.exists(self.checkpoint_path): + _logger.warning("Checkpints path doesn't exist: [{}]".format( + self.checkpoint_path)) + return context, strategies + checkpoints = [ + dir for dir in os.listdir(self.checkpoint_path) + if os.path.isdir(os.path.join(self.checkpoint_path, dir)) + ] + _logger.debug('self.checkpoint_path: {}'.format( + self.checkpoint_path)) + _logger.info('checkpoints: {}'.format(checkpoints)) + if len(checkpoints) > 0: + latest = max([int(ck) for ck in checkpoints]) + latest_ck_path = os.path.join(self.checkpoint_path, str(latest)) + + model_path = os.path.join(latest_ck_path, 'model') + context_path = os.path.join(latest_ck_path, 'context') + strategy_path = os.path.join(latest_ck_path, 'strategies') + if os.path.exists(context_path): + context.from_file(context_path) + context.epoch_id += 1 + if os.path.exists(strategy_path): + with open(strategy_path, 'rb') as strategy_file: + if sys.version_info < (3, 0): + strategies = pickle.load(strategy_file) + else: + strategies = pickle.load( + strategy_file, encoding='bytes') + + if os.path.exists(model_path): + exe = SlimGraphExecutor(context.place) + with scope_guard(context.scope): + context.optimize_graph.load_persistables(model_path, + exe) + context.optimize_graph.update_param_shape(context.scope) + context.optimize_graph.update_groups_of_conv() + context.eval_graph.update_param_shape(context.scope) + context.eval_graph.update_groups_of_conv() + _logger.info("Loaded params from: {}".format(model_path)) + return context, strategies + + def _save_checkpoint(self, context): + """ + Save checkpoints to file. + """ + if context.epoch_id % 1 == 0 and self.checkpoint_path: + checkpoint_path = os.path.join(self.checkpoint_path, + str(context.epoch_id)) + model_path = os.path.join(checkpoint_path, 'model') + context_path = os.path.join(checkpoint_path, 'context') + strategy_path = os.path.join(checkpoint_path, 'strategies') + if not os.path.isdir(model_path): + os.makedirs(model_path) + exe = SlimGraphExecutor(context.place) + with scope_guard(context.scope): + context.optimize_graph.save_persistables(model_path, exe) + context.to_file(context_path) + with open(strategy_path, 'wb') as strategy_file: + pickle.dump(self.strategies, strategy_file) + _logger.info('Saved checkpoint to: {}'.format(checkpoint_path)) + + def _train_one_epoch(self, context): + """ + Train one epoch. + """ + + executor = SlimGraphExecutor(self.place) + + if context.optimize_graph.compiled_graph is None: + context.optimize_graph.compiled_graph = compiler.CompiledProgram( + context.optimize_graph.program).with_data_parallel( + loss_name=context.optimize_graph.out_nodes['loss']) + + for data in context.train_reader(): + for strategy in self.strategies: + strategy.on_batch_begin(context) + results = executor.run(context.optimize_graph, + context.scope, + data=data) + results = [float(np.mean(result)) for result in results] + if context.batch_id % 20 == 0: + _logger.info("epoch:{}; batch_id:{}; {} = {}".format( + context.epoch_id, context.batch_id, + context.optimize_graph.out_nodes.keys( + ), [round(r, 3) for r in results])) + for strategy in self.strategies: + strategy.on_batch_end(context) + context.batch_id += 1 + context.batch_id = 0 + + def _eval(self, context): + """ + Runing evaluation. + """ + results, names = context.run_eval_graph() + for name, result in zip(names, results): + if name not in context.eval_results: + context.eval_results[name] = [] + context.eval_results[name].append(result) + + def run(self): + """ + Execute compressiong pass. + """ + context = Context( + place=self.place, + scope=self.scope, + train_graph=self.train_graph, + train_reader=self.train_reader, + eval_graph=self.eval_graph, + eval_reader=self.eval_reader, + teacher_graphs=self.teacher_graphs, + train_optimizer=self.train_optimizer, + distiller_optimizer=self.distiller_optimizer) + self.context = context + if self.teacher_graphs: + context.put('teachers', self.teacher_graphs) + self._init_model(context) + if not context.optimize_graph: + if context.train_optimizer: + context.train_optimizer._name = 'train_opt' + context.optimize_graph = context.train_graph.get_optimize_graph( + context.train_optimizer, context.place, context.scope) + else: + context.optimize_graph = context.train_graph + + context, self.strategies = self._load_checkpoint(context) + + for strategy in self.strategies: + strategy.on_compression_begin(context) + start = context.epoch_id + self._eval(context) + for epoch in range(start, self.epoch): + context.epoch_id = epoch + for strategy in self.strategies: + strategy.on_epoch_begin(context) + self._train_one_epoch(context) + for strategy in self.strategies: + strategy.on_epoch_end(context) + if self.eval_epoch and epoch % self.eval_epoch == 0: + self._eval(context) + self._save_checkpoint(context) + for strategy in self.strategies: + strategy.on_compression_end(context) + return context.eval_graph diff --git a/python/paddle/fluid/contrib/slim/core/config.py b/python/paddle/fluid/contrib/slim/core/config.py index 811c45700376aff9883fe197007b582f63817f03..12df9fcd1b0042c26aabac88d6ecba5fb827cba0 100644 --- a/python/paddle/fluid/contrib/slim/core/config.py +++ b/python/paddle/fluid/contrib/slim/core/config.py @@ -17,7 +17,7 @@ import funcsigs import yaml from collections import OrderedDict from ..prune import * -from .compress_pass import * +from ..quantization import * from .strategy import * __all__ = ['ConfigFactory'] @@ -29,15 +29,10 @@ class ConfigFactory(object): def __init__(self, config): """Init a factory from configure file.""" self.instances = {} + self.compressor = {} self.version = None self._parse_config(config) - def get_compress_pass(self): - """ - Get compress pass from factory. - """ - return self.instance('compress_pass') - def instance(self, name): """ Get instance from factory. @@ -59,8 +54,16 @@ class ConfigFactory(object): args = {} for key in keys: value = attrs[key] + if isinstance(value, str) and value.lower() == 'none': + value = None if isinstance(value, str) and value in self.instances: value = self.instances[value] + if isinstance(value, list): + for i in range(len(value)): + if isinstance(value[i], + str) and value[i] in self.instances: + value[i] = self.instances[value[i]] + args[key] = value self.instances[name] = class_(**args) return self.instances.get(name) @@ -76,16 +79,23 @@ class ConfigFactory(object): assert self.version == int(key_values['version']) # parse pruners - if key == 'pruners' or key == 'strategies': + if key == 'distillers' or key == 'pruners' or key == 'quantizers' or key == 'strategies': instances = key_values[key] for name in instances: self._new_instance(name, instances[name]) - if key == 'compress_pass': - compress_pass = self._new_instance(key, key_values[key]) - for name in key_values[key]['strategies']: - strategy = self.instance(name) - compress_pass.add_strategy(strategy) + if key == 'compressor': + self.compressor['strategies'] = [] + self.compressor['epoch'] = key_values[key]['epoch'] + if 'init_model' in key_values[key]: + self.compressor['init_model'] = key_values[key][ + 'init_model'] + self.compressor['checkpoint_path'] = key_values[key][ + 'checkpoint_path'] + if 'strategies' in key_values[key]: + for name in key_values[key]['strategies']: + strategy = self.instance(name) + self.compressor['strategies'].append(strategy) if key == 'include': for config_file in key_values[key]: diff --git a/python/paddle/fluid/contrib/slim/core/pass_builder.py b/python/paddle/fluid/contrib/slim/core/pass_builder.py deleted file mode 100644 index fc1ddc94e04f1d606292071ba7e5cc74fedd5d36..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/core/pass_builder.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .compress_pass import CompressPass -from .config import ConfigFactory - -__all__ = ['build_compressor'] - - -def build_compressor(place=None, - data_reader=None, - data_feeder=None, - scope=None, - metrics=None, - epoch=None, - config=None): - if config is not None: - factory = ConfigFactory(config) - comp_pass = factory.get_compress_pass() - else: - comp_pass = CompressPass() - comp_pass.place = place - comp_pass.data_reader = data_reader - comp_pass.data_feeder = data_feeder - comp_pass.scope = scope - comp_pass.metrics = metrics - comp_pass.epoch = epoch - return comp_pass diff --git a/python/paddle/fluid/contrib/slim/core/strategy.py b/python/paddle/fluid/contrib/slim/core/strategy.py index 74d98e98b0c390599acfaefeb0636a599b46d391..28bf24f4e341dd528d2cd25f6fb24543886150d6 100644 --- a/python/paddle/fluid/contrib/slim/core/strategy.py +++ b/python/paddle/fluid/contrib/slim/core/strategy.py @@ -20,7 +20,7 @@ class Strategy(object): Base class for all strategies. """ - def __init__(self, start_epoch=0, end_epoch=10): + def __init__(self, start_epoch=0, end_epoch=0): """ Args: start_epoch: The first epoch to apply the strategy. @@ -29,7 +29,7 @@ class Strategy(object): self.start_epoch = start_epoch self.end_epoch = end_epoch - def on_compress_begin(self, context): + def on_compression_begin(self, context): pass def on_epoch_begin(self, context): @@ -44,5 +44,5 @@ class Strategy(object): def on_batch_end(self, context): pass - def on_compress_end(self, context): + def on_compression_end(self, context): pass diff --git a/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml b/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml deleted file mode 100644 index ea888fa2c74a23b4769f75dce6a776afcca41a51..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/demo/filter_prune/config.yaml +++ /dev/null @@ -1,28 +0,0 @@ -version: 1.0 -pruners: - pruner_1: - class: 'RatioPruner' - ratios: - 'conv1_1.w': 0.3 - 'conv1_2.w': 0.4 - '*': 0.9 - group_dims: - '*': [1, 2, 3] - criterions: - '*': 'l1-norm' -strategies: - strategy_1: - class: 'SensitivePruneStrategy' - pruner: 'pruner_1' - start_epoch: 0 - end_epoch: 10 - delta_rate: 0.20 - acc_loss_threshold: 0.2 - sensitivities: - 'conv1_1.w': 0.4 - -compress_pass: - class: 'CompressPass' - epoch: 100 - strategies: - - strategy_1 diff --git a/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py b/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py deleted file mode 100644 index 21c59c0c9d2d9b76932ab6eeff73754940a3bfa0..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/demo/filter_prune/demo.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle.fluid as fluid -import paddle -import os -import sys -from paddle.fluid.contrib.slim import CompressPass -from paddle.fluid.contrib.slim import build_compressor -from paddle.fluid.contrib.slim import ImitationGraph - - -class LinearModel(object): - def __init__(slef): - pass - - def train(self): - train_program = fluid.Program() - startup_program = fluid.Program() - startup_program.random_seed = 10 - with fluid.program_guard(train_program, startup_program): - x = fluid.layers.data(name='x', shape=[13], dtype='float32') - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - predict = fluid.layers.fc(input=x, size=1, act=None) - cost = fluid.layers.square_error_cost(input=predict, label=y) - avg_cost = fluid.layers.mean(cost) - eval_program = train_program.clone() - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(avg_cost) - - train_reader = paddle.batch( - paddle.dataset.uci_housing.train(), batch_size=1) - eval_reader = paddle.batch( - paddle.dataset.uci_housing.test(), batch_size=1) - place = fluid.CPUPlace() - train_feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) - eval_feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) - exe = fluid.Executor(place) - exe.run(startup_program) - train_metrics = {"loss": avg_cost.name} - eval_metrics = {"loss": avg_cost.name} - - graph = ImitationGraph(train_program) - config = './config.yaml' - comp_pass = build_compressor( - place, - data_reader=train_reader, - data_feeder=train_feeder, - scope=fluid.global_scope(), - metrics=train_metrics, - epoch=1, - config=config) - comp_pass.apply(graph) - - -if __name__ == "__main__": - model = LinearModel() - model.train() diff --git a/python/paddle/fluid/contrib/slim/graph/__init__.py b/python/paddle/fluid/contrib/slim/graph/__init__.py index d65472d193b639f0766e278ec14b5dc36c5d62bc..c5d1c4dbdfb208ea66bb3dc315e502309799492e 100644 --- a/python/paddle/fluid/contrib/slim/graph/__init__.py +++ b/python/paddle/fluid/contrib/slim/graph/__init__.py @@ -14,10 +14,7 @@ from . import executor from .executor import * -from . import graph -from .graph import * -from . import graph_pass -from .graph_pass import * +from . import graph_wrapper +from .graph_wrapper import * __all__ = executor.__all__ -__all__ += graph.__all__ -__all__ += graph_pass.__all__ +__all__ += graph_wrapper.__all__ diff --git a/python/paddle/fluid/contrib/slim/graph/executor.py b/python/paddle/fluid/contrib/slim/graph/executor.py index c02c3af82013287bf19e1869cb60dc65239b720a..70438a90eb790e7ca5d00be0bc09efc6c00cafe4 100644 --- a/python/paddle/fluid/contrib/slim/graph/executor.py +++ b/python/paddle/fluid/contrib/slim/graph/executor.py @@ -12,51 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc -from abc import abstractmethod +from ....compiler import CompiledProgram +from ....data_feeder import DataFeeder from .... import executor -from .graph import IRGraph, ImitationGraph +from .graph_wrapper import GraphWrapper -__all__ = ['get_executor'] +__all__ = ['SlimGraphExecutor'] -class GraphExecutor(object): - __metaclass__ = abc.ABCMeta +class SlimGraphExecutor(object): + """ + Wrapper of executor used to run GraphWrapper. + """ def __init__(self, place): - self.place = place - - @abstractmethod - def run(self, graph, feches=None, feed=None): - pass - - -class IRGraphExecutor(GraphExecutor): - def run(self, grah, fetches, feed=None): - pass - - -class ImitationGraphExecutor(GraphExecutor): - def __init__(self, place): - super(ImitationGraphExecutor, self).__init__(place) self.exe = executor.Executor(place) + self.place = place - def run(self, graph, scope=None, fetches=None, feed=None): - assert isinstance(graph, ImitationGraph) - fetch_list = None - if fetches: - fetch_list = [ - graph.program.global_block().var(name) for name in fetches - ] - results = self.exe.run(graph.program, + def run(self, graph, scope, data=None): + """ + Runing a graph with a batch of data. + Args: + graph(GraphWrapper): The graph to be executed. + scope(fluid.core.Scope): The scope to be used. + data(list): A batch of data. Each tuple in this list is a sample. + It will feed the items of tuple to the in_nodes of graph. + Returns: + results(list): A list of result with the same order indicated by graph.out_nodes. + """ + assert isinstance(graph, GraphWrapper) + if data is not None: + feeder = DataFeeder( + feed_list=graph.in_nodes.values(), + place=self.place, + program=graph.program) + feed = feeder.feed(data) + + fetch_list = graph.out_nodes.values() + program = graph.compiled_graph if graph.compiled_graph else graph.program + results = self.exe.run(program, scope=scope, fetch_list=fetch_list, feed=feed) return results - - -def get_executor(graph, place): - if isinstance(graph, ImitationGraph): - return ImitationGraphExecutor(place) - if isinstance(graph, IRGraph): - return IRGraphExecutor(place) diff --git a/python/paddle/fluid/contrib/slim/graph/graph.py b/python/paddle/fluid/contrib/slim/graph/graph.py deleted file mode 100644 index f38d9783413a01cd1005a014c0aba5ecf5cc79c2..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/graph/graph.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import print_function -import os -import subprocess -from ....framework import Program -from ....framework import Block -from .... import core - -__all__ = ['Graph', 'ImitationGraph', 'IRGraph'] - - -class Graph(object): - """ - Base class for all graph. - """ - - def __init__(self): - pass - - def all_parameters(self): - """ - Return all the parameters in current graph. - """ - pass - - -class ImitationGraph(Graph): - def __init__(self, program=None): - super(ImitationGraph, self).__init__() - self.program = Program() if program is None else program - - def all_parameters(self): - return self.program.global_block().all_parameters() - - -class IRGraph(Graph): - pass diff --git a/python/paddle/fluid/contrib/slim/graph/graph_pass.py b/python/paddle/fluid/contrib/slim/graph/graph_pass.py deleted file mode 100644 index 1db6c4f110daa44be7fcbcc36f47224797b6dc88..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/graph/graph_pass.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ['GraphPass', 'PruneParameterPass'] - - -class GraphPass(object): - """ - Base class for all graph pass. - """ - - def __init__(self): - pass - - def apply(self, graph): - pass - - -class PruneParameterPass(GraphPass): - """ - Generate a graph for pruning parameters from target graph. - """ - - def __init__(self, pruned_params, thresholds): - super(PruneParameterPass, self).__init__() - self.pruned_params = pruned_params - self.thresholds = thresholds - self.default_threshold = thresholds['*'] - - def apply(self, graph): - pass diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..8694be782708a6d47b3e1450305975d34fd3bd7f --- /dev/null +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -0,0 +1,500 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from .... import io +from .... import compiler +from ....framework import Program +from ....framework import program_guard +from ....framework import Parameter +from ....framework import Variable +from ....executor import Executor +import copy +from collections import Iterable +from ....io import save_inference_model, load_inference_model, save_persistables +import numpy as np +import pickle +import os + +__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper'] + +OPTIMIZER_OPS = [ + 'momentum', + 'lars_momentum', + 'adagrad', + 'adam', + 'adamax', + 'decayed_adagrad', + 'adadelta', + 'rmsprop', +] + + +class VarWrapper(object): + def __init__(self, var, graph): + assert isinstance(var, Variable) + assert isinstance(graph, GraphWrapper) + self._var = var + self._graph = graph + + def __eq__(self, v): + """ + Overwrite this function for ...in... syntax in python. + """ + return self._var.name == v._var.name + + def name(self): + """ + Get the name of the variable. + """ + return self._var.name + + def shape(self): + """ + Get the shape of the varibale. + """ + return self._var.shape + + def set_shape(self, shape): + """ + Set the shape of the variable. + """ + self._var.desc.set_shape(shape) + + def inputs(self): + """ + Get all the operators that use this variable as output. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_inputs(): + ops.append(op) + return ops + + def outputs(self): + """ + Get all the operators that use this variable as input. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_outputs(): + ops.append(op) + return ops + + +class OpWrapper(object): + def __init__(self, op, graph): + assert isinstance(graph, GraphWrapper) + self._op = op + self._graph = graph + + def __eq__(self, op): + """ + Overwrite this function for ...in... syntax in python. + """ + return self.idx() == op.idx() + + def all_inputs(self): + """ + Get all the input variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.input_arg_names + ] + + def all_outputs(self): + """ + Get all the output variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.output_arg_names + ] + + def idx(self): + """ + Get the id of this operator. + """ + return self._op.idx + + def type(self): + """ + Get the type of this operator. + """ + return self._op.type + + def is_bwd_op(self): + """ + Whether this operator is backward op. + """ + return self.type().endswith('_grad') + + def is_opt_op(self): + """ + Whether this operator is optimizer op. + """ + return self.type() in OPTIMIZER_OPS + + def inputs(self, name): + """ + Get all the varibales by the input name. + """ + return [self._graph.var(var_name) for var_name in self._op.input(name)] + + def outputs(self, name): + """ + Get all the varibales by the output name. + """ + return [self._graph.var(var_name) for var_name in self._op.output(name)] + + def set_attr(self, key, value): + """ + Set the value of attribute by attribute's name. + + Args: + key(str): the attribute name. + value(bool|int|str|float|list): the value of the attribute. + """ + self._op._set_attr(key, value) + + def attr(self, name): + """ + Get the attribute by name. + + Args: + name(str): the attribute name. + + Returns: + bool|int|str|float|list: The attribute value. The return value + can be any valid attribute type. + """ + return self._op.attr(name) + + +class GraphWrapper(object): + """ + It is a wrapper of paddle.fluid.framework.IrGraph with some special functions + for paddle slim framework. + """ + + def __init__(self, program=None, in_nodes=[], out_nodes=[]): + """ + Args: + program(framework.Program): A program with + in_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + out_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + """ + super(GraphWrapper, self).__init__() + self.program = Program() if program is None else program + self.compiled_graph = None + self.in_nodes = OrderedDict(in_nodes) + self.out_nodes = OrderedDict(out_nodes) + self._attrs = OrderedDict() + + def all_parameters(self): + """ + Get all the parameters in this graph. + Returns: + list: A list of VarWrapper instances. + """ + params = [] + for block in self.program.blocks: + for param in block.all_parameters(): + params.append(VarWrapper(param, self)) + return params + + def is_parameter(self, var): + """ + Whether the given variable is parameter. + Args: + var(VarWrapper): The given varibale. + """ + return isinstance(var._var, Parameter) + + def is_persistable(self, var): + """ + Whether the given variable is persistable. + Args: + var(VarWrapper): The given varibale. + """ + return var._var.persistable + + def compile(self, for_parallel=True, for_test=False): + """ + Compile the program in this wrapper to framework.CompiledProgram for next running. + This function must be called if the program is modified. + Args: + for_parallel(bool): Whether the program to run in data parallel way. default: True. + for_test(bool): Whether the compiled program is used for test. + """ + target = self.program + if for_test: + loss = None + else: + loss = self.out_nodes['loss'] + if for_parallel: + # disable memory optimize for stable training + build_strategy = compiler.BuildStrategy() + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + self.compiled_graph = compiler.CompiledProgram( + target).with_data_parallel( + loss_name=loss, build_strategy=build_strategy) + else: + self.compiled_graph = compiler.CompiledProgram(target) + + def ops(self): + """ + Return all operator nodes included in the graph as a set. + """ + ops = [] + for block in self.program.blocks: + for op in block.ops: + ops.append(OpWrapper(op, self)) + return ops + + def vars(self): + """ + Get all the variables. + """ + return [VarWrapper(var, self) for var in self.program.list_vars()] + + def var(self, name): + """ + Get the variable by variable name. + """ + return VarWrapper(self.program.global_block().var(name), self) + + def clone(self, for_test=False): + """ + Clone a new graph from current graph. + Returns: + (GraphWrapper): The wrapper of a new graph. + """ + return GraphWrapper( + self.program.clone(for_test), + copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes)) + + def merge(self, graph): + """ + Merge a graph into current graph. + Args: + graph(GraphWrapper): The graph to be merged by current graph. + """ + for var in graph.program.list_vars(): + self.program.global_block()._clone_variable(var) + # TODO: parameters should be cloned + for op in graph.ops(): + op = op._op + inputs = {} + outputs = {} + attrs = {} + for input_name in op.input_names: + inputs[input_name] = [ + self.var(in_var_name) + for in_var_name in op.inputs(input_name) + ] + for output_name in op.output_names: + outputs[output_name] = [ + self.var(out_var_name) + for out_var_name in op.output(output_name) + ] + for attr_name in op.attr_names: + attrs[attr_name] = op.attr(attr_name) + self.program.global_block().append_op( + type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + + def program(self): + """ + Get the program in current wrapper. + """ + return self.program + + def pre_ops(self, op): + """ + Get all the previous operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for in_var in op.all_inputs(): + if in_var in p.all_outputs(): + ops.append(p) + return ops + + def next_ops(self, op): + """ + Get all the next operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for out_var in op.all_outputs(): + if out_var in p.all_inputs(): + ops.append(p) + return ops + + def get_param_by_op(self, op): + """ + Get the parameters used by target operator. + """ + assert isinstance(op, OpWrapper) + params = [] + for var in op.all_inputs(): + if isinstance(var._var, Parameter): + params.append(var) + assert len(params) > 0 + return params + + def numel_params(self): + """ + Get the number of elements in all parameters. + """ + ret = 0 + for param in self.all_parameters(): + ret += np.product(param.shape()) + return ret + + def get_optimize_graph(self, optimizer, place, scope, no_grad_var_names=[]): + """ + Get a new graph for training by appending some backward operators and optimization operators. + Args: + optimizer: The optimzier used to generate training graph. + place: The place to run the graph. + scope: The scope used to run the graph. Some new variable will be added into this scope. + no_grad_var_names(list): Names of variables that should be ignored while computing gradients. default: []. + Returns: + (GraphWrapper): The wrapper of new graph with backward ops and optimization ops. + """ + graph = self.clone() + startup_program = Program() + with program_guard( + main_program=graph.program, startup_program=startup_program): + target_name = None + if 'loss' in graph.out_nodes: + target_name = graph.out_nodes['loss'] + elif 'cost' in graph.out_nodes: + target_name = graph.out_nodes['cost'] + target = graph.var(target_name)._var + optimizer.minimize(target, no_grad_set=no_grad_var_names) + + exe = Executor(place) + exe.run(program=startup_program, scope=scope) + return graph + + def flops(self, only_conv=False): + """ + Get the flops of current graph. + Args: + only_conv: Only calculating the conv layers. default: False. + Returns: + int: The flops of current graph. + """ + flops = 0 + for op in self.ops(): + if op.type() in ['conv2d', 'depthwise_conv2d']: + filter_shape = op.inputs("Filter")[0].shape() + input_shape = op.inputs("Input")[0].shape() + output_shape = op.outputs("Output")[0].shape() + c_out, c_in, k_h, k_w = filter_shape + _, _, h_out, w_out = output_shape + groups = op.attr("groups") + kernel_ops = k_h * k_w * (c_in / groups) + if len(op.inputs("Bias")) > 0: + with_bias = 1 + else: + with_bias = 0 + flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias) + elif op.type() == 'pool2d' and not only_conv: + input_shape = op.inputs("X")[0].shape() + output_shape = op.outputs("Out")[0].shape() + _, c_out, h_out, w_out = output_shape + k_size = op.attr("ksize") + flops += h_out * w_out * c_out * (k_size[0]**2) + + elif op.type() == 'mul' and not only_conv: + x_shape = list(op.inputs("X")[0].shape()) + y_shape = op.inputs("Y")[0].shape() + if x_shape[0] == -1: + x_shape[0] = 1 + flops += 2 * x_shape[0] * x_shape[1] * y_shape[1] + + elif op.type() in ['relu', 'sigmoid', 'batch_norm' + ] and not only_conv: + input_shape = list(op.inputs("X")[0].shape()) + if input_shape[0] == -1: + input_shape[0] = 1 + flops += np.product(input_shape) + + return flops + + def save_persistables(self, path, exe): + """ + Save all the persistable variables into file. + Args: + path(str): The path to save the persistables. + exe(framework.Executor): The executor used to save the persistables. + """ + io.save_persistables(exe.exe, path, main_program=self.program) + + def load_persistables(self, path, exe): + """ + Load the persistable variables from file. + Args: + path(str): The path to load the persistables. + exe(framework.Executor): The executor used to load the persistables. + """ + + def if_exist(var): + return os.path.exists(os.path.join(path, var.name)) + + io.load_vars( + exe.exe, path, main_program=self.program, predicate=if_exist) + + def update_param_shape(self, scope): + """ + Update the shape of parameters in the graph according to tensors in scope. + It is used after loading pruned parameters from file. + """ + for param in self.all_parameters(): + tensor_shape = np.array(scope.find_var(param.name()).get_tensor( + )).shape + param.set_shape(tensor_shape) + + def infer_shape(self): + """ + Update the groups of convolution layer according to current filters. + It is used after loading pruned parameters from file. + """ + for op in self.ops(): + if op.type() != 'conditional_block': + op._op.desc.infer_shape(op._op.block.desc) + + def update_groups_of_conv(self): + for op in self.ops(): + if op.type() == 'depthwise_conv2d': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py index 34c5107daa3cde10e7995902be37e34e19664da8..7a25c3a61e0815a20fa9b0477a6c69a4f8d2a066 100644 --- a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py +++ b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py @@ -13,54 +13,919 @@ # limitations under the License. from ..core.strategy import Strategy -from ....framework import Program, program_guard +from ..graph import VarWrapper, OpWrapper, GraphWrapper +from ....framework import Program, program_guard, Parameter from .... import layers +import prettytable as pt import numpy as np +from scipy.optimize import leastsq +import copy +import re +import os +import pickle +import logging +import sys -__all__ = ['SensitivePruneStrategy', 'PruneStrategy'] +__all__ = ['SensitivePruneStrategy', 'UniformPruneStrategy'] +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +class PruneStrategy(Strategy): + """ + The base class of all pruning strategies. + """ -class SensitivePruneStrategy(Strategy): def __init__(self, pruner=None, start_epoch=0, - end_epoch=10, - delta_rate=0.20, - acc_loss_threshold=0.2, - sensitivities=None): - super(SensitivePruneStrategy, self).__init__(start_epoch, end_epoch) + end_epoch=0, + target_ratio=0.5, + metric_name=None, + pruned_params='conv.*_weights'): + """ + Args: + pruner(slim.Pruner): The pruner used to prune the parameters. + start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0 + end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0 + target_ratio(float): The flops ratio to be pruned from current model. + metric_name(str): The metric used to evaluate the model. + It should be one of keys in out_nodes of graph wrapper. + pruned_params(str): The pattern str to match the parameter names to be pruned. + """ + super(PruneStrategy, self).__init__(start_epoch, end_epoch) self.pruner = pruner - self.delta_rate = delta_rate - self.acc_loss_threshold = acc_loss_threshold - self.sensitivities = sensitivities + self.target_ratio = target_ratio + self.metric_name = metric_name + self.pruned_params = pruned_params + self.pruned_list = [] + self.backup = {} + self.param_shape_backup = {} + def _eval_graph(self, context, sampled_rate=None, cached_id=0): + """ + Evaluate the current mode in context. + Args: + context(slim.core.Context): The context storing all information used to evaluate the current model. + sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. + cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. + """ + results, names = context.run_eval_graph(sampled_rate, cached_id) + metric = np.mean(results[list(names).index(self.metric_name)]) + return metric -class PruneStrategy(Strategy): + def _prune_filters_by_ratio(self, + scope, + params, + ratio, + place, + lazy=False, + only_graph=False): + """ + Pruning filters by given ratio. + Args: + scope(fluid.core.Scope): The scope used to pruning filters. + params(list): A list of filter parameters. + ratio(float): The ratio to be pruned. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[0]: + return + param_t = scope.find_var(params[0].name()).get_tensor() + pruned_idx = self.pruner.cal_pruned_idx( + params[0].name(), np.array(param_t), ratio, axis=0) + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if lazy: + self.backup[param.name()] = copy.deepcopy(np.array(param_t)) + pruned_param = self.pruner.prune_tensor( + np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) + if not only_graph: + param_t.set(pruned_param, place) + ori_shape = param.shape() + if param.name() not in self.param_shape_backup: + self.param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[0] = pruned_param.shape[0] + param.set_shape(new_shape) + _logger.debug( + '|----------------------------------------+----+------------------------------+------------------------------|' + ) + _logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format( + str(param.name()), str(0), str(ori_shape), str(param.shape()))) + self.pruned_list[0].append(param.name()) + return pruned_idx + + def _prune_parameter_by_idx(self, + scope, + params, + pruned_idx, + pruned_axis, + place, + lazy=False, + only_graph=False): + """ + Pruning parameters in given axis. + Args: + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(VarWrapper): The parameter to be pruned. + pruned_idx(list): The index of elements to be pruned. + pruned_axis(int): The pruning axis. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[pruned_axis]: + return + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if lazy: + self.backup[param.name()] = copy.deepcopy(np.array(param_t)) + pruned_param = self.pruner.prune_tensor( + np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) + if not only_graph: + param_t.set(pruned_param, place) + ori_shape = param.shape() + if param.name() not in self.param_shape_backup: + self.param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[pruned_axis] = pruned_param.shape[pruned_axis] + param.set_shape(new_shape) + _logger.debug( + '|----------------------------------------+----+------------------------------+------------------------------|' + ) + _logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format( + str(param.name()), + str(pruned_axis), str(ori_shape), str(param.shape()))) + self.pruned_list[pruned_axis].append(param.name()) + + def _forward_search_related_op(self, graph, param): + """ + Forward search operators that will be affected by pruning of param. + Args: + graph(GraphWrapper): The graph to be searched. + param(VarWrapper): The current pruned parameter. + Returns: + list: A list of operators. + """ + assert isinstance(param, VarWrapper) + visited = {} + for op in graph.ops(): + visited[op.idx()] = False + stack = [] + for op in graph.ops(): + if (not op.is_bwd_op()) and (param in op.all_inputs()): + stack.append(op) + visit_path = [] + while len(stack) > 0: + top_op = stack[len(stack) - 1] + if visited[top_op.idx()] == False: + visit_path.append(top_op) + visited[top_op.idx()] = True + next_ops = None + if top_op.type() == "conv2d" and param not in top_op.all_inputs(): + next_ops = None + elif top_op.type() == "mul": + next_ops = None + else: + next_ops = self._get_next_unvisited_op(graph, visited, top_op) + if next_ops == None: + stack.pop() + else: + stack += next_ops + return visit_path + + def _get_next_unvisited_op(self, graph, visited, top_op): + """ + Get next unvisited adjacent operators of given operators. + Args: + graph(GraphWrapper): The graph used to search. + visited(list): The ids of operators that has been visited. + top_op: The given operator. + Returns: + list: A list of operators. + """ + assert isinstance(top_op, OpWrapper) + next_ops = [] + for op in graph.next_ops(top_op): + if (visited[op.idx()] == False) and (not op.is_bwd_op()): + next_ops.append(op) + return next_ops if len(next_ops) > 0 else None + + def _get_accumulator(self, graph, param): + """ + Get accumulators of given parameter. The accumulator was created by optimizer. + Args: + graph(GraphWrapper): The graph used to search. + param(VarWrapper): The given parameter. + Returns: + list: A list of accumulators which are variables. + """ + assert isinstance(param, VarWrapper) + params = [] + for op in param.outputs(): + if op.is_opt_op(): + for out_var in op.all_outputs(): + if graph.is_persistable(out_var) and out_var.name( + ) != param.name(): + params.append(out_var) + return params + + def _forward_pruning_ralated_params(self, + graph, + scope, + param, + place, + ratio=None, + pruned_idxs=None, + lazy=False, + only_graph=False): + """ + Pruning all the parameters affected by the pruning of given parameter. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + param(VarWrapper): The given parameter. + place(fluid.Place): The device place of filter parameters. + ratio(float): The target ratio to be pruned. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + assert isinstance( + graph, + GraphWrapper), "graph must be instance of slim.core.GraphWrapper" + assert isinstance( + param, VarWrapper), "param must be instance of slim.core.VarWrapper" + + if param.name() in self.pruned_list[0]: + return + related_ops = self._forward_search_related_op(graph, param) + + if ratio is None: + assert pruned_idxs is not None + self._prune_parameter_by_idx( + scope, [param] + self._get_accumulator(graph, param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + + else: + pruned_idxs = self._prune_filters_by_ratio( + scope, [param] + self._get_accumulator(graph, param), + ratio, + place, + lazy=lazy, + only_graph=only_graph) + corrected_idxs = pruned_idxs[:] + + for idx, op in enumerate(related_ops): + if op.type() == "conv2d" and (param not in op.all_inputs()): + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=1, + place=place, + lazy=lazy, + only_graph=only_graph) + if op.type() == "depthwise_conv2d": + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + elif op.type() == "elementwise_add": + # pruning bias + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + bias_param = in_var + self._prune_parameter_by_idx( + scope, [bias_param] + self._get_accumulator( + graph, bias_param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + elif op.type() == "mul": # pruning fc layer + fc_input = None + fc_param = None + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + fc_param = in_var + else: + fc_input = in_var + + idx = [] + feature_map_size = fc_input.shape()[2] * fc_input.shape()[3] + range_idx = np.array(range(feature_map_size)) + for i in corrected_idxs: + idx += list(range_idx + i * feature_map_size) + corrected_idxs = idx + self._prune_parameter_by_idx( + scope, [fc_param] + self._get_accumulator(graph, fc_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + + elif op.type() == "concat": + concat_inputs = op.all_inputs() + last_op = related_ops[idx - 1] + for out_var in last_op.all_outputs(): + if out_var in concat_inputs: + concat_idx = concat_inputs.index(out_var) + offset = 0 + for ci in range(concat_idx): + offset += concat_inputs[ci].shape()[1] + corrected_idxs = [x + offset for x in pruned_idxs] + elif op.type() == "batch_norm": + bn_inputs = op.all_inputs() + mean = bn_inputs[2] + variance = bn_inputs[3] + alpha = bn_inputs[0] + beta = bn_inputs[1] + self._prune_parameter_by_idx( + scope, [mean] + self._get_accumulator(graph, mean), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + self._prune_parameter_by_idx( + scope, [variance] + self._get_accumulator(graph, variance), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + self._prune_parameter_by_idx( + scope, [alpha] + self._get_accumulator(graph, alpha), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + self._prune_parameter_by_idx( + scope, [beta] + self._get_accumulator(graph, beta), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph) + + def _prune_parameters(self, + graph, + scope, + params, + ratios, + place, + lazy=False, + only_graph=False): + """ + Pruning the given parameters. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(list): A list of parameter names to be pruned. + ratios(list): A list of ratios to be used to pruning parameters. + place(fluid.Place): The device place of filter parameters. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + + """ + _logger.debug('\n################################') + _logger.debug('# pruning parameters #') + _logger.debug('################################\n') + _logger.debug( + '|----------------------------------------+----+------------------------------+------------------------------|' + ) + _logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format('parameter', 'axis', + 'from', 'to')) + assert len(params) == len(ratios) + self.pruned_list = [[], []] + for param, ratio in zip(params, ratios): + assert isinstance(param, str) or isinstance(param, unicode) + param = graph.var(param) + self._forward_pruning_ralated_params( + graph, + scope, + param, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph) + ops = param.outputs() + for op in ops: + if op.type() == 'conv2d': + brother_ops = self._search_brother_ops(graph, op) + for broher in brother_ops: + for p in graph.get_param_by_op(broher): + self._forward_pruning_ralated_params( + graph, + scope, + p, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph) + _logger.debug( + '|----------------------------------------+----+------------------------------+------------------------------|' + ) + + def _search_brother_ops(self, graph, op_node): + """ + Search brother operators that was affected by pruning of given operator. + Args: + graph(GraphWrapper): The graph to be searched. + op_node(OpWrapper): The start node for searching. + Returns: + list: A list of operators. + """ + visited = [op_node.idx()] + stack = [] + brothers = [] + for op in graph.next_ops(op_node): + if (op.type() != 'conv2d') and (op.type() != 'fc') and ( + not op._is_bwd_op()): + stack.append(op) + visited.append(op.idx()) + while len(stack) > 0: + top_op = stack.pop() + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and (not parent._is_bwd_op()): + if ((parent.type == 'conv2d') or (parent.type == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) + + for child in graph.next_ops(top_op): + if (child.type != 'conv2d') and (child.type != 'fc') and ( + child.idx() not in visited) and ( + not child._is_bwd_op()): + stack.append(child) + visited.append(child.idx()) + return brothers + + def _prune_graph(self, graph, target_graph): + """ + Pruning parameters of graph according to target graph. + Args: + graph(GraphWrapper): The graph to be pruned. + target_graph(GraphWrapper): The reference graph. + Return: None + """ + count = 1 + _logger.debug( + '|----+----------------------------------------+------------------------------+------------------------------|' + ) + _logger.debug('|{:^4}|{:^40}|{:^30}|{:^30}|'.format('id', 'parammeter', + 'from', 'to')) + for param in target_graph.all_parameters(): + var = graph.var(param.name()) + ori_shape = var.shape() + var.set_shape(param.shape()) + _logger.debug( + '|----+----------------------------------------+------------------------------+------------------------------|' + ) + _logger.debug('|{:^4}|{:^40}|{:^30}|{:^30}|'.format( + str(count), + str(param.name()), str(ori_shape), str(param.shape()))) + count += 1 + _logger.debug( + '|----+----------------------------------------+------------------------------+------------------------------|' + ) + + +class UniformPruneStrategy(PruneStrategy): """ - The strategy that pruning weights by threshold or ratio iteratively. + The uniform pruning strategy. The parameters will be pruned by uniform ratio. """ def __init__(self, - pruner, - mini_batch_pruning_frequency=1, + pruner=None, start_epoch=0, - end_epoch=10): - super(PruneStrategy, self).__init__(start_epoch, end_epoch) - self.pruner = pruner - self.mini_batch_pruning_frequency = mini_batch_pruning_frequency - - def _triger(self, context): - return (context.batch_id % self.mini_batch_pruning_frequency == 0 and - self.start_epoch <= context.epoch_id < self.end_epoch) - - def on_batch_end(self, context): - if self._triger(context): - prune_program = Program() - with program_guard(prune_program): - for param in context.graph.all_parameters(): - prune_program.global_block().clone_variable(param) - p = prune_program.global_block().var(param.name) - zeros_mask = self.pruner.prune(p) - pruned_param = p * zeros_mask - layers.assign(input=pruned_param, output=param) - context.program_exe.run(prune_program, scope=context.scope) + end_epoch=0, + target_ratio=0.5, + metric_name=None, + pruned_params='conv.*_weights'): + """ + Args: + pruner(slim.Pruner): The pruner used to prune the parameters. + start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0 + end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0 + target_ratio(float): The flops ratio to be pruned from current model. + metric_name(str): The metric used to evaluate the model. + It should be one of keys in out_nodes of graph wrapper. + pruned_params(str): The pattern str to match the parameter names to be pruned. + """ + super(UniformPruneStrategy, self).__init__(pruner, start_epoch, + end_epoch, target_ratio, + metric_name, pruned_params) + + def _get_best_ratios(self, context): + """ + Search a group of ratios for pruning target flops. + """ + _logger.info('_get_best_ratios') + pruned_params = [] + for param in context.eval_graph.all_parameters(): + if re.match(self.pruned_params, param.name()): + pruned_params.append(param.name()) + + min_ratio = 0. + max_ratio = 1. + + flops = context.eval_graph.flops() + model_size = context.eval_graph.numel_params() + + while min_ratio < max_ratio: + ratio = (max_ratio + min_ratio) / 2 + _logger.debug( + '-----------Try pruning ratio: {:.2f}-----------'.format(ratio)) + ratios = [ratio] * len(pruned_params) + self._prune_parameters( + context.eval_graph, + context.scope, + pruned_params, + ratios, + context.place, + only_graph=True) + + pruned_flops = 1 - (float(context.eval_graph.flops()) / flops) + pruned_size = 1 - (float(context.eval_graph.numel_params()) / + model_size) + _logger.debug('Pruned flops: {:.2f}'.format(pruned_flops)) + _logger.debug('Pruned model size: {:.2f}'.format(pruned_size)) + for param in self.param_shape_backup.keys(): + context.eval_graph.var(param).set_shape(self.param_shape_backup[ + param]) + self.param_shape_backup = {} + + if abs(pruned_flops - self.target_ratio) < 1e-2: + break + if pruned_flops > self.target_ratio: + max_ratio = ratio + else: + min_ratio = ratio + _logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios])) + return pruned_params, ratios + + def on_epoch_begin(self, context): + if context.epoch_id == self.start_epoch: + params, ratios = self._get_best_ratios(context) + + self._prune_parameters(context.optimize_graph, context.scope, + params, ratios, context.place) + + model_size = context.eval_graph.numel_params() + flops = context.eval_graph.flops() + _logger.debug('\n################################') + _logger.debug('# pruning eval graph #') + _logger.debug('################################\n') + self._prune_graph(context.eval_graph, context.optimize_graph) + context.optimize_graph.update_groups_of_conv() + context.eval_graph.update_groups_of_conv() + + _logger.info( + '------------------finish pruning--------------------------------' + ) + _logger.info('Pruned size: {:.2f}'.format(1 - (float( + context.eval_graph.numel_params()) / model_size))) + _logger.info('Pruned flops: {:.2f}'.format(1 - (float( + context.eval_graph.flops()) / flops))) + # metric = self._eval_graph(context) + # _logger.info('Metric after pruning: {:.2f}'.format(metric)) + _logger.info( + '------------------UniformPruneStrategy.on_compression_begin finish--------------------------------' + ) + + +class SensitivePruneStrategy(PruneStrategy): + """ + Sensitive pruning strategy. Different pruned ratio was applied on each layer. + """ + + def __init__(self, + pruner=None, + start_epoch=0, + end_epoch=0, + delta_rate=0.20, + target_ratio=0.5, + metric_name='top1_acc', + pruned_params='conv.*_weights', + sensitivities_file='./sensitivities.data', + sensitivities={}, + num_steps=1, + eval_rate=None): + """ + Args: + pruner(slim.Pruner): The pruner used to prune the parameters. + start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0. + end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 10. + delta_rate(float): The delta used to generate ratios when calculating sensitivities. default: 0.2 + target_ratio(float): The flops ratio to be pruned from current model. default: 0.5 + metric_name(str): The metric used to evaluate the model. + It should be one of keys in out_nodes of graph wrapper. default: 'top1_acc' + pruned_params(str): The pattern str to match the parameter names to be pruned. default: 'conv.*_weights'. + sensitivities_file(str): The sensitivities file. default: './sensitivities.data' + sensitivities(dict): The user-defined sensitivities. default: {}. + num_steps(int): The number of pruning steps. default: 1. + eval_rate(float): The rate of sampled data used to calculate sensitivities. + None means using all the data. default: None. + """ + super(SensitivePruneStrategy, self).__init__(pruner, start_epoch, + end_epoch, target_ratio, + metric_name, pruned_params) + self.delta_rate = delta_rate + self.pruned_list = [] + self.sensitivities = sensitivities + self.sensitivities_file = sensitivities_file + self.backup = {} + self.param_shape_backup = {} + self.num_steps = num_steps + self.eval_rate = eval_rate + self.pruning_step = 1 - pow((1 - target_ratio), 1.0 / self.num_steps) + + def _save_sensitivities(self, sensitivities, sensitivities_file): + """ + Save sensitivities into file. + """ + with open(sensitivities_file, 'wb') as f: + pickle.dump(sensitivities, f) + + def _load_sensitivities(self, sensitivities_file): + """ + Load sensitivities from file. + """ + sensitivities = {} + if sensitivities_file and os.path.exists(sensitivities_file): + with open(sensitivities_file, 'rb') as f: + if sys.version_info < (3, 0): + sensitivities = pickle.load(f) + else: + sensitivities = pickle.load(f, encoding='bytes') + + for param in sensitivities: + sensitivities[param]['pruned_percent'] = [ + round(p, 2) for p in sensitivities[param]['pruned_percent'] + ] + self._format_sensitivities(sensitivities) + return sensitivities + + def _format_sensitivities(self, sensitivities): + """ + Print formated sensitivities in debug log level. + """ + tb = pt.PrettyTable() + tb.field_names = ["parameter", "size"] + [ + str(round(i, 2)) + for i in np.arange(self.delta_rate, 1, self.delta_rate) + ] + for param in sensitivities: + if len(sensitivities[param]['loss']) == (len(tb.field_names) - 2): + tb.add_row([param, sensitivities[param]['size']] + [ + round(loss, 2) for loss in sensitivities[param]['loss'] + ]) + _logger.debug('\n################################') + _logger.debug('# sensitivities table #') + _logger.debug('################################\n') + _logger.debug(tb) + + def _compute_sensitivities(self, context): + """ + Computing the sensitivities of all parameters. + """ + _logger.info("calling _compute_sensitivities.") + self.param_shape_backup = {} + self.backup = {} + cached_id = np.random.randint(1000) + if self.start_epoch == context.epoch_id: + sensitivities_file = self.sensitivities_file + else: + sensitivities_file = self.sensitivities_file + ".epoch" + str( + context.epoch_id) + sensitivities = self._load_sensitivities(sensitivities_file) + + for param in context.eval_graph.all_parameters(): + if not re.match(self.pruned_params, param.name()): + continue + if param.name() not in sensitivities: + sensitivities[param.name()] = { + 'pruned_percent': [], + 'loss': [], + 'size': param.shape()[0] + } + + metric = None + + for param in sensitivities.keys(): + ratio = self.delta_rate + while ratio < 1: + ratio = round(ratio, 2) + if ratio in sensitivities[param]['pruned_percent']: + _logger.debug('{}, {} has computed.'.format(param, ratio)) + ratio += self.delta_rate + continue + if metric is None: + metric = self._eval_graph(context, self.eval_rate, + cached_id) + # prune parameter by ratio + self._prune_parameters( + context.eval_graph, + context.scope, [param], [ratio], + context.place, + lazy=True) + self.pruned_list[0] + # get accuracy after pruning and update self.sensitivities + pruned_metric = self._eval_graph(context, self.eval_rate, + cached_id) + loss = metric - pruned_metric + _logger.info("pruned param: {}; {}; loss={}".format( + param, ratio, loss)) + for brother in self.pruned_list[0]: + if re.match(self.pruned_params, brother): + if brother not in sensitivities: + sensitivities[brother] = { + 'pruned_percent': [], + 'loss': [] + } + sensitivities[brother]['pruned_percent'].append(ratio) + sensitivities[brother]['loss'].append(loss) + + self._save_sensitivities(sensitivities, sensitivities_file) + + # restore pruned parameters + for param_name in self.backup.keys(): + param_t = context.scope.find_var(param_name).get_tensor() + param_t.set(self.backup[param_name], context.place) + +# pruned_metric = self._eval_graph(context) + self.backup = {} + + ratio += self.delta_rate + return sensitivities + + def _get_best_ratios(self, context, sensitivities, target_ratio): + """ + Search a group of ratios for pruning target flops. + """ + _logger.info('_get_best_ratios for pruning ratie: {}'.format( + target_ratio)) + self.param_shape_backup = {} + self.backup = {} + + def func(params, x): + a, b, c, d = params + return a * x * x * x + b * x * x + c * x + d + + def error(params, x, y): + return func(params, x) - y + + def slove_coefficient(x, y): + init_coefficient = [10, 10, 10, 10] + coefficient, loss = leastsq(error, init_coefficient, args=(x, y)) + return coefficient + + min_loss = 0. + max_loss = 0. + + # step 1: fit curve by sensitivities + coefficients = {} + for param in sensitivities: + losses = np.array([0] * 5 + sensitivities[param]['loss']) + precents = np.array([0] * 5 + sensitivities[param][ + 'pruned_percent']) + coefficients[param] = slove_coefficient(precents, losses) + loss = np.max(losses) + max_loss = np.max([max_loss, loss]) + + # step 2: Find a group of ratios by binary searching. + flops = context.eval_graph.flops() + model_size = context.eval_graph.numel_params() + ratios = [] + while min_loss < max_loss: + loss = (max_loss + min_loss) / 2 + _logger.info( + '-----------Try pruned ratios while acc loss={:.4f}-----------'. + format(loss)) + ratios = [] + # step 2.1: Get ratios according to current loss + for param in sensitivities: + coefficient = copy.deepcopy(coefficients[param]) + coefficient[-1] = coefficient[-1] - loss + roots = np.roots(coefficient) + for root in roots: + min_root = 1 + if np.isreal(root) and root > 0 and root < 1: + selected_root = min(root.real, min_root) + ratios.append(selected_root) + _logger.info('Pruned ratios={}'.format( + [round(ratio, 3) for ratio in ratios])) + # step 2.2: Pruning by current ratios + self._prune_parameters( + context.eval_graph, + context.scope, + sensitivities.keys(), + ratios, + context.place, + only_graph=True) + + pruned_flops = 1 - (float(context.eval_graph.flops()) / flops) + pruned_size = 1 - (float(context.eval_graph.numel_params()) / + model_size) + _logger.info('Pruned flops: {:.4f}'.format(pruned_flops)) + _logger.info('Pruned model size: {:.4f}'.format(pruned_size)) + for param in self.param_shape_backup.keys(): + context.eval_graph.var(param).set_shape(self.param_shape_backup[ + param]) + self.param_shape_backup = {} + + # step 2.3: Check whether current ratios is enough + if abs(pruned_flops - target_ratio) < 0.015: + break + if pruned_flops > target_ratio: + max_loss = loss + else: + min_loss = loss + return sensitivities.keys(), ratios + + def _current_pruning_target(self, context): + ''' + Get the target pruning rate in current epoch. + ''' + _logger.info('Left number of pruning steps: {}'.format(self.num_steps)) + if self.num_steps <= 0: + return None + if (self.start_epoch == context.epoch_id) or context.eval_converged( + self.metric_name, 0.005): + self.num_steps -= 1 + return self.pruning_step + + def on_epoch_begin(self, context): + current_ratio = self._current_pruning_target(context) + if current_ratio is not None: + sensitivities = self._compute_sensitivities(context) + params, ratios = self._get_best_ratios(context, sensitivities, + current_ratio) + self._prune_parameters(context.optimize_graph, context.scope, + params, ratios, context.place) + + self.param_shape_backup = {} + self.backup = {} + + model_size = context.eval_graph.numel_params() + flops = context.eval_graph.flops() + _logger.debug('################################') + _logger.debug('# pruning eval graph #') + _logger.debug('################################') + self._prune_graph(context.eval_graph, context.optimize_graph) + context.optimize_graph.update_groups_of_conv() + context.eval_graph.update_groups_of_conv() + context.optimize_graph.compile() # to update the compiled program + context.eval_graph.compile( + for_parallel=False, + for_test=True) # to update the compiled program + _logger.info( + '------------------finish pruning--------------------------------' + ) + _logger.info('Pruned size: {:.3f}'.format(1 - (float( + context.eval_graph.numel_params()) / model_size))) + _logger.info('Pruned flops: {:.3f}'.format(1 - (float( + context.eval_graph.flops()) / flops))) + metric = self._eval_graph(context) + _logger.info('Metric after pruning: {:.2f}'.format(metric)) + _logger.info( + '------------------SensitivePruneStrategy.on_epoch_begin finish--------------------------------' + ) diff --git a/python/paddle/fluid/contrib/slim/prune/pruner.py b/python/paddle/fluid/contrib/slim/prune/pruner.py index ca72bcb6f6004c18f3ec794850e0aeaecb92d7ac..506b8fbe1de2e0f8a036f591bd2baacd5759c9c8 100644 --- a/python/paddle/fluid/contrib/slim/prune/pruner.py +++ b/python/paddle/fluid/contrib/slim/prune/pruner.py @@ -13,9 +13,10 @@ # limitations under the License. import numpy as np +import collections from .... import layers -__all__ = ['Pruner', 'MagnitudePruner', 'RatioPruner'] +__all__ = ['Pruner', 'StructurePruner'] class Pruner(object): @@ -30,54 +31,77 @@ class Pruner(object): pass -class MagnitudePruner(Pruner): +class StructurePruner(Pruner): """ - Pruner used to pruning a parameter by threshold. + Pruner used to pruning parameters by groups. """ - def __init__(self, threshold): - self.threshold = threshold - - def prune(self, param, threshold=None): - if threshold is None: - thres = layers.fill_constant( - shape=[1], dtype='float32', value=self.threshold) - else: - thres = threshold - zeros_mask = layers.less_than(x=param, y=thres) - return zeros_mask - - -class RatioPruner(Pruner): - """ - Pruner used to pruning a parameter by ratio. - """ + def __init__(self, pruning_axis, criterions): + """ + Args: + pruning_axis(dict): The key is the name of parameter to be pruned, + '*' means all the parameters. + The value is the axis to be used. Given a parameter + with shape [3, 4], the result of pruning 50% on aixs 1 + is a parameter with shape [3, 2]. + criterions(dict): The key is the name of parameter to be pruned, + '*' means all the parameters. + The value is the criterion used to sort groups for pruning. + It only supports 'l1_norm' currently. + """ + self.pruning_axis = pruning_axis + self.criterions = criterions - def __init__(self, ratios=None): + def cal_pruned_idx(self, name, param, ratio, axis=None): """ + Calculate the index to be pruned on axis by given pruning ratio. Args: - ratios: dict with pair (paramer_name, pruned_ratio). + name(str): The name of parameter to be pruned. + param(np.array): The data of parameter to be pruned. + ratio(float): The ratio to be pruned. + axis(int): The axis to be used for pruning given parameter. + If it is None, the value in self.pruning_axis will be used. + default: None. + Returns: + list: The indexes to be pruned on axis. """ - self.ratios = ratios + criterion = self.criterions[ + name] if name in self.criterions else self.criterions['*'] + if axis is None: + assert self.pruning_axis is not None, "pruning_axis should set if axis is None." + axis = self.pruning_axis[ + name] if name in self.pruning_axis else self.pruning_axis['*'] + prune_num = int(round(param.shape[axis] * ratio)) + reduce_dims = [i for i in range(len(param.shape)) if i != axis] + if criterion == 'l1_norm': + criterions = np.sum(np.abs(param), axis=tuple(reduce_dims)) + pruned_idx = criterions.argsort()[:prune_num] + return pruned_idx - def prune(self, param, ratio=None): + def prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): """ + Pruning a array by indexes on given axis. Args: - ratio: `ratio=40%` means pruning (1 - 40%) weights to zero. + tensor(numpy.array): The target array to be pruned. + pruned_idx(list): The indexes to be pruned. + pruned_axis(int): The axis of given array to be pruned on. + lazy(bool): True means setting the pruned elements to zero. + False means remove the pruned elements from memory. + default: False. + Returns: + numpy.array: The pruned array. """ - if ratio is None: - rat = self.ratios[ - param.name] if param.name in self.ratios else self.ratios['*'] - else: - rat = ratio - if rat < 1.0: - k = max(int(rat * np.prod(param.shape)), 1) - param_vec = layers.reshape(x=param, shape=[1, -1]) - param_topk, _ = layers.topk(param_vec, k=k) - threshold = layers.slice( - param_topk, axes=[1], starts=[-1], ends=[k]) - threshold = layers.reshape(x=threshold, shape=[1]) - zeros_mask = layers.less_than(x=param, y=threshold) + mask = np.zeros(tensor.shape[pruned_axis], dtype=bool) + mask[pruned_idx] = True + + def func(data): + return data[~mask] + + def lazy_func(data): + data[mask] = 0 + return data + + if lazy: + return np.apply_along_axis(lazy_func, pruned_axis, tensor) else: - zeros_mask = layers.ones(param.shape) - return zeros_mask + return np.apply_along_axis(func, pruned_axis, tensor) diff --git a/python/paddle/fluid/contrib/slim/tests/configs/config.yaml b/python/paddle/fluid/contrib/slim/tests/configs/config.yaml deleted file mode 100644 index d9b49029d3e34d487ad65fe0f7e54e2cee1d5838..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/tests/configs/config.yaml +++ /dev/null @@ -1,29 +0,0 @@ -version: 1.0 -include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"] -pruners: - pruner_1: - class: 'RatioPruner' - ratios: - 'conv1_1.w': 0.3 - 'conv1_2.w': 0.4 - '*': 0.9 - group_dims: - '*': [1, 2, 3] - criterions: - '*': 'l1-norm' -strategies: - strategy_1: - class: 'SensitivePruneStrategy' - pruner: 'pruner_2' - start_epoch: 0 - end_epoch: 10 - delta_rate: 0.20 - acc_loss_threshold: 0.2 - sensitivities: - 'conv1_1.w': 0.4 - -compress_pass: - class: 'CompressPass' - epoch: 100 - strategies: - - strategy_1 diff --git a/python/paddle/fluid/contrib/slim/tests/configs/filter_pruning.yaml b/python/paddle/fluid/contrib/slim/tests/configs/filter_pruning.yaml new file mode 100644 index 0000000000000000000000000000000000000000..570c60026d55c242106f7e2dc5c3f47bfbdbe884 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/configs/filter_pruning.yaml @@ -0,0 +1,34 @@ +#start_epoch: The 'on_epoch_begin' function will be called in start_epoch. default: 0. +#end_epoch: The 'on_epoch_end' function will be called in end_epoch. default: 10. +#delta_rate: The delta used to generate ratios when calculating sensitivities. +#target_ratio: The flops ratio to be pruned from current model. +#metric_name: The metric used to evaluate the model. +#pruned_params: The pattern str to match the parameter names to be pruned. +#sensitivities_file: The sensitivities file. +#num_steps: The number of pruning steps. +#eval_rate: The rate of sampled data used to calculate sensitivities. +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + sensitive_pruning_strategy: + class: 'SensitivePruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + delta_rate: 0.1 + target_ratio: 0.3 + num_steps: 1 + eval_rate: 0.5 + pruned_params: '.*_sep_weights' + sensitivities_file: 'mobilenet_acc_top1_sensitive.data' + metric_name: 'acc_top1' +compressor: + epoch: 120 + checkpoint_path: './checkpoints/' + strategies: + - sensitive_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/configs/pruners.yaml b/python/paddle/fluid/contrib/slim/tests/configs/pruners.yaml deleted file mode 100644 index 235092c595bf7c653221c7fe2b381fecf487fa49..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/tests/configs/pruners.yaml +++ /dev/null @@ -1,12 +0,0 @@ -version: 1.0 -pruners: - pruner_2: - class: 'RatioPruner' - ratios: - 'conv1_1.w': 0.5 - 'conv1_2.w': 0.2 - '*': 0.7 - group_dims: - '*': [1, 2, 3] - criterions: - '*': 'l1-norm' diff --git a/python/paddle/fluid/contrib/slim/tests/configs/pruners_0.yaml b/python/paddle/fluid/contrib/slim/tests/configs/pruners_0.yaml deleted file mode 100644 index cd2ef9eb56ddbc1367ce2e3b413372fbcd542bde..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/contrib/slim/tests/configs/pruners_0.yaml +++ /dev/null @@ -1,12 +0,0 @@ -version: 1.0 -pruners: - pruner_3: - class: 'RatioPruner' - ratios: - 'conv1_1.w': 0.5 - 'conv1_2.w': 0.2 - '*': 0.7 - group_dims: - '*': [1, 2, 3] - criterions: - '*': 'l1-norm' diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py b/python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c32e26092f6ea25771279418582a24ea449ab2 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..232276feac5023c45d594015cf7084b000cd5b4a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml @@ -0,0 +1,34 @@ +#start_epoch: The 'on_epoch_begin' function will be called in start_epoch. default: 0. +#end_epoch: The 'on_epoch_end' function will be called in end_epoch. default: 10. +#delta_rate: The delta used to generate ratios when calculating sensitivities. +#target_ratio: The flops ratio to be pruned from current model. +#metric_name: The metric used to evaluate the model. +#pruned_params: The pattern str to match the parameter names to be pruned. +#sensitivities_file: The sensitivities file. +#num_steps: The number of pruning steps. +#eval_rate: The rate of sampled data used to calculate sensitivities. +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + sensitive_pruning_strategy: + class: 'SensitivePruneStrategy' + pruner: 'pruner_1' + start_epoch: 1 + delta_rate: 0.2 + target_ratio: 0.08 + num_steps: 1 + eval_rate: 0.5 + pruned_params: 'conv6_sep_weights' + sensitivities_file: 'mobilenet_acc_top1_sensitive.data' + metric_name: 'acc_top1' +compressor: + epoch: 2 + checkpoint_path: './checkpoints/' + strategies: + - sensitive_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py b/python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..0148325a642a2bcbebd3d7794056ff2778a3992d --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/mobilenet.py @@ -0,0 +1,210 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['MobileNet'] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class MobileNet(): + def __init__(self): + self.params = train_parameters + + def net(self, input, class_dim=1000, scale=1.0): + # conv1: 112x112 + input = self.conv_bn_layer( + input, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1, + name="conv1") + + # 56x56 + input = self.depthwise_separable( + input, + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale, + name="conv2_1") + + input = self.depthwise_separable( + input, + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale, + name="conv2_2") + + # 28x28 + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale, + name="conv3_1") + + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale, + name="conv3_2") + + # 14x14 + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale, + name="conv4_1") + + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale, + name="conv4_2") + + # 14x14 + for i in range(5): + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale, + name="conv5" + "_" + str(i + 1)) + # 7x7 + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale, + name="conv5_6") + + input = self.depthwise_separable( + input, + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale, + name="conv6") + + input = fluid.layers.pool2d( + input=input, + pool_size=0, + pool_stride=1, + pool_type='avg', + global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr( + initializer=MSRA(), name="fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + depthwise_conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + "_dw") + + pointwise_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep") + return pointwise_conv diff --git a/python/paddle/fluid/contrib/slim/tests/test_factory.py b/python/paddle/fluid/contrib/slim/tests/test_factory.py index 2fc72b6475e6bdd977dafb57696046a1100d0087..90eb8bd4b3caa44880f6df21c7f9f6d460655a8c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_factory.py +++ b/python/paddle/fluid/contrib/slim/tests/test_factory.py @@ -12,29 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.fluid.contrib.slim import ConfigFactory +from paddle.fluid.contrib.slim.core import ConfigFactory import unittest class TestFactory(unittest.TestCase): - def test_parse(self): - factory = ConfigFactory('./configs/config.yaml') + def test_parse_pruning(self): + factory = ConfigFactory('./configs/filter_pruning.yaml') - pruner = factory.instance('pruner_1') - self.assertEquals(pruner.ratios['conv1_1.w'], 0.3) + pruner_1 = factory.instance('pruner_1') + self.assertEquals(pruner_1.pruning_axis['*'], 0) + self.assertEquals(pruner_1.criterions['*'], 'l1_norm') - pruner = factory.instance('pruner_2') - self.assertEquals(pruner.ratios['*'], 0.7) + strategy = factory.instance('sensitive_pruning_strategy') + pruner_1 = strategy.pruner + self.assertEquals(pruner_1.criterions['*'], 'l1_norm') - strategy = factory.instance('strategy_1') - pruner = strategy.pruner - self.assertEquals(pruner.ratios['*'], 0.7) - - compress_pass = factory.get_compress_pass() - self.assertEquals(compress_pass.epoch, 100) - - strategy = compress_pass.strategies[0] - self.assertEquals(strategy.delta_rate, 0.2) + self.assertEquals(strategy.start_epoch, 0) + self.assertEquals(strategy.sensitivities_file, + 'mobilenet_acc_top1_sensitive.data') if __name__ == '__main__': diff --git a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..d73ee27779a0d17a0f60df645a6d2946d665c01e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py @@ -0,0 +1,89 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import unittest +import paddle.fluid as fluid +from filter_pruning.mobilenet import MobileNet +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestFilterPruning(unittest.TestCase): + def test_compression(self): + """ + Model: mobilenet_v1 + data: mnist + step1: Training one epoch + step2: pruning flops + step3: fine-tune one epoch + step4: check top1_acc. + """ + if not fluid.core.is_compiled_with_cuda(): + return + class_dim = 10 + image_shape = [1, 28, 28] + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = MobileNet().net(input=image, class_dim=class_dim) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=False) + + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + val_feed_list = [('img', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + train_optimizer=optimizer) + com_pass.config('./filter_pruning/compress.yaml') + eval_graph = com_pass.run() + self.assertTrue( + abs((com_pass.context.eval_results['acc_top1'][-1] - 0.969) / 0.969) + < 0.02) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ad82aa941183d72353dae19527b21286d6473a63 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py @@ -0,0 +1,140 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function +import unittest +import paddle.fluid as fluid +import six +import numpy as np +from paddle.fluid.contrib.slim.graph import GraphWrapper +from paddle.fluid import core + + +def residual_block(num): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + data = fluid.layers.data(name='image', shape=[1, 8, 8], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + data.stop_gradinet = False + hidden = data + for _ in six.moves.xrange(num): + conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10) + + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return data, label, loss + + +class TestGraphWrapper(unittest.TestCase): + def build_program(self): + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + image, label, self.loss = residual_block(2) + eval_program = main.clone() + opt = fluid.optimizer.SGD(learning_rate=0.001) + opt.minimize(self.loss) + self.scope = core.Scope() + exe = fluid.Executor(place) + exe.run(startup, scope=self.scope) + self.eval_graph = GraphWrapper( + program=eval_program, + in_nodes={'image': image.name, + 'label': label.name}, + out_nodes={'loss': self.loss.name}) + self.train_graph = GraphWrapper( + program=main, + in_nodes={'image': image.name, + 'label': label.name}, + out_nodes={'loss': self.loss.name}) + + def test_all_parameters(self): + self.build_program() + self.assertEquals(len(self.train_graph.all_parameters()), 24) + + def test_all_vars(self): + self.build_program() + self.assertEquals(len(self.train_graph.vars()), 90) + + def test_numel_params(self): + self.build_program() + self.assertEquals(self.train_graph.numel_params(), 13258) + + def test_compile(self): + self.build_program() + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + self.train_graph.compile() + exe.run(self.train_graph.compiled_graph, + scope=self.scope, + feed={ + 'image': + np.random.randint(0, 40, [16, 1, 8, 8]).astype('float32'), + 'label': np.random.randint(0, 10, [16, 1]).astype('int64') + }) + + def test_pre_and_next_ops(self): + self.build_program() + for op in self.train_graph.ops(): + for next_op in self.train_graph.next_ops(op): + self.assertTrue(op in self.train_graph.pre_ops(next_op)) + + def test_get_optimize_graph(self): + self.build_program() + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + opt = fluid.optimizer.SGD(learning_rate=0.001) + train_graph = self.eval_graph.get_optimize_graph( + opt, place, self.scope, no_grad_var_names=['image']) + self.assertEquals(len(self.train_graph.ops()), len(train_graph.ops())) + exe = fluid.Executor(place) + train_graph.compile() + image = np.random.randint(0, 225, [16, 1, 8, 8]).astype('float32') + label = np.random.randint(0, 10, [16, 1]).astype('int64') + exe.run(train_graph.compiled_graph, + scope=self.scope, + feed={'image': image, + 'label': label}) + + def test_flops(self): + self.build_program() + self.assertEquals(self.train_graph.flops(), 354624) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/requirements.txt b/python/requirements.txt index 36bd5d4261cc7aa78d26b8c8ddfd87abd4f4e2e2..ce56462fac9c69df79c3c542202d21c0c67a91b8 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -12,3 +12,4 @@ six funcsigs pyyaml decorator +prettytable