提交 672be806 编写于 作者: R root

add scope default val None

上级 5e0c1cc3
...@@ -26,6 +26,7 @@ WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max'] ...@@ -26,6 +26,7 @@ WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max']
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max']
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max', 'weight_quantize_type': 'abs_max',
...@@ -50,6 +51,7 @@ _quant_config_default = { ...@@ -50,6 +51,7 @@ _quant_config_default = {
'quant_weight_only': False 'quant_weight_only': False
} }
def _parse_configs(user_config): def _parse_configs(user_config):
""" """
check user configs is valid, and set default value if user not config. check user configs is valid, and set default value if user not config.
...@@ -107,8 +109,7 @@ def _parse_configs(user_config): ...@@ -107,8 +109,7 @@ def _parse_configs(user_config):
return configs return configs
def quant_aware(program, place, config, scope=None, for_test=False):
def quant_aware(program, scope, place, config, for_test=False):
""" """
add trainable quantization ops in program. add trainable quantization ops in program.
Args: Args:
...@@ -151,7 +152,8 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -151,7 +152,8 @@ def quant_aware(program, scope, place, config, for_test=False):
quant_program = fluid.CompiledProgram(main_graph.graph) quant_program = fluid.CompiledProgram(main_graph.graph)
return quant_program return quant_program
def quant_post(program, scope, place, config):
def quant_post(program, place, config, scope=None):
""" """
add quantization ops in program. the program returned is not trainable. add quantization ops in program. the program returned is not trainable.
Args: Args:
...@@ -162,26 +164,9 @@ def quant_post(program, scope, place, config): ...@@ -162,26 +164,9 @@ def quant_post(program, scope, place, config):
for_test: is for test program. for_test: is for test program.
Return: Return:
fluid.Program: the quantization program is not trainable. fluid.Program: the quantization program is not trainable.
""" """
pass
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quantize_type' in config.keys(), 'weight_quantize_type must be configured'
assert 'activation_quantize_type' in config.keys(), 'activation_quantize_type must be configured'
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope, place=place,
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'])
transform_pass.apply(main_graph)
quant_program = main_graph.to_program()
return quant_program
def convert(program, scope, place, config, save_int8=False): def convert(program, scope, place, config, save_int8=False):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册