提交 6db7c2a5 编写于 作者: W wanghaoshuang

Fix checkpoint of quantization.

上级 e41d5813
...@@ -204,6 +204,10 @@ class GraphWrapper(object): ...@@ -204,6 +204,10 @@ class GraphWrapper(object):
""" """
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program self.program = Program() if program is None else program
self.persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None self.compiled_graph = None
self.in_nodes = OrderedDict(in_nodes) self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes) self.out_nodes = OrderedDict(out_nodes)
...@@ -467,7 +471,12 @@ class GraphWrapper(object): ...@@ -467,7 +471,12 @@ class GraphWrapper(object):
path(str): The path to save the persistables. path(str): The path to save the persistables.
exe(framework.Executor): The executor used to save the persistables. exe(framework.Executor): The executor used to save the persistables.
""" """
io.save_persistables(exe.exe, path, main_program=self.program) # update persistables from program
for var in self.program.list_vars():
if var.persistable and var.name not in self.persistables:
self.persistables[var.name] = var
io.save_vars(exe.exe, path, vars=self.persistables.values())
def load_persistables(self, path, exe): def load_persistables(self, path, exe):
""" """
...@@ -481,7 +490,7 @@ class GraphWrapper(object): ...@@ -481,7 +490,7 @@ class GraphWrapper(object):
return os.path.exists(os.path.join(path, var.name)) return os.path.exists(os.path.join(path, var.name))
io.load_vars( io.load_vars(
exe.exe, path, main_program=self.program, predicate=if_exist) exe.exe, path, vars=self.persistables.values(), predicate=if_exist)
def update_param_shape(self, scope): def update_param_shape(self, scope):
""" """
......
...@@ -20,7 +20,7 @@ from .... import io ...@@ -20,7 +20,7 @@ from .... import io
from .... import core from .... import core
from ....compiler import CompiledProgram from ....compiler import CompiledProgram
from ....compiler import BuildStrategy from ....compiler import BuildStrategy
from ....framework import IrGraph from ....framework import IrGraph, Variable, Program
from ..core.strategy import Strategy from ..core.strategy import Strategy
from .quantization_pass import * from .quantization_pass import *
...@@ -84,40 +84,75 @@ class QuantizationStrategy(Strategy): ...@@ -84,40 +84,75 @@ class QuantizationStrategy(Strategy):
self.save_out_nodes = save_out_nodes self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes self.save_in_nodes = save_in_nodes
def on_compression_begin(self, context):
"""
Restore graph when the compressoin task is inited from checkpoint.
"""
# It is inited from checkpoint and has missed start epoch.
if context.epoch_id != 0 and context.epoch_id > self.start_epoch:
_logger.info("Restore quantization task from checkpoint")
self._modify_graph_for_quantization(context)
_logger.info("Finish restoring quantization task from checkpoint")
def _modify_graph_for_quantization(self, context):
"""
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
"""
train_ir_graph = IrGraph(
core.Graph(context.optimize_graph.program.clone().desc),
for_test=False)
test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.clone().desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
# Put persistables created by transform_pass into context.optimize_graph.persistables
# for saving checkpoint.
program_persistables = set()
for var in context.optimize_graph.program.list_vars():
if var.persistable:
program_persistables.add(var.name)
program = Program()
for var_node in train_ir_graph.all_persistable_nodes():
if var_node.name() not in program_persistables:
var_desc = var_node.var()
var = program.global_block().create_var(
name=var_node.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level())
context.optimize_graph.persistables[var.name] = var
build_strategy = BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
# for quantization training
context.optimize_graph.compiled_graph = CompiledProgram(
train_ir_graph.graph).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'],
build_strategy=build_strategy)
# for evaluation. And program compiled from ir graph must be with data parallel.
context.eval_graph.compiled_graph = CompiledProgram(
test_ir_graph.graph).with_data_parallel(
build_strategy=build_strategy)
# for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph)
def on_epoch_begin(self, context): def on_epoch_begin(self, context):
""" """
Insert fake_quantize_op and fake_dequantize_op before trainging and testing. Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
""" """
super(QuantizationStrategy, self).on_compression_begin(context) super(QuantizationStrategy, self).on_epoch_begin(context)
if self.start_epoch == context.epoch_id: if self.start_epoch == context.epoch_id:
_logger.info('QuantizationStrategy::on_epoch_begin') _logger.info('QuantizationStrategy::on_epoch_begin')
train_ir_graph = IrGraph( self._modify_graph_for_quantization(context)
core.Graph(context.optimize_graph.program.desc), for_test=False)
test_ir_graph = IrGraph(
core.Graph(context.eval_graph.program.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
build_strategy = BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
# for quantization training
context.optimize_graph.compiled_graph = CompiledProgram(
train_ir_graph.graph).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'],
build_strategy=build_strategy)
# for evaluation. And program compiled from ir graph must be with data parallel.
context.eval_graph.compiled_graph = CompiledProgram(
test_ir_graph.graph).with_data_parallel(
build_strategy=build_strategy)
# for saving inference model after training
context.put('quantization_test_ir_graph_backup', test_ir_graph)
_logger.info('Finish QuantizationStrategy::on_epoch_begin') _logger.info('Finish QuantizationStrategy::on_epoch_begin')
def on_epoch_end(self, context): def on_epoch_end(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册