未验证 提交 c9285a18 编写于 作者: Y yukavio 提交者: GitHub

saving inference model when user define activation or weight preprocess function (#25749)

* saving inference model for user defined quantization model

* saving inference model for user defined quantization model

* fixed ci coverage
上级 eef98b7f
...@@ -435,6 +435,8 @@ class QuantizationTransformPass(object): ...@@ -435,6 +435,8 @@ class QuantizationTransformPass(object):
if op.name() in self._quantizable_ops or \ if op.name() in self._quantizable_ops or \
op.name() in self._quantizable_grad_ops: op.name() in self._quantizable_grad_ops:
_quant_preprocess(op) _quant_preprocess(op)
# Insert mapping table to solve the problem in saving inference model.
graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
...@@ -853,6 +855,7 @@ class QuantizationTransformPass(object): ...@@ -853,6 +855,7 @@ class QuantizationTransformPass(object):
shape=var_node.shape(), shape=var_node.shape(),
dtype='float32') dtype='float32')
out_node = func(in_node) out_node = func(in_node)
graph.out_node_mapping_table[out_node.name] = var_node.name()
# loss shape must be 1 when minimize # loss shape must be 1 when minimize
loss = mean(out_node) loss = mean(out_node)
if not graph._for_test: if not graph._for_test:
...@@ -1037,6 +1040,10 @@ class QuantizationFreezePass(object): ...@@ -1037,6 +1040,10 @@ class QuantizationFreezePass(object):
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_quant_op_names: if op_name in self._fake_quant_op_names:
input_arg_name = op_node.input('X')[0] input_arg_name = op_node.input('X')[0]
if hasattr(graph, 'out_node_mapping_table'):
if input_arg_name in graph.out_node_mapping_table.keys():
input_arg_name = graph.out_node_mapping_table[
input_arg_name]
if input_arg_name in persistable_vars: if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max': if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name) param = self._load_var(input_arg_name)
......
...@@ -298,7 +298,6 @@ list(REMOVE_ITEM TEST_OPS ...@@ -298,7 +298,6 @@ list(REMOVE_ITEM TEST_OPS
#TODO(wanghaoshuang): Fix this unitest failed on GCC8. #TODO(wanghaoshuang): Fix this unitest failed on GCC8.
LIST(REMOVE_ITEM TEST_OPS test_auto_pruning) LIST(REMOVE_ITEM TEST_OPS test_auto_pruning)
LIST(REMOVE_ITEM TEST_OPS test_filter_pruning) LIST(REMOVE_ITEM TEST_OPS test_filter_pruning)
LIST(REMOVE_ITEM TEST_OPS test_user_defined_quantization)
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endforeach() endforeach()
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import unittest import unittest
import json
import random import random
import numpy as np import numpy as np
import six import six
...@@ -109,6 +110,16 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -109,6 +110,16 @@ class TestUserDefinedQuantization(unittest.TestCase):
def get_optimizer(): def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
def load_dict():
with open('mapping_table_for_saving_inference_model', 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(Dict):
with open('mapping_table_for_saving_inference_model', 'w') as file:
file.write(json.dumps(Dict))
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -151,6 +162,7 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -151,6 +162,7 @@ class TestUserDefinedQuantization(unittest.TestCase):
executor=exe) executor=exe)
test_transform_pass.apply(test_graph) test_transform_pass.apply(test_graph)
save_dict(test_graph.out_node_mapping_table)
add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place) add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place)
add_quant_dequant_pass.apply(main_graph) add_quant_dequant_pass.apply(main_graph)
...@@ -182,6 +194,21 @@ class TestUserDefinedQuantization(unittest.TestCase): ...@@ -182,6 +194,21 @@ class TestUserDefinedQuantization(unittest.TestCase):
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph)
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_bits=8,
activation_bits=8,
weight_quantize_type=weight_quant_type)
mapping_table = load_dict()
test_graph.out_node_mapping_table = mapping_table
if act_quantize_func == None and weight_quantize_func == None:
freeze_pass.apply(test_graph)
def test_act_preprocess_cuda(self): def test_act_preprocess_cuda(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册