未验证 提交 ee116264 编写于 作者: H handiz 提交者: GitHub

fix bug in PostTrainingQuantizationProgram for not full quantization case (#45500)

* fix bug in PostTrainingQuantizationProgram for not full quantization case

* fix coverage ci
上级 60e1eccb
...@@ -1055,6 +1055,8 @@ class PostTrainingQuantization(object): ...@@ -1055,6 +1055,8 @@ class PostTrainingQuantization(object):
max_scale = None max_scale = None
tmp_tensor_list = [] tmp_tensor_list = []
for tensor_name in tensor_list: for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name: if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split( real_tensor_name, opera, scalar = tensor_name.split(
'#') '#')
...@@ -1075,6 +1077,8 @@ class PostTrainingQuantization(object): ...@@ -1075,6 +1077,8 @@ class PostTrainingQuantization(object):
max_scale, scale_dict[tensor_name]) max_scale, scale_dict[tensor_name])
for tensor_name in tensor_list: for tensor_name in tensor_list:
if tensor_name not in scale_dict.keys():
continue
if '#' in tensor_name: if '#' in tensor_name:
real_tensor_name, opera, scalar = tensor_name.split( real_tensor_name, opera, scalar = tensor_name.split(
'#') '#')
......
...@@ -183,7 +183,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization): ...@@ -183,7 +183,11 @@ class TestPostTrainingQuantizationProgram(TestPostTrainingQuantization):
val_reader = val() val_reader = val()
same_scale_tensor_list = [[ same_scale_tensor_list = [[
'batch_norm_3.tmp_2#/#1', 'batch_norm_4.tmp_2#*#1' 'batch_norm_3.tmp_2#/#1', 'batch_norm_4.tmp_2#*#1'
], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2']] ], ['batch_norm_27.tmp_2', 'batch_norm_26.tmp_2'],
[
'test_scale_name_not_in_scale_dict1',
'test_scale_name_not_in_scale_dict1'
]]
ptq = PostTrainingQuantizationProgram( ptq = PostTrainingQuantizationProgram(
executor=exe, executor=exe,
program=program, program=program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册