From 486f7d8ed61518508c0c653c84015b63285a1521 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 18 Apr 2019 23:54:00 +0800 Subject: [PATCH] Restore quantization and distillation stategy before loading persistables. (#16958) test=develop --- python/paddle/fluid/contrib/slim/core/compressor.py | 3 +++ python/paddle/fluid/contrib/slim/core/strategy.py | 3 +++ .../contrib/slim/distillation/distillation_strategy.py | 2 +- .../contrib/slim/quantization/quantization_strategy.py | 9 ++++----- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/core/compressor.py b/python/paddle/fluid/contrib/slim/core/compressor.py index 1547b6abb..b97508018 100644 --- a/python/paddle/fluid/contrib/slim/core/compressor.py +++ b/python/paddle/fluid/contrib/slim/core/compressor.py @@ -363,6 +363,9 @@ class Compressor(object): strategies = pickle.load( strategy_file, encoding='bytes') + for strategy in strategies: + strategy.restore_from_checkpoint(context) + if os.path.exists(model_path): exe = SlimGraphExecutor(context.place) with scope_guard(context.scope): diff --git a/python/paddle/fluid/contrib/slim/core/strategy.py b/python/paddle/fluid/contrib/slim/core/strategy.py index 28bf24f4e..f2cd2a283 100644 --- a/python/paddle/fluid/contrib/slim/core/strategy.py +++ b/python/paddle/fluid/contrib/slim/core/strategy.py @@ -46,3 +46,6 @@ class Strategy(object): def on_compression_end(self, context): pass + + def restore_from_checkpoint(self, context): + pass diff --git a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py index 2fc6b4518..d8e08c3eb 100644 --- a/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py +++ b/python/paddle/fluid/contrib/slim/distillation/distillation_strategy.py @@ -38,7 +38,7 @@ class DistillationStrategy(Strategy): super(DistillationStrategy, self).__init__(start_epoch, end_epoch) self.distillers = distillers - def on_compression_begin(self, context): + def restore_from_checkpoint(self, context): # load from checkpoint if context.epoch_id > 0: if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch: diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py index a22b6da02..12c1ce989 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_strategy.py @@ -88,7 +88,7 @@ class QuantizationStrategy(Strategy): self.save_out_nodes = save_out_nodes self.save_in_nodes = save_in_nodes - def on_compression_begin(self, context): + def restore_from_checkpoint(self, context): """ Restore graph when the compressoin task is inited from checkpoint. """ @@ -143,10 +143,9 @@ class QuantizationStrategy(Strategy): train_ir_graph.graph).with_data_parallel( loss_name=context.optimize_graph.out_nodes['loss'], build_strategy=build_strategy) - # for evaluation. And program compiled from ir graph must be with data parallel. - context.eval_graph.compiled_graph = CompiledProgram( - test_ir_graph.graph).with_data_parallel( - build_strategy=build_strategy) + + context.eval_graph.program = test_ir_graph.to_program() + # for saving inference model after training context.put('quantization_test_ir_graph_backup', test_ir_graph) -- GitLab