未验证 提交 72b9dcd8 编写于 作者: B Bai Yifan 提交者: GitHub

Merge branch 'develop' into pact_clip

......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import json
import logging
import paddle
......@@ -33,8 +35,7 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
......@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [
'leaky_relu'
]
VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model'
_quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
......@@ -81,6 +84,18 @@ _quant_config_default = {
}
def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(table):
with open(VARS_MAPPING_TABLE, 'w') as file:
file.write(json.dumps(table))
def _parse_configs(user_config):
"""
check if user's configs are valid.
......@@ -267,6 +282,15 @@ def quant_aware(program,
scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph)
if (weight_preprocess_func is not None or
act_preprocess_func is not None) and not for_test:
_logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
)
_logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
if for_test:
quant_program = main_graph.to_program()
else:
......@@ -274,27 +298,28 @@ def quant_aware(program,
return quant_program
def quant_post_static(executor,
model_dir,
quantize_model_path,
batch_generator=None,
sample_generator=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
is_use_cache_file=False,
cache_dir="./temp_post_training"):
def quant_post_static(
executor,
model_dir,
quantize_model_path,
batch_generator=None,
sample_generator=None,
model_filename=None,
params_filename=None,
save_model_filename='__model__',
save_params_filename='__params__',
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
activation_bits=8,
activation_quantize_type='range_abs_max',
weight_quantize_type='channel_wise_abs_max',
is_use_cache_file=False,
cache_dir="./temp_post_training"):
"""
The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the
......@@ -381,6 +406,7 @@ def quant_post_static(executor,
model_filename=save_model_filename,
params_filename=save_params_filename)
# We have changed the quant_post to quant_post_static.
# For compatibility, we keep quant_post api for now, and it will be
# deprecated in the future.
......@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False):
activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
......@@ -451,14 +480,14 @@ def convert(program, place, config=None, scope=None, save_int8=False):
def quant_post_dynamic(model_dir,
save_model_dir,
model_filename=None,
params_filename=None,
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
weight_bits=8,
generate_test_model=False):
save_model_dir,
model_filename=None,
params_filename=None,
save_model_filename=None,
save_params_filename=None,
quantizable_op_type=["conv2d", "mul"],
weight_bits=8,
generate_test_model=False):
'''
The function utilizes static post training quantization method to
quantize the fp32 model. In details, it quantizes the weight of some
......@@ -517,4 +546,4 @@ def quant_post_dynamic(model_dir,
# We have changed the quant_post_only_weight to quant_post_dynamic.
# For compatibility, we keep quant_post_only_weight api for now,
# and it will be deprecated in the future.
quant_post_only_weight = quant_post_dynamic
\ No newline at end of file
quant_post_only_weight = quant_post_dynamic
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册