未验证 提交 b83d27ac 编写于 作者: H handiz 提交者: GitHub

fix bug in PostTrainingProgram for certain cases (#45616)

* fix bug in PostTrainingProgram for certain cases
上级 808df649
...@@ -26,6 +26,7 @@ except: ...@@ -26,6 +26,7 @@ except:
from inspect import isgeneratorfunction from inspect import isgeneratorfunction
from .... import io from .... import io
from .... import core from .... import core
from .... import reader
from .... import framework from .... import framework
from .... import unique_name from .... import unique_name
from ....executor import global_scope, Executor from ....executor import global_scope, Executor
...@@ -141,7 +142,6 @@ class PostTrainingQuantization(object): ...@@ -141,7 +142,6 @@ class PostTrainingQuantization(object):
is_use_cache_file=False, is_use_cache_file=False,
skip_tensor_list=None, skip_tensor_list=None,
same_scale_tensor_list=None, same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None, cache_dir=None,
scale_dict=None, scale_dict=None,
return_graph=False): return_graph=False):
...@@ -231,7 +231,6 @@ class PostTrainingQuantization(object): ...@@ -231,7 +231,6 @@ class PostTrainingQuantization(object):
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before be different. In address this problem, fuse the pattern before
quantization. Default False. quantization. Default False.
scale_trainable(bool, optional): whether scale can be train.
is_use_cache_file(bool, optional): This param is deprecated. is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated. cache_dir(str, optional): This param is deprecated.
Returns: Returns:
...@@ -296,7 +295,7 @@ class PostTrainingQuantization(object): ...@@ -296,7 +295,7 @@ class PostTrainingQuantization(object):
batch_generator, data_loader]), "The sample_generator, batch_generator " \ batch_generator, data_loader]), "The sample_generator, batch_generator " \
"and data_loader cannot be None in the same time." "and data_loader cannot be None in the same time."
if data_loader is not None: if data_loader is not None:
assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction))), \ assert isinstance(data_loader, (io.DataLoader, type(isgeneratorfunction), reader.GeneratorLoader)), \
"data_loader only accepts `paddle.io.DataLoader` or Generator instance." "data_loader only accepts `paddle.io.DataLoader` or Generator instance."
assert batch_size > 0, "The batch_size should be greater than 0." assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \ assert algo in self._support_algo_type, \
...@@ -366,9 +365,11 @@ class PostTrainingQuantization(object): ...@@ -366,9 +365,11 @@ class PostTrainingQuantization(object):
self._quantized_threshold = {} self._quantized_threshold = {}
self._same_scale_tensor_list = same_scale_tensor_list self._same_scale_tensor_list = same_scale_tensor_list
self._freeze_model = freeze_model self._freeze_model = freeze_model
self._scale_trainable = scale_trainable
self._scale_dict = scale_dict self._scale_dict = scale_dict
self._return_graph = return_graph self._return_graph = return_graph
self.FLAG = False
if self._program is not None:
self.FLAG = True
def quantize(self): def quantize(self):
''' '''
...@@ -440,6 +441,7 @@ class PostTrainingQuantization(object): ...@@ -440,6 +441,7 @@ class PostTrainingQuantization(object):
self._update_program() self._update_program()
# save out_threshold for quantized ops. # save out_threshold for quantized ops.
if not self.FLAG:
self._save_output_threshold() self._save_output_threshold()
if any(op_type in self._quantizable_op_type if any(op_type in self._quantizable_op_type
...@@ -1001,8 +1003,7 @@ class PostTrainingQuantization(object): ...@@ -1001,8 +1003,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=major_quantizable_op_types)
is_test=not self._scale_trainable)
else: else:
transform_pass = QuantizationTransformPassV2( transform_pass = QuantizationTransformPassV2(
scope=self._scope, scope=self._scope,
...@@ -1011,8 +1012,7 @@ class PostTrainingQuantization(object): ...@@ -1011,8 +1012,7 @@ class PostTrainingQuantization(object):
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
activation_quantize_type=self._activation_quantize_type, activation_quantize_type=self._activation_quantize_type,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
quantizable_op_type=major_quantizable_op_types, quantizable_op_type=major_quantizable_op_types)
is_test=not self._scale_trainable)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so # Insert fake_quant/fake_dequantize op must in test graph, so
...@@ -1029,15 +1029,13 @@ class PostTrainingQuantization(object): ...@@ -1029,15 +1029,13 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types)
is_test=not self._scale_trainable)
else: else:
add_quant_dequant_pass = AddQuantDequantPassV2( add_quant_dequant_pass = AddQuantDequantPassV2(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=minor_quantizable_op_types, quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=self._is_full_quantize, is_full_quantized=self._is_full_quantize)
is_test=not self._scale_trainable)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
...@@ -1055,11 +1053,11 @@ class PostTrainingQuantization(object): ...@@ -1055,11 +1053,11 @@ class PostTrainingQuantization(object):
max_scale = None max_scale = None
tmp_tensor_list = [] tmp_tensor_list = []
for tensor_name in tensor_list: for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name: if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split( real_tensor_name, opera, scalar = tensor_name.split(
'#') '#')
if real_tensor_name not in scale_dict.keys():
continue
if opera == '*': if opera == '*':
scale_dict[real_tensor_name] = float( scale_dict[real_tensor_name] = float(
scale_dict[real_tensor_name]) * float( scale_dict[real_tensor_name]) * float(
...@@ -1072,16 +1070,18 @@ class PostTrainingQuantization(object): ...@@ -1072,16 +1070,18 @@ class PostTrainingQuantization(object):
real_tensor_name] if max_scale is None else max( real_tensor_name] if max_scale is None else max(
max_scale, scale_dict[real_tensor_name]) max_scale, scale_dict[real_tensor_name])
else: else:
if tensor_name not in scale_dict.keys():
continue
max_scale = scale_dict[ max_scale = scale_dict[
tensor_name] if max_scale is None else max( tensor_name] if max_scale is None else max(
max_scale, scale_dict[tensor_name]) max_scale, scale_dict[tensor_name])
for tensor_name in tensor_list: for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name: if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split( real_tensor_name, opera, scalar = tensor_name.split(
'#') '#')
if real_tensor_name not in scale_dict.keys():
continue
if opera == '*': if opera == '*':
scale_dict[ scale_dict[
real_tensor_name] = max_scale / float( real_tensor_name] = max_scale / float(
...@@ -1091,6 +1091,8 @@ class PostTrainingQuantization(object): ...@@ -1091,6 +1091,8 @@ class PostTrainingQuantization(object):
real_tensor_name] = max_scale * float( real_tensor_name] = max_scale * float(
scalar) scalar)
else: else:
if tensor_name not in scale_dict.keys():
continue
scale_dict[tensor_name] = max_scale scale_dict[tensor_name] = max_scale
self._scale_dict = scale_dict self._scale_dict = scale_dict
...@@ -1265,7 +1267,6 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization): ...@@ -1265,7 +1267,6 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization):
is_use_cache_file=False, is_use_cache_file=False,
skip_tensor_list=None, skip_tensor_list=None,
same_scale_tensor_list=None, same_scale_tensor_list=None,
scale_trainable=False,
cache_dir=None, cache_dir=None,
scale_dict=None, scale_dict=None,
return_graph=True): return_graph=True):
...@@ -1276,9 +1277,12 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization): ...@@ -1276,9 +1277,12 @@ class PostTrainingQuantizationProgram(PostTrainingQuantization):
activation_bits, weight_bits, activation_quantize_type, activation_bits, weight_bits, activation_quantize_type,
weight_quantize_type, onnx_format, freeze_model, weight_quantize_type, onnx_format, freeze_model,
optimize_model, is_use_cache_file, skip_tensor_list, optimize_model, is_use_cache_file, skip_tensor_list,
same_scale_tensor_list, scale_trainable, cache_dir, same_scale_tensor_list, cache_dir, scale_dict,
scale_dict, return_graph) return_graph)
self.FLAG = False
self._program = program self._program = program
if self._program is not None:
self.FLAG = True
assert feed_list is not None, \ assert feed_list is not None, \
"Feed list should not be None." "Feed list should not be None."
assert fetch_list is not None, \ assert fetch_list is not None, \
......
...@@ -1470,9 +1470,10 @@ class OutScaleForTrainingPass(object): ...@@ -1470,9 +1470,10 @@ class OutScaleForTrainingPass(object):
data_type = 'float64' if in_node.dtype() \ data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32' == core.VarDesc.VarType.FP64 else 'float32'
try: try:
scale_node = graph._find_node_by_name( graph._find_node_by_name(
graph.all_var_nodes(), graph.all_var_nodes(),
self._scale_name(in_node.name())) self._scale_name(in_node.name()))
continue
except: except:
scale_node = graph.create_persistable_node( scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()), name=self._scale_name(in_node.name()),
......
...@@ -186,7 +186,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization): ...@@ -186,7 +186,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2'], ], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2'],
[ [
'test_scale_name_not_in_scale_dict1', 'test_scale_name_not_in_scale_dict1',
'test_scale_name_not_in_scale_dict1' 'test_scale_name_not_in_scale_dict2'
],
[
'test_scale_name_not_in_scale_dict1#/#1',
'test_scale_name_not_in_scale_dict2#/#1'
]] ]]
ptq = PostTrainingQuantizationProgram( ptq = PostTrainingQuantizationProgram(
executor=exe, executor=exe,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册