diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6ebce107fad35b093aeaf2fc4af779bf247fa13f..ae41992835584d8106337aacae0a09ba20e72ce3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -491,7 +491,7 @@ paddle.fluid.contrib.QuantizeTranspiler.freeze_program (ArgSpec(args=['self', 'p paddle.fluid.contrib.QuantizeTranspiler.training_transpile (ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6dd9909f10b283ba2892a99058a72884')) paddle.fluid.contrib.distributed_batch_reader (ArgSpec(args=['batch_reader'], varargs=None, keywords=None, defaults=None), ('document', 'b60796eb0a481484dd34e345f0eaa4d5')) paddle.fluid.contrib.Compressor ('paddle.fluid.contrib.slim.core.compressor.Compressor', ('document', 'a5417774a94aa9ae5560a42b96527e7d')) -paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], None, None, None, None)), ('document', 'c195b3bba26169cff9439e8c467557c0')) +paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'eval_func', 'save_eval_model', 'prune_infer_model', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, None, True, None, [], None, None, None, None)), ('document', '05119e0fa0fc07f5cf848ebf0a2cf070')) paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0')) paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9')) paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67')) diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 2627f7f004bc47a5d1b2e5e22d7fe05373ae3ec8..6ede756599fde04ef9151ea1c1e10d91ac7ee507 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -139,7 +139,7 @@ class Context(object): """ Load the context from file. """ - with open(file_name) as context_file: + with open(file_name, 'rb') as context_file: if sys.version_info < (3, 0): data = pickle.load(context_file) else: @@ -242,6 +242,9 @@ class Compressor(object): eval_reader=None, eval_feed_list=None, eval_fetch_list=None, + eval_func=None, + save_eval_model=True, + prune_infer_model=None, teacher_programs=[], checkpoint_path=None, train_optimizer=None, @@ -260,13 +263,28 @@ class Compressor(object): The key is user-defined and human-readable name. The value is the name of Variable. eval_program(Program): The program used for evaluation. - eval_reader: The data reader used for evaluation. + eval_reader: The data reader used for evaluation. It can be None if eval_func is not None. eval_feed_list(dict): A dict to indicate the input variable of the evaluation program. The key is user-defined and human-readable name. The value is the name of Variable. + It can be None if eval_func is not None. eval_fetch_list(dict): A dict to indicate the output variable of the evaluation program. The key is user-defined and human-readable name. The value is the name of Variable. + eval_func(dict|function): Callback functions used to evaluate the compressed model. + The eval_func is a dict, the key is user-defined name and the value is + a callback function. And the score returned from callback functions + can be referenced in config file by the key of eval_func. + The args of callback function are compressed eval_program and scope which + store the compressed parameters. + Default: None. + save_eval_model(bool): Whether to save eval model when saving checkpoints. Default: True. + prune_infer_model(tuple|list): If prune_infer_model is not None, compressor will prune + eval program into inference program according to inputs and outputs + defined in prune_infer_model. prune_infer_model[0] is a list of input + variables' names and prune_infer_model[1] is a list of output variables' + names. If prune_infer_model is None, it will not save inference model. + Default: None. teacher_programs: The teacher graphs used in distillation strategies. train_optimizer: The optimizer used to append backward ops and optimization ops into train_graph. @@ -294,6 +312,10 @@ class Compressor(object): eval_program, in_nodes=eval_feed_list, out_nodes=eval_fetch_list) self.train_reader = train_reader self.eval_reader = eval_reader + self.eval_func = eval_func + self.save_eval_model = save_eval_model + self.prune_infer_model = prune_infer_model + self.teacher_graphs = [] for teacher in teacher_programs: self.teacher_graphs.append(GraphWrapper(teacher)) @@ -393,6 +415,9 @@ class Compressor(object): strategies = pickle.load( strategy_file, encoding='bytes') + for s, s1 in zip(self.strategies, strategies): + s1.__dict__.update(s.__dict__) + for strategy in strategies: strategy.restore_from_checkpoint(context) @@ -401,10 +426,6 @@ class Compressor(object): with scope_guard(context.scope): context.optimize_graph.load_persistables(model_path, exe) - context.optimize_graph.update_param_shape(context.scope) - context.optimize_graph.update_groups_of_conv() - context.eval_graph.update_param_shape(context.scope) - context.eval_graph.update_groups_of_conv() _logger.info("Loaded params from: {}".format(model_path)) return context, strategies @@ -416,6 +437,7 @@ class Compressor(object): checkpoint_path = os.path.join(self.checkpoint_path, str(context.epoch_id)) model_path = os.path.join(checkpoint_path, 'model') + eval_model_path = os.path.join(checkpoint_path, 'eval_model') context_path = os.path.join(checkpoint_path, 'context') strategy_path = os.path.join(checkpoint_path, 'strategies') if not os.path.isdir(model_path): @@ -423,6 +445,15 @@ class Compressor(object): exe = SlimGraphExecutor(context.place) with scope_guard(context.scope): context.optimize_graph.save_persistables(model_path, exe) + if self.save_eval_model: + context.eval_graph.save_model(eval_model_path, exe) + if self.prune_infer_model: + context.eval_graph.save_infer_model( + eval_model_path, + exe, + self.prune_infer_model, + program_only=self.save_eval_model) + context.to_file(context_path) with open(strategy_path, 'wb') as strategy_file: pickle.dump(self.strategies, strategy_file) @@ -485,11 +516,19 @@ class Compressor(object): """ Runing evaluation. """ - results, names = context.run_eval_graph() - for name, result in zip(names, results): - if name not in context.eval_results: - context.eval_results[name] = [] - context.eval_results[name].append(result) + if self.eval_func is not None: + for key in self.eval_func: + func = self.eval_func[key] + if key not in context.eval_results: + context.eval_results[key] = [] + context.eval_results[key].append( + func(self.eval_graph.program, self.scope)) + else: + results, names = context.run_eval_graph() + for name, result in zip(names, results): + if name not in context.eval_results: + context.eval_results[name] = [] + context.eval_results[name].append(result) def run(self): """ diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index 3ed07a287bd280ba241794ca4423261247017cb7..4c21427e6b91f164aab07947f787ded6ae2cca02 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -211,6 +211,7 @@ class GraphWrapper(object): self.persistables[var.name] = var self.compiled_graph = None in_nodes = [] if in_nodes is None else in_nodes + out_nodes = [] if out_nodes is None else out_nodes self.in_nodes = OrderedDict(in_nodes) self.out_nodes = OrderedDict(out_nodes) self._attrs = OrderedDict() @@ -471,6 +472,54 @@ class GraphWrapper(object): return flops + def save_model(self, path, exe): + """ + Save network and parameters into file which can be load by load_inference_model api. + Args: + path(str): The path to save the persistables. + exe(framework.Executor): The executor used to save the persistables. + """ + out_vars = [ + self.var(var_name)._var for var_name in self.out_nodes.values() + ] + in_vars = list(self.in_nodes.values()) + assert (len(in_vars) > 0) + assert (len(out_vars) > 0) + io.save_inference_model( + path, + in_vars, + out_vars, + exe.exe, + model_filename="__model__", + params_filename="__params__", + main_program=self.program.clone(), + export_for_deployment=True) + + def save_infer_model(self, path, exe, in_out, program_only=False): + """ + Save network and parameters into file which can be load by load_inference_model api. + Args: + path(str): The path to save the persistables. + exe(framework.Executor): The executor used to save the persistables. + in_out(tuple|list): in_out[0] is a list of input nodes' names + and in_out[1] is a list of output nodes' names. + program_only(bool): Whether to save program only. + """ + out_vars = [self.var(var_name)._var for var_name in in_out[1]] + in_vars = list(in_out[0]) + assert (len(in_vars) > 0) + assert (len(out_vars) > 0) + io.save_inference_model( + path, + in_vars, + out_vars, + exe.exe, + model_filename="__model__.infer", + params_filename="__params__", + program_only=program_only, + main_program=self.program.clone(), + export_for_deployment=True) + def save_persistables(self, path, exe): """ Save all the persistable variables into file. @@ -527,5 +576,6 @@ class GraphWrapper(object): def update_groups_of_conv(self): for op in self.ops(): - if op.type() == 'depthwise_conv2d': + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py index 6f430bc9e2fee375c813aeac1e05045b3b42afa4..ffce27113705b86bec963a77d67a448ec4cf360e 100644 --- a/python/paddle/fluid/contrib/slim/prune/prune_strategy.py +++ b/python/paddle/fluid/contrib/slim/prune/prune_strategy.py @@ -635,31 +635,35 @@ class UniformPruneStrategy(PruneStrategy): _logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios])) return pruned_params, ratios - def on_epoch_begin(self, context): - if context.epoch_id == self.start_epoch: - params, ratios = self._get_best_ratios(context) + def restore_from_checkpoint(self, context): + self._prune(context, self.params, self.ratios) - self._prune_parameters(context.optimize_graph, context.scope, - params, ratios, context.place) + def _prune(self, context, params, ratios): + self._prune_parameters(context.optimize_graph, context.scope, params, + ratios, context.place) - model_size = context.eval_graph.numel_params() - flops = context.eval_graph.flops() - _logger.debug('\n################################') - _logger.debug('# pruning eval graph #') - _logger.debug('################################\n') - self._prune_graph(context.eval_graph, context.optimize_graph) - context.optimize_graph.update_groups_of_conv() - context.eval_graph.update_groups_of_conv() + model_size = context.eval_graph.numel_params() + flops = context.eval_graph.flops() + _logger.debug('\n################################') + _logger.debug('# pruning eval graph #') + _logger.debug('################################\n') + self._prune_graph(context.eval_graph, context.optimize_graph) + context.optimize_graph.update_groups_of_conv() + context.eval_graph.update_groups_of_conv() - _logger.info( - '------------------finish pruning--------------------------------' - ) - _logger.info('Pruned size: {:.2f}'.format(1 - (float( - context.eval_graph.numel_params()) / model_size))) - _logger.info('Pruned flops: {:.2f}'.format(1 - (float( - context.eval_graph.flops()) / flops))) - # metric = self._eval_graph(context) - # _logger.info('Metric after pruning: {:.2f}'.format(metric)) + _logger.info( + '------------------finish pruning--------------------------------') + _logger.info('Pruned size: {:.2f}'.format(1 - (float( + context.eval_graph.numel_params()) / model_size))) + _logger.info('Pruned flops: {:.2f}'.format(1 - (float( + context.eval_graph.flops()) / flops))) + + def on_epoch_begin(self, context): + if context.epoch_id == self.start_epoch: + params, ratios = self._get_best_ratios(context) + self.params = params + self.ratios = ratios + self._prune(context, params, ratios) _logger.info( '------------------UniformPruneStrategy.on_compression_begin finish--------------------------------' ) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index bf4313784eab0f846c2a2a45d7ed98509103d94f..5d2b6ea369dedfd1f1437ae626f7f3b3eb6a21a7 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -17,7 +17,7 @@ import sys import numpy as np from .... import Executor from .... import io -from .... import core +from .... import core, scope_guard from ....compiler import CompiledProgram from ....compiler import BuildStrategy from ....framework import IrGraph, Variable, Program @@ -199,15 +199,16 @@ class QuantizationStrategy(Strategy): # save float model if self.float_model_save_path: executor = Executor(context.place) - io.save_inference_model( - self.float_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + with scope_guard(context.scope): + io.save_inference_model( + self.float_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) # save int8 model if self.int8_model_save_path: @@ -216,15 +217,17 @@ class QuantizationStrategy(Strategy): convert_int8_pass.apply(test_ir_graph) executor = Executor(context.place) - io.save_inference_model( - self.int8_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + + with scope_guard(context.scope): + io.save_inference_model( + self.int8_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) # save mobile model if self.mobile_model_save_path: @@ -237,13 +240,14 @@ class QuantizationStrategy(Strategy): mobile_pass = TransformForMobilePass() mobile_pass.apply(test_ir_graph) executor = Executor(context.place) - io.save_inference_model( - self.mobile_model_save_path, - in_vars, - out_vars, - executor, - main_program=test_ir_graph.to_program(), - model_filename='model', - params_filename='weights', - export_for_deployment=True) + with scope_guard(context.scope): + io.save_inference_model( + self.mobile_model_save_path, + in_vars, + out_vars, + executor, + main_program=test_ir_graph.to_program(), + model_filename='model', + params_filename='weights', + export_for_deployment=True) _logger.info('Finish QuantizationStrategy::on_epoch_end') diff --git a/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml b/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml new file mode 100644 index 0000000000000000000000000000000000000000..604cdf3f447ae0ed17700fe53f1daf6ded77399a --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/configs/compress.yaml @@ -0,0 +1,4 @@ +version: 1.0 +compressor: + epoch: 1 + checkpoint_path: './checkpoints/' diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml index 5f747a049e95a5920236336c69a80a9492e6190d..b21a36263727f39f7eb13778d9b326dd045d9627 100644 --- a/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/compress.yaml @@ -28,7 +28,7 @@ strategies: sensitivities_file: 'mobilenet_acc_top1_sensitive.data' metric_name: 'acc_top1' compressor: - epoch: 2 + epoch: 1 checkpoint_path: './checkpoints_pruning/' strategies: - sensitive_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9e437aedc9d2427394fb697ca1898baffb00a109 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 2 + checkpoint_path: './checkpoints_uniform_restore_tmp/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..49f104f98f3854ee831ebbea1ff6fa9c7817a15b --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_0.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 1 + checkpoint_path: './checkpoints_uniform_restore/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82e6793aff97d261a83d88dbc077e76e652e1fe1 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/filter_pruning/uniform_restore_1.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: 'conv.*' + metric_name: 'acc_top1' +compressor: + epoch: 2 + checkpoint_path: './checkpoints_uniform_restore/' + strategies: + - uniform_pruning_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml b/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44e2dc985aac65306a3b05860a26a1d60fa5cf44 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/quantization/compress_1.yaml @@ -0,0 +1,50 @@ +#start_epoch(int): The epoch to insert quantization operators. default: 0 +# +#end_epoch(int): The epoch to save inference model. default: 0 +# +#float_model_save_path(str): The path to save model with float weights. +# None means it doesn't save float model. default: None. +# +#mobile_model_save_path(str): The path to save model for paddle-mobile execution. +# None means it doesn't save mobile model. default: None. +# +#int8_model_save_path(str): The path to save model with int8_t weight. +# None means it doesn't save int8 model. default: None. +# +#activation_bits(int): quantization bit number for activation. default: 8. +# +#weight_bits(int): quantization bit number for weights. The bias is not quantized. +# default: 8. +# +#activation_quantize_type(str): quantization type for activation, +# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. +# If use 'abs_max' mode, the quantization scale will be calculated +# dynamically each step in both training and testing period. If use +# 'range_abs_max', a static quantization scale will be calculated +# during training and used in inference. +# +#save_in_nodes(list): A list of variable names used to prune graph +# for saving inference model. +# +#save_out_nodes(list): A list of variable names used to prune graph +# for saving inference model. +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 0 + float_model_save_path: './output/float' + mobile_model_save_path: './output/mobile' + int8_model_save_path: './output/int8' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'abs_max' + save_in_nodes: ['image'] + save_out_nodes: ['quan.tmp_2'] +compressor: + epoch: 2 + checkpoint_path: './checkpoints_quan/' + strategies: + - quantization_strategy diff --git a/python/paddle/fluid/contrib/slim/tests/test_compressor.py b/python/paddle/fluid/contrib/slim/tests/test_compressor.py new file mode 100644 index 0000000000000000000000000000000000000000..330c6e3543ddb44e1016ffdbf14d65116422e54e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_compressor.py @@ -0,0 +1,99 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import paddle +import unittest +import os +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.contrib.slim.core import Compressor +from paddle.fluid.contrib.slim.graph import GraphWrapper + + +class TestCompressor(unittest.TestCase): + def test_eval_func(self): + class_dim = 10 + image_shape = [1, 28, 28] + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = fluid.layers.fc(input=image, size=class_dim) + out = fluid.layers.softmax(out) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + val_program = fluid.default_main_program().clone(for_test=False) + + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + eval_feed_list = [('img', image.name), ('label', label.name)] + eval_fetch_list = [('acc_top1', acc_top1.name)] + + def eval_func(program, scope): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + feeder = fluid.DataFeeder( + feed_list=[image.name, label.name], + place=place, + program=program) + results = [] + for data in val_reader(): + result = exe.run(program=program, + scope=scope, + fetch_list=[acc_top1.name], + feed=feeder.feed(data)) + results.append(np.array(result)) + result = np.mean(results) + return result + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_feed_list=eval_feed_list, + eval_fetch_list=eval_fetch_list, + eval_func={"score": eval_func}, + prune_infer_model=[[image.name], [out.name]], + train_optimizer=optimizer) + com_pass.config('./configs/compress.yaml') + com_pass.run() + self.assertTrue('score' in com_pass.context.eval_results) + self.assertTrue(float(com_pass.context.eval_results['score'][0]) > 0.9) + self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__model__")) + self.assertTrue( + os.path.exists("./checkpoints/0/eval_model/__model__.infer")) + self.assertTrue(os.path.exists("./checkpoints/0/eval_model/__params__")) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py index e1763039b3a962a43f2fe3a22c05cb32cba596ed..cb956ef6bf09e0172d9e0caea1c76d5bf78fcfef 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py +++ b/python/paddle/fluid/contrib/slim/tests/test_filter_pruning.py @@ -15,6 +15,7 @@ import paddle import unittest import paddle.fluid as fluid +import numpy as np from mobilenet import MobileNet from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.contrib.slim.graph import GraphWrapper @@ -84,6 +85,80 @@ class TestFilterPruning(unittest.TestCase): abs((com_pass.context.eval_results['acc_top1'][-1] - 0.969) / 0.969) < 0.02) + def test_uniform_restore_from_checkpoint(self): + np.random.seed(0) + self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore_0.yaml") + acc_0 = self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore_1.yaml") + np.random.seed(0) + acc_1 = self.uniform_restore_from_checkpoint( + "./filter_pruning/uniform_restore.yaml") + self.assertTrue(abs((acc_0 - acc_1) / acc_1) < 0.001) + + def uniform_restore_from_checkpoint(self, config_file): + + class_dim = 10 + image_shape = [1, 28, 28] + + train_program = fluid.Program() + startup_program = fluid.Program() + train_program.random_seed = 10 + startup_program.random_seed = 10 + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + out = fluid.layers.conv2d(image, 4, 1) + out = fluid.layers.fc(out, size=class_dim) + out = fluid.layers.softmax(out) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + val_program = train_program.clone(for_test=False) + + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CPUPlace() + scope = fluid.Scope() + exe = fluid.Executor(place) + exe.run(startup_program, scope=scope) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + val_feed_list = [('img', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', + acc_top5.name)] + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=128) + train_feed_list = [('img', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + com_pass = Compressor( + place, + scope, + train_program, + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + train_optimizer=optimizer) + com_pass.config(config_file) + eval_graph = com_pass.run() + return com_pass.context.eval_results['acc_top1'][-1] + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py index 92afd892afed86e69266c9ab9c97d90daebb86d5..a1ca7108ff08678236d6bbd17de6bd9408d8136c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_strategy.py @@ -26,30 +26,44 @@ class TestQuantizationStrategy(unittest.TestCase): """ def test_compression(self): + self.quan("./quantization/compress.yaml") + self.quan("./quantization/compress_1.yaml") + + def quan(self, config_file): if not fluid.core.is_compiled_with_cuda(): return class_dim = 10 image_shape = [1, 28, 28] - image = fluid.layers.data( - name='image', shape=image_shape, dtype='float32') - image.stop_gradient = False - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - out = MobileNet(name='quan').net(input=image, class_dim=class_dim) - acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) - acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) - val_program = fluid.default_main_program().clone(for_test=False) - - cost = fluid.layers.cross_entropy(input=out, label=label) - avg_cost = fluid.layers.mean(x=cost) + + train_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + image.stop_gradient = False + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + out = MobileNet(name='quan').net(input=image, + class_dim=class_dim) + print("out: {}".format(out.name)) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + + val_program = train_program.clone(for_test=False) optimizer = fluid.optimizer.Momentum( momentum=0.9, learning_rate=0.01, regularization=fluid.regularizer.L2Decay(4e-5)) + scope = fluid.Scope() place = fluid.CUDAPlace(0) exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + exe.run(startup_program, scope=scope) val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) @@ -64,8 +78,8 @@ class TestQuantizationStrategy(unittest.TestCase): com_pass = Compressor( place, - fluid.global_scope(), - fluid.default_main_program(), + scope, + train_program, train_reader=train_reader, train_feed_list=train_feed_list, train_fetch_list=train_fetch_list, @@ -74,7 +88,7 @@ class TestQuantizationStrategy(unittest.TestCase): eval_feed_list=val_feed_list, eval_fetch_list=val_fetch_list, train_optimizer=optimizer) - com_pass.config('./quantization/compress.yaml') + com_pass.config(config_file) eval_graph = com_pass.run()