未验证 提交 75eec3d1 编写于 作者: C cc 提交者: GitHub

Post training quantization supports optimize model by fusing (#24822)

* Post_training_quantization supports optimize model by fusing, test=develop
上级 12bffdc0
...@@ -65,6 +65,56 @@ def _all_persistable_var_names(program): ...@@ -65,6 +65,56 @@ def _all_persistable_var_names(program):
return persistable_var_names return persistable_var_names
def _remove_unused_var_nodes(graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _remove_ctrl_vars(graph):
remove_ctr_vars = set()
for node in graph.all_var_nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
graph.safe_remove_nodes(remove_ctr_vars)
return graph
def _apply_pass(scope,
graph,
pass_name,
attrs=None,
attr_values=None,
debug=False):
ir_pass = core.get_pass(pass_name)
cpp_graph = graph.graph
if not cpp_graph.has('__param_scope__'):
cpp_graph.set_not_owned('__param_scope__', scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(cpp_graph)
if debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name), graph.all_op_nodes())
_remove_unused_var_nodes(graph)
return graph
class PostTrainingQuantization(object): class PostTrainingQuantization(object):
""" """
Utilizing post training quantization methon to quantize the FP32 model, Utilizing post training quantization methon to quantize the FP32 model,
...@@ -89,6 +139,7 @@ class PostTrainingQuantization(object): ...@@ -89,6 +139,7 @@ class PostTrainingQuantization(object):
weight_bits=8, weight_bits=8,
activation_quantize_type='range_abs_max', activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False, is_use_cache_file=False,
cache_dir="./temp_post_training"): cache_dir="./temp_post_training"):
''' '''
...@@ -145,6 +196,14 @@ class PostTrainingQuantization(object): ...@@ -145,6 +196,14 @@ class PostTrainingQuantization(object):
the fake ops in saving quantized model, and we save the scale obtained the fake ops in saving quantized model, and we save the scale obtained
by post training quantization in fake ops. Compared to 'abs_max', by post training quantization in fake ops. Compared to 'abs_max',
the model accuracy is usually higher when it is 'channel_wise_abs_max'. the model accuracy is usually higher when it is 'channel_wise_abs_max'.
optimize_model(bool, optional): If set optimize_model as True, it applies
some passes to the model before quantization, and it supports
`conv2d/depthwise_conv2d + bn` pass so far. Some targets require the
weights are quantized by tensor-wise method, which means the weights
scale for all channel are the same. However, if fuse
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before
quantization. Default False.
is_use_cache_file(bool, optional): If set is_use_cache_file as False, is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True, all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or it will save temp data to disk. When the fp32 model is complex or
...@@ -240,6 +299,7 @@ class PostTrainingQuantization(object): ...@@ -240,6 +299,7 @@ class PostTrainingQuantization(object):
for op_type in self._quantizable_op_type: for op_type in self._quantizable_op_type:
assert op_type in self._support_quantize_op_type, \ assert op_type in self._support_quantize_op_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._optimize_model = optimize_model
self._is_use_cache_file = is_use_cache_file self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir): if self._is_use_cache_file and not os.path.exists(self._cache_dir):
...@@ -344,6 +404,10 @@ class PostTrainingQuantization(object): ...@@ -344,6 +404,10 @@ class PostTrainingQuantization(object):
executor=self._executor, executor=self._executor,
model_filename=self._model_filename, model_filename=self._model_filename,
params_filename=self._params_filename) params_filename=self._params_filename)
if self._optimize_model:
self._optimize_fp32_model()
feed_vars = [framework._get_var(str(var_name), self._program) \ feed_vars = [framework._get_var(str(var_name), self._program) \
for var_name in self._feed_list] for var_name in self._feed_list]
self._data_loader = io.DataLoader.from_generator( self._data_loader = io.DataLoader.from_generator(
...@@ -358,6 +422,16 @@ class PostTrainingQuantization(object): ...@@ -358,6 +422,16 @@ class PostTrainingQuantization(object):
self._data_loader.set_batch_generator( self._data_loader.set_batch_generator(
self._batch_generator, places=self._place) self._batch_generator, places=self._place)
def _optimize_fp32_model(self):
'''
Fuse the `conv2d/depthwise_conv2d + bn` in FP32 model.
'''
_logger.info("Optimize FP32 model ...")
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
graph = _remove_ctrl_vars(graph)
graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass')
self._program = graph.to_program()
def _collect_target_varnames(self): def _collect_target_varnames(self):
''' '''
Collect the variable names for sampling, and set activation Collect the variable names for sampling, and set activation
......
...@@ -239,7 +239,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -239,7 +239,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
quantizable_op_type, quantizable_op_type,
algo="KL", algo="KL",
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False): is_use_cache_file=False,
is_optimize_model=False):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
except Exception as e: except Exception as e:
...@@ -259,12 +260,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -259,12 +260,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold): is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -278,9 +281,9 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -278,9 +281,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model(model_cache_folder + "/model", self.generate_quantized_model(
quantizable_op_type, algo, model_cache_folder + "/model", quantizable_op_type, algo,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file, is_optimize_model)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -316,9 +319,11 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -316,9 +319,11 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
] ]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025 diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold) is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
...@@ -335,10 +340,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ...@@ -335,10 +340,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
] ]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-traing quantization (abs_max) maybe bigger # The accuracy diff of post-traing quantization (abs_max) maybe bigger
diff_threshold = 0.05 diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold) is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,9 +28,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -28,9 +28,11 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"] quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0.025 diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold) is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册