未验证 提交 b83d27ac 编写于 作者: H handiz 提交者: GitHub

fix bug in PostTrainingProgram for certain cases (#45616)

* fix bug in PostTrainingProgram for certain cases
上级 808df649
......@@ -26,6 +26,7 @@ except:
from inspect import isgeneratorfunction
from .... import io
from .... import core
from .... import reader
from .... import framework
from .... import unique_name
from ....executor import global_scope, Executor
......@@ -141,7 +142,6 @@ class PostTrainingQuantization(object):
is_use_cache_file=False,
skip_tensor_list=None,
same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=False):
......@@ -231,7 +231,6 @@ class PostTrainingQuantization(object):
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before
quantization. Default False.
scale_trainable(bool, optional): whether scale can be train.
is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated.
Returns:
......@@ -296,7 +295,7 @@ class PostTrainingQuantization(object):
batch_generator, data_loader]), "The sample_generator, batch_generator " \
"and data_loader cannot be None in the same time."
if data_loader is not None:
assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \
assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction), reader.GeneratorLoader)), \
"data_loader only accepts `paddle.io.DataLoader` or Generator instance."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
......@@ -366,9 +365,11 @@ class PostTrainingQuantization(object):
self._quantized_threshold = {}
self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model
self._scale_trainable = scale_trainable
self._scale_dict = scale_dict
self._return_graph = return_graph
self.FLAG = False
if self._program is not None:
self.FLAG = True
def quantize(self):
'''
......@@ -440,6 +441,7 @@ class PostTrainingQuantization(object):
self._update_program()
# save out_threshold for quantized ops.
if not self.FLAG:
self._save_output_threshold()
if any(op_type in self._quantizable_op_type
......@@ -1001,8 +1003,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)
quantizable_op_type=major_quantizable_op_types)
else:
transform_pass = QuantizationTransformPassV2(
scope=self._scope,
......@@ -1011,8 +1012,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types,
is_test=not self._scale_trainable)
quantizable_op_type=major_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
......@@ -1029,15 +1029,13 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_test=not self._scale_trainable)
quantizable_op_type=minor_quantizable_op_types)
else:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope,
place=self._place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize,
is_test=not self._scale_trainable)
is_full_quantized=self._is_full_quantize)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
......@@ -1055,11 +1053,11 @@ class PostTrainingQuantization(object):
max_scale = None
tmp_tensor_list = []
for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if real_tensor_name not in scale_dict.keys():
continue
if opera == '*':
scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) * float(
......@@ -1072,16 +1070,18 @@ class PostTrainingQuantization(object):
real_tensor_name] if max_scale is None else max(
max_scale, scale_dict[real_tensor_name])
else:
if tensor_name not in scale_dict.keys():
continue
max_scale = scale_dict[
tensor_name] if max_scale is None else max(
max_scale, scale_dict[tensor_name])
for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split(
'#')
if real_tensor_name not in scale_dict.keys():
continue
if opera == '*':
scale_dict[
real_tensor_name] = max_scale / float(
......@@ -1091,6 +1091,8 @@ class PostTrainingQuantization(object):
real_tensor_name] = max_scale * float(
scalar)
else:
if tensor_name not in scale_dict.keys():
continue
scale_dict[tensor_name] = max_scale
self._scale_dict = scale_dict
......@@ -1265,7 +1267,6 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization):
is_use_cache_file=False,
skip_tensor_list=None,
same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None,
scale_dict=None,
return_graph=True):
......@@ -1276,9 +1277,12 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization):
activation_bits, weight_bits, activation_quantize_type,
weight_quantize_type, onnx_format, freeze_model,
optimize_model, is_use_cache_file, skip_tensor_list,
same_scale_tensor_list, scale_trainable, cache_dir,
scale_dict, return_graph)
same_scale_tensor_list, cache_dir, scale_dict,
return_graph)
self.FLAG = False
self._program = program
if self._program is not None:
self.FLAG = True
assert feed_list is not None, \
"Feed list should not be None."
assert fetch_list is not None, \
......
......@@ -1470,9 +1470,10 @@ class OutScaleForTrainingPass(object):
data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32'
try:
scale_node = graph._find_node_by_name(
graph._find_node_by_name(
graph.all_var_nodes(),
self._scale_name(in_node.name()))
continue
except:
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
......
......@@ -186,7 +186,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2'],
[
'test_scale_name_not_in_scale_dict1',
'test_scale_name_not_in_scale_dict1'
'test_scale_name_not_in_scale_dict2'
],
[
'test_scale_name_not_in_scale_dict1#/#1',
'test_scale_name_not_in_scale_dict2#/#1'
]]
ptq = PostTrainingQuantizationProgram(
executor=exe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册