未验证 提交 75eec3d1 编写于 作者: C cc 提交者: GitHub

Post training quantization supports optimize model by fusing (#24822)

* Post_training_quantization supports optimize model by fusing, test=develop
上级 12bffdc0
......@@ -65,6 +65,56 @@ def _all_persistable_var_names(program):
return persistable_var_names
def _remove_unused_var_nodes(graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _remove_ctrl_vars(graph):
remove_ctr_vars = set()
for node in graph.all_var_nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
graph.safe_remove_nodes(remove_ctr_vars)
return graph
def _apply_pass(scope,
graph,
pass_name,
attrs=None,
attr_values=None,
debug=False):
ir_pass = core.get_pass(pass_name)
cpp_graph = graph.graph
if not cpp_graph.has('__param_scope__'):
cpp_graph.set_not_owned('__param_scope__', scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(cpp_graph)
if debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes())
_remove_unused_var_nodes(graph)
return graph
class PostTrainingQuantization(object):
"""
Utilizing post training quantization methon to quantize the FP32 model,
......@@ -89,6 +139,7 @@ class PostTrainingQuantization(object):
weight_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
......@@ -145,6 +196,14 @@ class PostTrainingQuantization(object):
the fake ops in saving quantized model, and we save the scale obtained
by post training quantization in fake ops. Compared to 'abs_max',
the model accuracy is usually higher when it is 'channel_wise_abs_max'.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
weights are quantized by tensor-wise method, which means the weights
scale for all channel are the same. However, if fuse
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before
quantization. Default False.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
......@@ -240,6 +299,7 @@ class PostTrainingQuantization(object):
for op_type in self._quantizable_op_type:
assert op_type in self._support_quantize_op_type, \
op_type + " is not supported for quantization."
self._optimize_model = optimize_model
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
......@@ -344,6 +404,10 @@ class PostTrainingQuantization(object):
executor=self._executor,
model_filename=self._model_filename,
params_filename=self._params_filename)
if self._optimize_model:
self._optimize_fp32_model()
feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list]
self._data_loader = io.DataLoader.from_generator(
......@@ -358,6 +422,16 @@ class PostTrainingQuantization(object):
self._data_loader.set_batch_generator(
self._batch_generator, places=self._place)
def _optimize_fp32_model(self):
'''
Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model.
'''
_logger.info("Optimize FP32 model ...")
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
graph = _remove_ctrl_vars(graph)
graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
self._program = graph.to_program()
def _collect_target_varnames(self):
'''
Collect the variable names for sampling, and set activation
......
......@@ -239,7 +239,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
quantizable_op_type,
algo="KL",
is_full_quantize=False,
is_use_cache_file=False):
is_use_cache_file=False,
is_optimize_model=False):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
......@@ -259,12 +260,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold):
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -278,9 +281,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model",
quantizable_op_type, algo,
is_full_quantize, is_use_cache_file)
self.generate_quantized_model(
model_cache_folder + "/model", quantizable_op_type, algo,
is_full_quantize, is_use_cache_file, is_optimize_model)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
......@@ -316,9 +319,11 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
......@@ -335,10 +340,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-traing quantization (abs_max) maybe bigger
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
if __name__ == '__main__':
......
......@@ -28,9 +28,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册