未验证 提交 4c1ec41d 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16531 from wanghaoshuang/quan_ck

[slim]  Fix checkpoint of quantization strategy.
...@@ -204,6 +204,10 @@ class GraphWrapper(object): ...@@ -204,6 +204,10 @@ class GraphWrapper(object):
""" """
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program self.program = Program() if program is None else program
self.persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None self.compiled_graph = None
self.in_nodes = OrderedDict(in_nodes) self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes) self.out_nodes = OrderedDict(out_nodes)
...@@ -467,7 +471,12 @@ class GraphWrapper(object): ...@@ -467,7 +471,12 @@ class GraphWrapper(object):
path(str): The path to save the persistables. path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables. exe(framework.Executor): The executor used to save the persistables.
""" """
io.save_persistables(exe.exe, path, main_program=self.program) # update persistables from program
for var in self.program.list_vars():
if var.persistable and var.name not in self.persistables:
self.persistables[var.name] = var
io.save_vars(exe.exe, path, vars=self.persistables.values())
def load_persistables(self, path, exe): def load_persistables(self, path, exe):
""" """
...@@ -481,7 +490,7 @@ class GraphWrapper(object): ...@@ -481,7 +490,7 @@ class GraphWrapper(object):
return os.path.exists(os.path.join(path, var.name)) return os.path.exists(os.path.join(path, var.name))
io.load_vars( io.load_vars(
exe.exe, path, main_program=self.program, predicate=if_exist) exe.exe, path, vars=self.persistables.values(), predicate=if_exist)
def update_param_shape(self, scope): def update_param_shape(self, scope):
""" """
......
...@@ -20,7 +20,7 @@ from .... import io ...@@ -20,7 +20,7 @@ from .... import io
from .... import core from .... import core
from ....compiler import CompiledProgram from ....compiler import CompiledProgram
from ....compiler import BuildStrategy from ....compiler import BuildStrategy
from ....framework import IrGraph from ....framework import IrGraph, Variable, Program
from ..core.strategy import Strategy from ..core.strategy import Strategy
from .quantization_pass import * from .quantization_pass import *
...@@ -88,17 +88,25 @@ class QuantizationStrategy(Strategy): ...@@ -88,17 +88,25 @@ class QuantizationStrategy(Strategy):
self.save_out_nodes = save_out_nodes self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes self.save_in_nodes = save_in_nodes
def on_epoch_begin(self, context): def on_compression_begin(self, context):
"""
Restore graph when the compressoin task is inited from checkpoint.
"""
# It is inited from checkpoint and has missed start epoch.
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
_logger.info("Restore quantization task from checkpoint")
self._modify_graph_for_quantization(context)
_logger.info("Finish restoring quantization task from checkpoint")
def _modify_graph_for_quantization(self, context):
""" """
Insert fake_quantize_op and fake_dequantize_op before trainging and testing. Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
""" """
super(QuantizationStrategy, self).on_compression_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
train_ir_graph = IrGraph( train_ir_graph = IrGraph(
core.Graph(context.optimize_graph.program.desc), for_test=False) core.Graph(context.optimize_graph.program.clone().desc),
for_test=False)
test_ir_graph = IrGraph( test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.desc), for_test=True) core.Graph(context.eval_graph.program.clone().desc), for_test=True)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=context.scope, scope=context.scope,
place=context.place, place=context.place,
...@@ -108,6 +116,24 @@ class QuantizationStrategy(Strategy): ...@@ -108,6 +116,24 @@ class QuantizationStrategy(Strategy):
weight_quantize_type=self.weight_quantize_type) weight_quantize_type=self.weight_quantize_type)
transform_pass.apply(train_ir_graph) transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph) transform_pass.apply(test_ir_graph)
# Put persistables created by transform_pass into context.optimize_graph.persistables
# for saving checkpoint.
program_persistables = set()
for var in context.optimize_graph.program.list_vars():
if var.persistable:
program_persistables.add(var.name)
program = Program()
for var_node in train_ir_graph.all_persistable_nodes():
if var_node.name() not in program_persistables:
var_desc = var_node.var()
var = program.global_block().create_var(
name=var_node.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level())
context.optimize_graph.persistables[var.name] = var
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
...@@ -123,6 +149,15 @@ class QuantizationStrategy(Strategy): ...@@ -123,6 +149,15 @@ class QuantizationStrategy(Strategy):
build_strategy=build_strategy) build_strategy=build_strategy)
# for saving inference model after training # for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph) context.put('quantization_test_ir_graph_backup', test_ir_graph)
def on_epoch_begin(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
super(QuantizationStrategy, self).on_epoch_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
self._modify_graph_for_quantization(context)
_logger.info('Finish QuantizationStrategy::on_epoch_begin') _logger.info('Finish QuantizationStrategy::on_epoch_begin')
def on_epoch_end(self, context): def on_epoch_end(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册