未验证 提交 fad86fc1 编写于 作者: B Bai Yifan 提交者: GitHub

Fix preprocess_func quant_aware convert issue (#401)

上级 98247f3a
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import copy import copy
import json
import logging import logging
import paddle import paddle
...@@ -33,8 +35,7 @@ from ..common import get_logger ...@@ -33,8 +35,7 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
'moving_average_abs_max'
] ]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max'] WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
...@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [ ...@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [
'leaky_relu' 'leaky_relu'
] ]
VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model'
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
...@@ -81,6 +84,18 @@ _quant_config_default = { ...@@ -81,6 +84,18 @@ _quant_config_default = {
} }
def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(table):
with open(VARS_MAPPING_TABLE, 'w') as file:
file.write(json.dumps(table))
def _parse_configs(user_config): def _parse_configs(user_config):
""" """
check if user's configs are valid. check if user's configs are valid.
...@@ -267,6 +282,15 @@ def quant_aware(program, ...@@ -267,6 +282,15 @@ def quant_aware(program,
scope=scope, place=place, moving_rate=config['moving_rate']) scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph) out_scale_training_pass.apply(main_graph)
if (weight_preprocess_func is not None or
act_preprocess_func is not None) and not for_test:
_logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
)
_logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
else: else:
...@@ -274,7 +298,8 @@ def quant_aware(program, ...@@ -274,7 +298,8 @@ def quant_aware(program,
return quant_program return quant_program
def quant_post_static(executor, def quant_post_static(
executor,
model_dir, model_dir,
quantize_model_path, quantize_model_path,
batch_generator=None, batch_generator=None,
...@@ -381,6 +406,7 @@ def quant_post_static(executor, ...@@ -381,6 +406,7 @@ def quant_post_static(executor,
model_filename=save_model_filename, model_filename=save_model_filename,
params_filename=save_params_filename) params_filename=save_params_filename)
# We have changed the quant_post to quant_post_static. # We have changed the quant_post to quant_post_static.
# For compatibility, we keep quant_post api for now, and it will be # For compatibility, we keep quant_post api for now, and it will be
# deprecated in the future. # deprecated in the future.
...@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False):
activation_bits=config['activation_bits'], activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type']) weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册