提交 6db7c2a5 编写于 作者: W wanghaoshuang

Fix checkpoint of quantization.

上级 e41d5813
......@@ -204,6 +204,10 @@ class GraphWrapper(object):
"""
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
......@@ -467,7 +471,12 @@ class GraphWrapper(object):
path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables.
"""
io.save_persistables(exe.exe, path, main_program=self.program)
# update persistables from program
for var in self.program.list_vars():
if var.persistable and var.name not in self.persistables:
self.persistables[var.name] = var
io.save_vars(exe.exe, path, vars=self.persistables.values())
def load_persistables(self, path, exe):
"""
......@@ -481,7 +490,7 @@ class GraphWrapper(object):
return os.path.exists(os.path.join(path, var.name))
io.load_vars(
exe.exe, path, main_program=self.program, predicate=if_exist)
exe.exe, path, vars=self.persistables.values(), predicate=if_exist)
def update_param_shape(self, scope):
"""
......
......@@ -20,7 +20,7 @@ from .... import io
from .... import core
from ....compiler import CompiledProgram
from ....compiler import BuildStrategy
from ....framework import IrGraph
from ....framework import IrGraph, Variable, Program
from ..core.strategy import Strategy
from .quantization_pass import *
......@@ -84,17 +84,25 @@ class QuantizationStrategy(Strategy):
self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes
def on_epoch_begin(self, context):
def on_compression_begin(self, context):
"""
Restore graph when the compressoin task is inited from checkpoint.
"""
# It is inited from checkpoint and has missed start epoch.
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
_logger.info("Restore quantization task from checkpoint")
self._modify_graph_for_quantization(context)
_logger.info("Finish restoring quantization task from checkpoint")
def _modify_graph_for_quantization(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
super(QuantizationStrategy, self).on_compression_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
train_ir_graph = IrGraph(
core.Graph(context.optimize_graph.program.desc), for_test=False)
core.Graph(context.optimize_graph.program.clone().desc),
for_test=False)
test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.desc), for_test=True)
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=context.scope,
place=context.place,
......@@ -103,6 +111,24 @@ class QuantizationStrategy(Strategy):
activation_quantize_type=self.activation_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
# Put persistables created by transform_pass into context.optimize_graph.persistables
# for saving checkpoint.
program_persistables = set()
for var in context.optimize_graph.program.list_vars():
if var.persistable:
program_persistables.add(var.name)
program = Program()
for var_node in train_ir_graph.all_persistable_nodes():
if var_node.name() not in program_persistables:
var_desc = var_node.var()
var = program.global_block().create_var(
name=var_node.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level())
context.optimize_graph.persistables[var.name] = var
build_strategy = BuildStrategy()
build_strategy.enable_inplace = False
......@@ -118,6 +144,15 @@ class QuantizationStrategy(Strategy):
build_strategy=build_strategy)
# for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph)
def on_epoch_begin(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
super(QuantizationStrategy, self).on_epoch_begin(context)
if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin')
self._modify_graph_for_quantization(context)
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
def on_epoch_end(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册