From 6db7c2a500f04ab2b4f54ca2657e1e3ba5bd8e46 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 29 Mar 2019 01:38:38 +0800 Subject: [PATCH] Fix checkpoint of quantization. --- .../fluid/contrib/slim/graph/graph_wrapper.py | 13 ++- .../quantization/quantization_strategy.py | 93 +++++++++++++------ 2 files changed, 75 insertions(+), 31 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index 7388ecd3b0..e7f5f0d6a2 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -204,6 +204,10 @@ class GraphWrapper(object): """ super(GraphWrapper, self).__init__() self.program = Program() if program is None else program + self.persistables = {} + for var in self.program.list_vars(): + if var.persistable: + self.persistables[var.name] = var self.compiled_graph = None self.in_nodes = OrderedDict(in_nodes) self.out_nodes = OrderedDict(out_nodes) @@ -467,7 +471,12 @@ class GraphWrapper(object): path(str): The path to save the persistables. exe(framework.Executor): The executor used to save the persistables. """ - io.save_persistables(exe.exe, path, main_program=self.program) + # update persistables from program + for var in self.program.list_vars(): + if var.persistable and var.name not in self.persistables: + self.persistables[var.name] = var + + io.save_vars(exe.exe, path, vars=self.persistables.values()) def load_persistables(self, path, exe): """ @@ -481,7 +490,7 @@ class GraphWrapper(object): return os.path.exists(os.path.join(path, var.name)) io.load_vars( - exe.exe, path, main_program=self.program, predicate=if_exist) + exe.exe, path, vars=self.persistables.values(), predicate=if_exist) def update_param_shape(self, scope): """ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index 6812b4c633..7f79991952 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -20,7 +20,7 @@ from .... import io from .... import core from ....compiler import CompiledProgram from ....compiler import BuildStrategy -from ....framework import IrGraph +from ....framework import IrGraph, Variable, Program from ..core.strategy import Strategy from .quantization_pass import * @@ -84,40 +84,75 @@ class QuantizationStrategy(Strategy): self.save_out_nodes = save_out_nodes self.save_in_nodes = save_in_nodes + def on_compression_begin(self, context): + """ + Restore graph when the compressoin task is inited from checkpoint. + """ + # It is inited from checkpoint and has missed start epoch. + if context.epoch_id != 0 and context.epoch_id > self.start_epoch: + _logger.info("Restore quantization task from checkpoint") + self._modify_graph_for_quantization(context) + _logger.info("Finish restoring quantization task from checkpoint") + + def _modify_graph_for_quantization(self, context): + """ + Insert fake_quantize_op and fake_dequantize_op before trainging and testing. + """ + train_ir_graph = IrGraph( + core.Graph(context.optimize_graph.program.clone().desc), + for_test=False) + test_ir_graph = IrGraph( + core.Graph(context.eval_graph.program.clone().desc), for_test=True) + transform_pass = QuantizationTransformPass( + scope=context.scope, + place=context.place, + weight_bits=self.weight_bits, + activation_bits=self.activation_bits, + activation_quantize_type=self.activation_quantize_type) + transform_pass.apply(train_ir_graph) + transform_pass.apply(test_ir_graph) + # Put persistables created by transform_pass into context.optimize_graph.persistables + # for saving checkpoint. + program_persistables = set() + for var in context.optimize_graph.program.list_vars(): + if var.persistable: + program_persistables.add(var.name) + + program = Program() + for var_node in train_ir_graph.all_persistable_nodes(): + if var_node.name() not in program_persistables: + var_desc = var_node.var() + var = program.global_block().create_var( + name=var_node.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level()) + context.optimize_graph.persistables[var.name] = var + + build_strategy = BuildStrategy() + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + # for quantization training + context.optimize_graph.compiled_graph = CompiledProgram( + train_ir_graph.graph).with_data_parallel( + loss_name=context.optimize_graph.out_nodes['loss'], + build_strategy=build_strategy) + # for evaluation. And program compiled from ir graph must be with data parallel. + context.eval_graph.compiled_graph = CompiledProgram( + test_ir_graph.graph).with_data_parallel( + build_strategy=build_strategy) + # for saving inference model after training + context.put('quantization_test_ir_graph_backup', test_ir_graph) + def on_epoch_begin(self, context): """ Insert fake_quantize_op and fake_dequantize_op before trainging and testing. """ - super(QuantizationStrategy, self).on_compression_begin(context) + super(QuantizationStrategy, self).on_epoch_begin(context) if self.start_epoch == context.epoch_id: _logger.info('QuantizationStrategy::on_epoch_begin') - train_ir_graph = IrGraph( - core.Graph(context.optimize_graph.program.desc), for_test=False) - test_ir_graph = IrGraph( - core.Graph(context.eval_graph.program.desc), for_test=True) - transform_pass = QuantizationTransformPass( - scope=context.scope, - place=context.place, - weight_bits=self.weight_bits, - activation_bits=self.activation_bits, - activation_quantize_type=self.activation_quantize_type) - transform_pass.apply(train_ir_graph) - transform_pass.apply(test_ir_graph) - - build_strategy = BuildStrategy() - build_strategy.enable_inplace = False - build_strategy.memory_optimize = False - # for quantization training - context.optimize_graph.compiled_graph = CompiledProgram( - train_ir_graph.graph).with_data_parallel( - loss_name=context.optimize_graph.out_nodes['loss'], - build_strategy=build_strategy) - # for evaluation. And program compiled from ir graph must be with data parallel. - context.eval_graph.compiled_graph = CompiledProgram( - test_ir_graph.graph).with_data_parallel( - build_strategy=build_strategy) - # for saving inference model after training - context.put('quantization_test_ir_graph_backup', test_ir_graph) + self._modify_graph_for_quantization(context) _logger.info('Finish QuantizationStrategy::on_epoch_begin') def on_epoch_end(self, context): -- GitLab