未验证 提交 01fc84a1 编写于 作者: Y yukavio 提交者: GitHub

Saving inference model for user defined quantization (#25799)

* Saving inference model for user defined quantization

* Saving inference model for user defined quantization
上级 e947d11e
......@@ -380,6 +380,7 @@ class QuantizationTransformPass(object):
shape=var_node.shape(),
dtype='float32')
out_node = func(in_node)
graph.out_node_mapping_table[out_node.name] = var_node.name()
# loss shape must be 1 when minimize
loss = mean(out_node)
if not graph._for_test:
......@@ -590,6 +591,8 @@ class QuantizationTransformPass(object):
if op.name() in self._quantizable_ops or \
op.name() in self._quantizable_grad_ops:
_quant_preprocess(op)
# Insert mapping table to solve the problem in saving inference model.
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
......@@ -1018,6 +1021,10 @@ class QuantizationFreezePass(object):
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
input_arg_name = op_node.input('X')[0]
if hasattr(graph, 'out_node_mapping_table'):
if input_arg_name in graph.out_node_mapping_table.keys():
input_arg_name = graph.out_node_mapping_table[
input_arg_name]
if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name)
......
......@@ -14,6 +14,7 @@
import os
import unittest
import json
import random
import numpy as np
import six
......@@ -38,6 +39,7 @@ def residual_block(img, label, num=1):
filter_size,
stride,
padding,
use_cudnn=False,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
......@@ -109,6 +111,16 @@ class TestUserDefinedQuantization(unittest.TestCase):
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
def load_dict():
with open('mapping_table_for_saving_inference_model', 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(Dict):
with open('mapping_table_for_saving_inference_model', 'w') as file:
file.write(json.dumps(Dict))
random.seed(0)
np.random.seed(0)
......@@ -151,6 +163,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
executor=exe)
test_transform_pass.apply(test_graph)
save_dict(test_graph.out_node_mapping_table)
add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place)
add_quant_dequant_pass.apply(main_graph)
......@@ -182,6 +195,21 @@ class TestUserDefinedQuantization(unittest.TestCase):
feed=feeder.feed(data),
fetch_list=[loss])
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_bits=8,
activation_bits=8,
weight_quantize_type=weight_quant_type)
mapping_table = load_dict()
test_graph.out_node_mapping_table = mapping_table
if act_quantize_func == None and weight_quantize_func == None:
freeze_pass.apply(test_graph)
def test_act_preprocess_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册