未验证 提交 72b9dcd8 编写于 作者: B Bai Yifan 提交者: GitHub

Merge branch 'develop' into pact_clip

......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import json
import logging
import paddle
......@@ -33,8 +35,7 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
......@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [
'leaky_relu'
]
VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model'
_quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
......@@ -81,6 +84,18 @@ _quant_config_default = {
}
def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(table):
with open(VARS_MAPPING_TABLE, 'w') as file:
file.write(json.dumps(table))
def _parse_configs(user_config):
"""
check if user's configs are valid.
......@@ -267,6 +282,15 @@ def quant_aware(program,
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
if (weight_preprocess_func is not None or
act_preprocess_func is not None) and not for_test:
_logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
)
_logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
if for_test:
quant_program = main_graph.to_program()
else:
......@@ -274,7 +298,8 @@ def quant_aware(program,
return quant_program
def quant_post_static(executor,
def quant_post_static(
executor,
model_dir,
quantize_model_path,
batch_generator=None,
......@@ -381,6 +406,7 @@ def quant_post_static(executor,
model_filename=save_model_filename,
params_filename=save_params_filename)
# We have changed the quant_post to quant_post_static.
# For compatibility, we keep quant_post api for now, and it will be
# deprecated in the future.
......@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False):
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册