diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 1547b6abbe660b6be7a681a4e270e3080a5dac36..b97508018ac6da47bfdefadd06a6c3788cb7bd77 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -363,6 +363,9 @@ class Compressor(object): strategies = pickle.load( strategy_file, encoding='bytes') + for strategy in strategies: + strategy.restore_from_checkpoint(context) + if os.path.exists(model_path): exe = SlimGraphExecutor(context.place) with scope_guard(context.scope): diff --git a/python/paddle/fluid/contrib/slim/core/strategy.py b/python/paddle/fluid/contrib/slim/core/strategy.py index 28bf24f4e341dd528d2cd25f6fb24543886150d6..f2cd2a2835b1c19a71679d74736a2d9fe7fc724e 100644 --- a/python/paddle/fluid/contrib/slim/core/strategy.py +++ b/python/paddle/fluid/contrib/slim/core/strategy.py @@ -46,3 +46,6 @@ class Strategy(object): def on_compression_end(self, context): pass + + def restore_from_checkpoint(self, context): + pass diff --git a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py index 2fc6b45183164f135ae3ced08c1900ad526add45..d8e08c3ebef50c9808ed818dcf35443dc25f850e 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py +++ b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py @@ -38,7 +38,7 @@ class DistillationStrategy(Strategy): super(DistillationStrategy, self).__init__(start_epoch, end_epoch) self.distillers = distillers - def on_compression_begin(self, context): + def restore_from_checkpoint(self, context): # load from checkpoint if context.epoch_id > 0: if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch: diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index a22b6da020510838dc82fe7af87ab62db6e874ef..12c1ce98992c32caaa300045c4adc918dd88f427 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -88,7 +88,7 @@ class QuantizationStrategy(Strategy): self.save_out_nodes = save_out_nodes self.save_in_nodes = save_in_nodes - def on_compression_begin(self, context): + def restore_from_checkpoint(self, context): """ Restore graph when the compressoin task is inited from checkpoint. """ @@ -143,10 +143,9 @@ class QuantizationStrategy(Strategy): train_ir_graph.graph).with_data_parallel( loss_name=context.optimize_graph.out_nodes['loss'], build_strategy=build_strategy) - # for evaluation. And program compiled from ir graph must be with data parallel. - context.eval_graph.compiled_graph = CompiledProgram( - test_ir_graph.graph).with_data_parallel( - build_strategy=build_strategy) + + context.eval_graph.program = test_ir_graph.to_program() + # for saving inference model after training context.put('quantization_test_ir_graph_backup', test_ir_graph)