From bdb3e376d07e2eece98710e0dde567e0b1940597 Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 24 Sep 2019 14:42:14 +0800 Subject: [PATCH] [PaddleSlim] Enhence compressor api in PaddleSlim (#19894) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Support customize eval function instead of eval program. 2. Fix loading checkpoint in quantization strategy. 3. Support saving eval model when saving a checkpoint. 4. Fix decoder of loading context in PaddleSlim. 5. Fix restoring from the checkpoint of uniform prune strategy. 6. Support saving eval model and infer model during training. 7. Add ‘unitest’ for saving eval model, saving infer model and uniform pruning restoring from the checkpoint. 8. Fix pruning of depthwise_conv_grad op by updating the groups. --- paddle/fluid/API.spec | 2 +- .../fluid/contrib/slim/core/compressor.py | 61 +++++++++--- .../fluid/contrib/slim/graph/graph_wrapper.py | 52 +++++++++- .../contrib/slim/prune/prune_strategy.py | 48 ++++----- .../quantization/quantization_strategy.py | 60 +++++------ .../contrib/slim/tests/configs/compress.yaml | 4 + .../slim/tests/filter_pruning/compress.yaml | 2 +- .../tests/filter_pruning/uniform_restore.yaml | 21 ++++ .../filter_pruning/uniform_restore_0.yaml | 21 ++++ .../filter_pruning/uniform_restore_1.yaml | 21 ++++ .../slim/tests/quantization/compress_1.yaml | 50 ++++++++++ .../contrib/slim/tests/test_compressor.py | 99 +++++++++++++++++++ .../contrib/slim/tests/test_filter_pruning.py | 75 ++++++++++++++ .../slim/tests/test_quantization_strategy.py | 44 ++++++--- 14 files changed, 481 insertions(+), 79 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/configs/compress.yaml create mode 100644 python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml create mode 100644 python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml create mode 100644 python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml create mode 100644 python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml create mode 100644 python/paddle/fluid/contrib/slim/tests/test_compressor.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6ebce107fad..ae419928355 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -491,7 +491,7 @@ paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'p paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884')) paddle.fluid.contrib.distributed_batch_reader (ArgSpec(args=['batch_reader'], varargs=None, keywords=None, defaults=None), ('document', 'b60796eb0a481484dd34e345f0eaa4d5')) paddle.fluid.contrib.Compressor ('paddle.fluid.contrib.slim.core.compressor.Compressor', ('document', 'a5417774a94aa9ae5560a42b96527e7d')) -paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], None, None, None, None)), ('document', 'c195b3bba26169cff9439e8c467557c0')) +paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'eval_func', 'save_eval_model', 'prune_infer_model', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, None, True, None, [], None, None, None, None)), ('document', '05119e0fa0fc07f5cf848ebf0a2cf070')) paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0')) paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9')) paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67')) diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 2627f7f004b..6ede756599f 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -139,7 +139,7 @@ class Context(object): """ Load the context from file. """ - with open(file_name) as context_file: + with open(file_name, 'rb') as context_file: if sys.version_info < (3, 0): data = pickle.load(context_file) else: @@ -242,6 +242,9 @@ class Compressor(object): eval_reader=None, eval_feed_list=None, eval_fetch_list=None, + eval_func=None, + save_eval_model=True, + prune_infer_model=None, teacher_programs=[], checkpoint_path=None, train_optimizer=None, @@ -260,13 +263,28 @@ class Compressor(object): The key is user-defined and human-readable name. The value is the name of Variable. eval_program(Program): The program used for evaluation. - eval_reader: The data reader used for evaluation. + eval_reader: The data reader used for evaluation. It can be None if eval_func is not None. eval_feed_list(dict): A dict to indicate the input variable of the evaluation program. The key is user-defined and human-readable name. The value is the name of Variable. + It can be None if eval_func is not None. eval_fetch_list(dict): A dict to indicate the output variable of the evaluation program. The key is user-defined and human-readable name. The value is the name of Variable. + eval_func(dict|function): Callback functions used to evaluate the compressed model. + The eval_func is a dict, the key is user-defined name and the value is + a callback function. And the score returned from callback functions + can be referenced in config file by the key of eval_func. + The args of callback function are compressed eval_program and scope which + store the compressed parameters. + Default: None. + save_eval_model(bool): Whether to save eval model when saving checkpoints. Default: True. + prune_infer_model(tuple|list): If prune_infer_model is not None, compressor will prune + eval program into inference program according to inputs and outputs + defined in prune_infer_model. prune_infer_model[0] is a list of input + variables' names and prune_infer_model[1] is a list of output variables' + names. If prune_infer_model is None, it will not save inference model. + Default: None. teacher_programs: The teacher graphs used in distillation strategies. train_optimizer: The optimizer used to append backward ops and optimization ops into train_graph. @@ -294,6 +312,10 @@ class Compressor(object): eval_program, in_nodes=eval_feed_list, out_nodes=eval_fetch_list) self.train_reader = train_reader self.eval_reader = eval_reader + self.eval_func = eval_func + self.save_eval_model = save_eval_model + self.prune_infer_model = prune_infer_model + self.teacher_graphs = [] for teacher in teacher_programs: self.teacher_graphs.append(GraphWrapper(teacher)) @@ -393,6 +415,9 @@ class Compressor(object): strategies = pickle.load( strategy_file, encoding='bytes') + for s, s1 in zip(self.strategies, strategies): + s1.__dict__.update(s.__dict__) + for strategy in strategies: strategy.restore_from_checkpoint(context) @@ -401,10 +426,6 @@ class Compressor(object): with scope_guard(context.scope): context.optimize_graph.load_persistables(model_path, exe) - context.optimize_graph.update_param_shape(context.scope) - context.optimize_graph.update_groups_of_conv() - context.eval_graph.update_param_shape(context.scope) - context.eval_graph.update_groups_of_conv() _logger.info("Loaded params from: {}".format(model_path)) return context, strategies @@ -416,6 +437,7 @@ class Compressor(object): checkpoint_path = os.path.join(self.checkpoint_path, str(context.epoch_id)) model_path = os.path.join(checkpoint_path, 'model') + eval_model_path = os.path.join(checkpoint_path, 'eval_model') context_path = os.path.join(checkpoint_path, 'context') strategy_path = os.path.join(checkpoint_path, 'strategies') if not os.path.isdir(model_path): @@ -423,6 +445,15 @@ class Compressor(object): exe = SlimGraphExecutor(context.place) with scope_guard(context.scope): context.optimize_graph.save_persistables(model_path, exe) + if self.save_eval_model: + context.eval_graph.save_model(eval_model_path, exe) + if self.prune_infer_model: + context.eval_graph.save_infer_model( + eval_model_path, + exe, + self.prune_infer_model, + program_only=self.save_eval_model) + context.to_file(context_path) with open(strategy_path, 'wb') as strategy_file: pickle.dump(self.strategies, strategy_file) @@ -485,11 +516,19 @@ class Compressor(object): """ Runing evaluation. """ - results, names = context.run_eval_graph() - for name, result in zip(names, results): - if name not in context.eval_results: - context.eval_results[name] = [] - context.eval_results[name].append(result) + if self.eval_func is not None: + for key in self.eval_func: + func = self.eval_func[key] + if key not in context.eval_results: + context.eval_results[key] = [] + context.eval_results[key].append( + func(self.eval_graph.program, self.scope)) + else: + results, names = context.run_eval_graph() + for name, result in zip(names, results): + if name not in context.eval_results: + context.eval_results[name] = [] + context.eval_results[name].append(result) def run(self): """ diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index 3ed07a287bd..4c21427e6b9 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -211,6 +211,7 @@ class GraphWrapper(object): self.persistables[var.name] = var self.compiled_graph = None in_nodes = [] if in_nodes is None else in_nodes + out_nodes = [] if out_nodes is None else out_nodes self.in_nodes = OrderedDict(in_nodes) self.out_nodes = OrderedDict(out_nodes) self._attrs = OrderedDict() @@ -471,6 +472,54 @@ class GraphWrapper(object): return flops + def save_model(self, path, exe): + """ + Save network and parameters into file which can be load by load_inference_model api. + Args: + path(str): The path to save the persistables. + exe(framework.Executor): The executor used to save the persistables. + """ + out_vars = [ + self.var(var_name)._var for var_name in self.out_nodes.values() + ] + in_vars = list(self.in_nodes.values()) + assert (len(in_vars) > 0) + assert (len(out_vars) > 0) + io.save_inference_model( + path, + in_vars, + out_vars, + exe.exe, + model_filename="__model__", + params_filename="__params__", + main_program=self.program.clone(), + export_for_deployment=True) + + def save_infer_model(self, path, exe, in_out, program_only=False): + """ + Save network and parameters into file which can be load by load_inference_model api. + Args: + path(str): The path to save the persistables. + exe(framework.Executor): The executor used to save the persistables. + in_out(tuple|list): in_out[0] is a list of input nodes' names + and in_out[1] is a list of output nodes' names. + program_only(bool): Whether to save program only. + """ + out_vars = [self.var(var_name)._var for var_name in in_out[1]] + in_vars = list(in_out[0]) + assert (len(in_vars) > 0) + assert (len(out_vars) > 0) + io.save_inference_model( + path, + in_vars, + out_vars, + exe.exe, + model_filename="__model__.infer", + params_filename="__params__", + program_only=program_only, + main_program=self.program.clone(), + export_for_deployment=True) + def save_persistables(self, path, exe): """ Save all the persistable variables into file. @@ -527,5 +576,6 @@ class GraphWrapper(object): def update_groups_of_conv(self): for op in self.ops(): - if op.type() == 'depthwise_conv2d': + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py index 6f430bc9e2f..ffce2711370 100644 --- a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py +++ b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py @@ -635,31 +635,35 @@ class UniformPruneStrategy(PruneStrategy): _logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios])) return pruned_params, ratios - def on_epoch_begin(self, context): - if context.epoch_id == self.start_epoch: - params, ratios = self._get_best_ratios(context) + def restore_from_checkpoint(self, context): + self._prune(context, self.params, self.ratios) - self._prune_parameters(context.optimize_graph, context.scope, - params, ratios, context.place) + def _prune(self, context, params, ratios): + self._prune_parameters(context.optimize_graph, context.scope, params, + ratios, context.place) - model_size = context.eval_graph.numel_params() - flops = context.eval_graph.flops() - _logger.debug('\n################################') - _logger.debug('# pruning eval graph #') - _logger.debug('################################\n') - self._prune_graph(context.eval_graph, context.optimize_graph) - context.optimize_graph.update_groups_of_conv() - context.eval_graph.update_groups_of_conv() + model_size = context.eval_graph.numel_params() + flops = context.eval_graph.flops() + _logger.debug('\n################################') + _logger.debug('# pruning eval graph #') + _logger.debug('################################\n') + self._prune_graph(context.eval_graph, context.optimize_graph) + context.optimize_graph.update_groups_of_conv() + context.eval_graph.update_groups_of_conv() - _logger.info( - '------------------finish pruning--------------------------------' - ) - _logger.info('Pruned size: {:.2f}'.format(1 - (float( - context.eval_graph.numel_params()) / model_size))) - _logger.info('Pruned flops: {:.2f}'.format(1 - (float( - context.eval_graph.flops()) / flops))) - # metric = self._eval_graph(context) - # _logger.info('Metric after pruning: {:.2f}'.format(metric)) + _logger.info( + '------------------finish pruning--------------------------------') + _logger.info('Pruned size: {:.2f}'.format(1 - (float( + context.eval_graph.numel_params()) / model_size))) + _logger.info('Pruned flops: {:.2f}'.format(1 - (float( + context.eval_graph.flops()) / flops))) + + def on_epoch_begin(self, context): + if context.epoch_id == self.start_epoch: + params, ratios = self._get_best_ratios(context) + self.params = params + self.ratios = ratios + self._prune(context, params, ratios) _logger.info( '------------------UniformPruneStrategy.on_compression_begin finish--------------------------------' ) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index bf4313784ea..5d2b6ea369d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -17,7 +17,7 @@ import sys import numpy as np from .... import Executor from .... import io -from .... import core +from .... import core, scope_guard from ....compiler import CompiledProgram from ....compiler import BuildStrategy from ....framework import IrGraph, Variable, Program @@ -199,15 +199,16 @@ class QuantizationStrategy(Strategy): # save float model if self.float_model_save_path: executor = Executor(context.place) - io.save_inference_model( - self.float_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + with scope_guard(context.scope): + io.save_inference_model( + self.float_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) # save int8 model if self.int8_model_save_path: @@ -216,15 +217,17 @@ class QuantizationStrategy(Strategy): convert_int8_pass.apply(test_ir_graph) executor = Executor(context.place) - io.save_inference_model( - self.int8_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + + with scope_guard(context.scope): + io.save_inference_model( + self.int8_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) # save mobile model if self.mobile_model_save_path: @@ -237,13 +240,14 @@ class QuantizationStrategy(Strategy): mobile_pass = TransformForMobilePass() mobile_pass.apply(test_ir_graph) executor = Executor(context.place) - io.save_inference_model( - self.mobile_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + with scope_guard(context.scope): + io.save_inference_model( + self.mobile_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) _logger.info('Finish QuantizationStrategy::on_epoch_end') diff --git a/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml b/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml new file mode 100644 index 00000000000..604cdf3f447 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml @@ -0,0 +1,4 @@ +version: 1.0 +compressor: + epoch: 1 + checkpoint_path: './checkpoints/' diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml index 5f747a049e9..b21a3626372 100644 --- a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml @@ -28,7 +28,7 @@ strategies: sensitivities_file: 'mobilenet_acc_top1_sensitive.data' metric_name: 'acc_top1' compressor: - epoch: 2 + epoch: 1 checkpoint_path: './checkpoints_pruning/' strategies: - sensitive_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml new file mode 100644 index 00000000000..9e437aedc9d --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 2 + checkpoint_path: './checkpoints_uniform_restore_tmp/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml new file mode 100644 index 00000000000..49f104f98f3 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 1 + checkpoint_path: './checkpoints_uniform_restore/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml new file mode 100644 index 00000000000..82e6793aff9 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 2 + checkpoint_path: './checkpoints_uniform_restore/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml b/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml new file mode 100644 index 00000000000..44e2dc985aa --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml @@ -0,0 +1,50 @@ +#start_epoch(int): The epoch to insert quantization operators. default: 0 +# +#end_epoch(int): The epoch to save inference model. default: 0 +# +#float_model_save_path(str): The path to save model with float weights. +# None means it doesn't save float model. default: None. +# +#mobile_model_save_path(str): The path to save model for paddle-mobile execution. +# None means it doesn't save mobile model. default: None. +# +#int8_model_save_path(str): The path to save model with int8_t weight. +# None means it doesn't save int8 model. default: None. +# +#activation_bits(int): quantization bit number for activation. default: 8. +# +#weight_bits(int): quantization bit number for weights. The bias is not quantized. +# default: 8. +# +#activation_quantize_type(str): quantization type for activation, +# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. +# If use 'abs_max' mode, the quantization scale will be calculated +# dynamically each step in both training and testing period. If use +# 'range_abs_max', a static quantization scale will be calculated +# during training and used in inference. +# +#save_in_nodes(list): A list of variable names used to prune graph +# for saving inference model. +# +#save_out_nodes(list): A list of variable names used to prune graph +# for saving inference model. +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 0 + float_model_save_path: './output/float' + mobile_model_save_path: './output/mobile' + int8_model_save_path: './output/int8' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'abs_max' + save_in_nodes: ['image'] + save_out_nodes: ['quan.tmp_2'] +compressor: + epoch: 2 + checkpoint_path: './checkpoints_quan/' + strategies: + - quantization_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/test_compressor.py b/python/paddle/fluid/contrib/slim/tests/test_compressor.py new file mode 100644 index 00000000000..330c6e3543d --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_compressor.py @@ -0,0 +1,99 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import unittest +import os +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestCompressor(unittest.TestCase): + def test_eval_func(self): + class_dim = 10 + image_shape = [1, 28, 28] + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = fluid.layers.fc(input=image, size=class_dim) + out = fluid.layers.softmax(out) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + val_program = fluid.default_main_program().clone(for_test=False) + + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + eval_feed_list = [('img', image.name), ('label', label.name)] + eval_fetch_list = [('acc_top1', acc_top1.name)] + + def eval_func(program, scope): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder( + feed_list=[image.name, label.name], + place=place, + program=program) + results = [] + for data in val_reader(): + result = exe.run(program=program, + scope=scope, + fetch_list=[acc_top1.name], + feed=feeder.feed(data)) + results.append(np.array(result)) + result = np.mean(results) + return result + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_feed_list=eval_feed_list, + eval_fetch_list=eval_fetch_list, + eval_func={"score": eval_func}, + prune_infer_model=[[image.name], [out.name]], + train_optimizer=optimizer) + com_pass.config('./configs/compress.yaml') + com_pass.run() + self.assertTrue('score' in com_pass.context.eval_results) + self.assertTrue(float(com_pass.context.eval_results['score'][0]) > 0.9) + self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__model__")) + self.assertTrue( + os.path.exists("./checkpoints/0/eval_model/__model__.infer")) + self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__params__")) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py index e1763039b3a..cb956ef6bf0 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py +++ b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py @@ -15,6 +15,7 @@ import paddle import unittest import paddle.fluid as fluid +import numpy as np from mobilenet import MobileNet from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.contrib.slim.graph import GraphWrapper @@ -84,6 +85,80 @@ class TestFilterPruning(unittest.TestCase): abs((com_pass.context.eval_results['acc_top1'][-1] - 0.969) / 0.969) < 0.02) + def test_uniform_restore_from_checkpoint(self): + np.random.seed(0) + self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore_0.yaml") + acc_0 = self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore_1.yaml") + np.random.seed(0) + acc_1 = self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore.yaml") + self.assertTrue(abs((acc_0 - acc_1) / acc_1) < 0.001) + + def uniform_restore_from_checkpoint(self, config_file): + + class_dim = 10 + image_shape = [1, 28, 28] + + train_program = fluid.Program() + startup_program = fluid.Program() + train_program.random_seed = 10 + startup_program.random_seed = 10 + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + out = fluid.layers.conv2d(image, 4, 1) + out = fluid.layers.fc(out, size=class_dim) + out = fluid.layers.softmax(out) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + val_program = train_program.clone(for_test=False) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CPUPlace() + scope = fluid.Scope() + exe = fluid.Executor(place) + exe.run(startup_program, scope=scope) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + val_feed_list = [('img', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + com_pass = Compressor( + place, + scope, + train_program, + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + train_optimizer=optimizer) + com_pass.config(config_file) + eval_graph = com_pass.run() + return com_pass.context.eval_results['acc_top1'][-1] + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py index 92afd892afe..a1ca7108ff0 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py @@ -26,30 +26,44 @@ class TestQuantizationStrategy(unittest.TestCase): """ def test_compression(self): + self.quan("./quantization/compress.yaml") + self.quan("./quantization/compress_1.yaml") + + def quan(self, config_file): if not fluid.core.is_compiled_with_cuda(): return class_dim = 10 image_shape = [1, 28, 28] - image = fluid.layers.data( - name='image', shape=image_shape, dtype='float32') - image.stop_gradient = False - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - out = MobileNet(name='quan').net(input=image, class_dim=class_dim) - acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) - acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) - val_program = fluid.default_main_program().clone(for_test=False) - - cost = fluid.layers.cross_entropy(input=out, label=label) - avg_cost = fluid.layers.mean(x=cost) + + train_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + out = MobileNet(name='quan').net(input=image, + class_dim=class_dim) + print("out: {}".format(out.name)) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + val_program = train_program.clone(for_test=False) optimizer = fluid.optimizer.Momentum( momentum=0.9, learning_rate=0.01, regularization=fluid.regularizer.L2Decay(4e-5)) + scope = fluid.Scope() place = fluid.CUDAPlace(0) exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + exe.run(startup_program, scope=scope) val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) @@ -64,8 +78,8 @@ class TestQuantizationStrategy(unittest.TestCase): com_pass = Compressor( place, - fluid.global_scope(), - fluid.default_main_program(), + scope, + train_program, train_reader=train_reader, train_feed_list=train_feed_list, train_fetch_list=train_fetch_list, @@ -74,7 +88,7 @@ class TestQuantizationStrategy(unittest.TestCase): eval_feed_list=val_feed_list, eval_fetch_list=val_fetch_list, train_optimizer=optimizer) - com_pass.config('./quantization/compress.yaml') + com_pass.config(config_file) eval_graph = com_pass.run() -- GitLab