未验证 提交 72b9dcd8 编写于 作者: B Bai Yifan 提交者: GitHub

Merge branch 'develop' into pact_clip

...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import copy import copy
import json
import logging import logging
import paddle import paddle
...@@ -33,8 +35,7 @@ from ..common import get_logger ...@@ -33,8 +35,7 @@ from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
WEIGHT_QUANTIZATION_TYPES = [ WEIGHT_QUANTIZATION_TYPES = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'
'moving_average_abs_max'
] ]
WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max'] WEIGHT_QUANTIZATION_TYPES_TENSORRT = ['channel_wise_abs_max']
...@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [ ...@@ -55,6 +56,8 @@ TENSORRT_OP_TYPES = [
'leaky_relu' 'leaky_relu'
] ]
VARS_MAPPING_TABLE = './mapping_table_for_saving_inference_model'
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'channel_wise_abs_max' # weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max', 'weight_quantize_type': 'channel_wise_abs_max',
...@@ -81,6 +84,18 @@ _quant_config_default = { ...@@ -81,6 +84,18 @@ _quant_config_default = {
} }
def load_dict():
with open(VARS_MAPPING_TABLE, 'r') as file:
data = file.read()
data = json.loads(data)
return data
def save_dict(table):
with open(VARS_MAPPING_TABLE, 'w') as file:
file.write(json.dumps(table))
def _parse_configs(user_config): def _parse_configs(user_config):
""" """
check if user's configs are valid. check if user's configs are valid.
...@@ -267,6 +282,15 @@ def quant_aware(program, ...@@ -267,6 +282,15 @@ def quant_aware(program,
scope=scope, place=place, moving_rate=config['moving_rate']) scope=scope, place=place, moving_rate=config['moving_rate'])
out_scale_training_pass.apply(main_graph) out_scale_training_pass.apply(main_graph)
if (weight_preprocess_func is not None or
act_preprocess_func is not None) and not for_test:
_logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
)
_logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
else: else:
...@@ -274,27 +298,28 @@ def quant_aware(program, ...@@ -274,27 +298,28 @@ def quant_aware(program,
return quant_program return quant_program
def quant_post_static(executor, def quant_post_static(
model_dir, executor,
quantize_model_path, model_dir,
batch_generator=None, quantize_model_path,
sample_generator=None, batch_generator=None,
model_filename=None, sample_generator=None,
params_filename=None, model_filename=None,
save_model_filename='__model__', params_filename=None,
save_params_filename='__params__', save_model_filename='__model__',
batch_size=16, save_params_filename='__params__',
batch_nums=None, batch_size=16,
scope=None, batch_nums=None,
algo='KL', scope=None,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], algo='KL',
is_full_quantize=False, quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
weight_bits=8, is_full_quantize=False,
activation_bits=8, weight_bits=8,
activation_quantize_type='range_abs_max', activation_bits=8,
weight_quantize_type='channel_wise_abs_max', activation_quantize_type='range_abs_max',
is_use_cache_file=False, weight_quantize_type='channel_wise_abs_max',
cache_dir="./temp_post_training"): is_use_cache_file=False,
cache_dir="./temp_post_training"):
""" """
The function utilizes static post training quantization method to The function utilizes static post training quantization method to
quantize the fp32 model. It uses calibrate data to calculate the quantize the fp32 model. It uses calibrate data to calculate the
...@@ -381,6 +406,7 @@ def quant_post_static(executor, ...@@ -381,6 +406,7 @@ def quant_post_static(executor,
model_filename=save_model_filename, model_filename=save_model_filename,
params_filename=save_params_filename) params_filename=save_params_filename)
# We have changed the quant_post to quant_post_static. # We have changed the quant_post to quant_post_static.
# For compatibility, we keep quant_post api for now, and it will be # For compatibility, we keep quant_post api for now, and it will be
# deprecated in the future. # deprecated in the future.
...@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -438,6 +464,9 @@ def convert(program, place, config=None, scope=None, save_int8=False):
activation_bits=config['activation_bits'], activation_bits=config['activation_bits'],
weight_quantize_type=config['weight_quantize_type']) weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
...@@ -451,14 +480,14 @@ def convert(program, place, config=None, scope=None, save_int8=False): ...@@ -451,14 +480,14 @@ def convert(program, place, config=None, scope=None, save_int8=False):
def quant_post_dynamic(model_dir, def quant_post_dynamic(model_dir,
save_model_dir, save_model_dir,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
save_model_filename=None, save_model_filename=None,
save_params_filename=None, save_params_filename=None,
quantizable_op_type=["conv2d", "mul"], quantizable_op_type=["conv2d", "mul"],
weight_bits=8, weight_bits=8,
generate_test_model=False): generate_test_model=False):
''' '''
The function utilizes static post training quantization method to The function utilizes static post training quantization method to
quantize the fp32 model. In details, it quantizes the weight of some quantize the fp32 model. In details, it quantizes the weight of some
...@@ -517,4 +546,4 @@ def quant_post_dynamic(model_dir, ...@@ -517,4 +546,4 @@ def quant_post_dynamic(model_dir,
# We have changed the quant_post_only_weight to quant_post_dynamic. # We have changed the quant_post_only_weight to quant_post_dynamic.
# For compatibility, we keep quant_post_only_weight api for now, # For compatibility, we keep quant_post_only_weight api for now,
# and it will be deprecated in the future. # and it will be deprecated in the future.
quant_post_only_weight = quant_post_dynamic quant_post_only_weight = quant_post_dynamic
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册