未验证 提交 9ab90c96 编写于 作者: G Guanghua Yu 提交者: GitHub

support quant onnx format in ACT (#1322)

上级 8b156124
...@@ -65,6 +65,8 @@ add_arg('use_pact', bool, True, ...@@ -65,6 +65,8 @@ add_arg('use_pact', bool, True,
"Whether to use PACT or not.") "Whether to use PACT or not.")
add_arg('analysis', bool, False, add_arg('analysis', bool, False,
"Whether analysis variables distribution.") "Whether analysis variables distribution.")
add_arg('onnx_format', bool, False,
"Whether use onnx format or not.")
add_arg('ce_test', bool, False, "Whether to CE test.") add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable # yapf: enable
...@@ -257,6 +259,8 @@ def compress(args): ...@@ -257,6 +259,8 @@ def compress(args):
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# Whether use onnx format or not
'onnx_format': args.onnx_format,
} }
# 2. quantization transform programs (training aware) # 2. quantization transform programs (training aware)
...@@ -298,9 +302,9 @@ def compress(args): ...@@ -298,9 +302,9 @@ def compress(args):
places, places,
quant_config, quant_config,
scope=None, scope=None,
act_preprocess_func=act_preprocess_func, act_preprocess_func=None,
optimizer_func=optimizer_func, optimizer_func=None,
executor=executor, executor=None,
for_test=True) for_test=True)
compiled_train_prog = quant_aware( compiled_train_prog = quant_aware(
train_prog, train_prog,
...@@ -425,29 +429,23 @@ def compress(args): ...@@ -425,29 +429,23 @@ def compress(args):
# 3. Freeze the graph after training by adjusting the quantize # 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range. # The dtype of float_program's weights is float32, but in int8 range.
float_program, int8_program = convert(val_program, places, quant_config, \ model_path = os.path.join(quantization_model_save_dir, args.model)
scope=None, \ if not os.path.isdir(model_path):
save_int8=True) os.makedirs(model_path)
float_program = convert(val_program, places, quant_config)
_logger.info("eval best_model after convert") _logger.info("eval best_model after convert")
final_acc1 = test(best_epoch, float_program) final_acc1 = test(best_epoch, float_program)
_logger.info("final acc:{}".format(final_acc1)) _logger.info("final acc:{}".format(final_acc1))
# 4. Save inference model # 4. Save inference model
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type']
+ '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float')
if not os.path.isdir(model_path):
os.makedirs(model_path)
paddle.fluid.io.save_inference_model( paddle.fluid.io.save_inference_model(
dirname=float_path, dirname=model_path,
feeded_var_names=[image.name], feeded_var_names=[image.name],
target_vars=[out], target_vars=[out],
executor=exe, executor=exe,
main_program=float_program, main_program=float_program,
model_filename=float_path + '/model', model_filename=model_path + '/model.pdmodel',
params_filename=float_path + '/params') params_filename=model_path + '/model.pdiparams')
def main(): def main():
......
...@@ -126,6 +126,8 @@ def compress(args): ...@@ -126,6 +126,8 @@ def compress(args):
'window_size': 10000, 'window_size': 10000,
# The decay coefficient of moving average, default is 0.9 # The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9, 'moving_rate': 0.9,
# Whether use onnx format or not
'onnx_format': args.onnx_format,
} }
pretrain = True pretrain = True
...@@ -294,10 +296,7 @@ def compress(args): ...@@ -294,10 +296,7 @@ def compress(args):
# operators' order for the inference. # operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range. # The dtype of float_program's weights is float32, but in int8 range.
############################################################################################################ ############################################################################################################
float_program, int8_program = convert(val_program, places, quant_config, \ float_program = convert(val_program, places, quant_config)
scope=None, \
save_int8=True,
onnx_format=args.onnx_format)
print("eval best_model after convert") print("eval best_model after convert")
final_acc1 = test(best_epoch, float_program) final_acc1 = test(best_epoch, float_program)
############################################################################################################ ############################################################################################################
......
...@@ -14,6 +14,7 @@ Distillation: ...@@ -14,6 +14,7 @@ Distillation:
Quantization: Quantization:
use_pact: true use_pact: true
onnx_format: False
activation_quantize_type: 'moving_average_abs_max' activation_quantize_type: 'moving_average_abs_max'
quantize_op_types: quantize_op_types:
- conv2d - conv2d
......
...@@ -787,15 +787,18 @@ class AutoCompression: ...@@ -787,15 +787,18 @@ class AutoCompression:
os.remove(os.path.join(self.tmp_dir, 'best_model.pdopt')) os.remove(os.path.join(self.tmp_dir, 'best_model.pdopt'))
os.remove(os.path.join(self.tmp_dir, 'best_model.pdparams')) os.remove(os.path.join(self.tmp_dir, 'best_model.pdparams'))
if 'qat' in strategy:
test_program, int8_program = convert(test_program, self._places, self._quant_config, \
scope=paddle.static.global_scope(), \
save_int8=True)
model_dir = os.path.join(self.tmp_dir, model_dir = os.path.join(self.tmp_dir,
'strategy_{}'.format(str(strategy_idx + 1))) 'strategy_{}'.format(str(strategy_idx + 1)))
if not os.path.exists(model_dir): if not os.path.exists(model_dir):
os.makedirs(model_dir) os.makedirs(model_dir)
if 'qat' in strategy:
test_program = convert(
test_program,
self._places,
self._quant_config,
scope=paddle.static.global_scope())
feed_vars = [ feed_vars = [
test_program.global_block().var(name) test_program.global_block().var(name)
for name in test_program_info.feed_target_names for name in test_program_info.feed_target_names
......
...@@ -65,6 +65,7 @@ class Quantization(BaseStrategy): ...@@ -65,6 +65,7 @@ class Quantization(BaseStrategy):
window_size=10000, window_size=10000,
moving_rate=0.9, moving_rate=0.9,
for_tensorrt=False, for_tensorrt=False,
onnx_format=False,
is_full_quantize=False): is_full_quantize=False):
""" """
Quantization Config. Quantization Config.
...@@ -80,6 +81,7 @@ class Quantization(BaseStrategy): ...@@ -80,6 +81,7 @@ class Quantization(BaseStrategy):
window_size(int): Window size for 'range_abs_max' quantization. Default: 10000. window_size(int): Window size for 'range_abs_max' quantization. Default: 10000.
moving_rate(float): The decay coefficient of moving average. Default: 0.9. moving_rate(float): The decay coefficient of moving average. Default: 0.9.
for_tensorrt(bool): If True, 'quantize_op_types' will be TENSORRT_OP_TYPES. Default: False. for_tensorrt(bool): If True, 'quantize_op_types' will be TENSORRT_OP_TYPES. Default: False.
onnx_format(bool): Whether to export the quantized model with format of ONNX. Default is False.
is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False. is_full_quantize(bool): If True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES. Default: False.
""" """
super(Quantization, self).__init__("Quantization") super(Quantization, self).__init__("Quantization")
...@@ -95,6 +97,7 @@ class Quantization(BaseStrategy): ...@@ -95,6 +97,7 @@ class Quantization(BaseStrategy):
self.window_size = window_size self.window_size = window_size
self.moving_rate = moving_rate self.moving_rate = moving_rate
self.for_tensorrt = for_tensorrt self.for_tensorrt = for_tensorrt
self.onnx_format = onnx_format
self.is_full_quantize = is_full_quantize self.is_full_quantize = is_full_quantize
......
...@@ -91,7 +91,9 @@ _quant_config_default = { ...@@ -91,7 +91,9 @@ _quant_config_default = {
# if True, 'quantize_op_types' will be TENSORRT_OP_TYPES # if True, 'quantize_op_types' will be TENSORRT_OP_TYPES
'for_tensorrt': False, 'for_tensorrt': False,
# if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES # if True, 'quantoze_op_types' will be TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES
'is_full_quantize': False 'is_full_quantize': False,
# if True, use onnx format to quant.
'onnx_format': False,
} }
...@@ -222,7 +224,6 @@ def quant_aware(program, ...@@ -222,7 +224,6 @@ def quant_aware(program,
act_preprocess_func=None, act_preprocess_func=None,
optimizer_func=None, optimizer_func=None,
executor=None, executor=None,
onnx_format=False,
return_program=False, return_program=False,
draw_graph=False): draw_graph=False):
"""Add quantization and dequantization operators to "program" """Add quantization and dequantization operators to "program"
...@@ -236,7 +237,9 @@ def quant_aware(program, ...@@ -236,7 +237,9 @@ def quant_aware(program,
Default: None. Default: None.
scope(paddle.static.Scope): Scope records the mapping between variable names and variables, scope(paddle.static.Scope): Scope records the mapping between variable names and variables,
similar to brackets in programming languages. Usually users can use similar to brackets in programming languages. Usually users can use
`paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_. When ``None`` will use `paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ . Default: ``None``. `paddle.static.global_scope <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_.
When ``None`` will use `paddle.static.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ .
Default: ``None``.
for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``. for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``.
Otherwise, set to ``False``.Default: False Otherwise, set to ``False``.Default: False
weight_quantize_func(function): Function that defines how to quantize weight. Using this weight_quantize_func(function): Function that defines how to quantize weight. Using this
...@@ -291,7 +294,8 @@ def quant_aware(program, ...@@ -291,7 +294,8 @@ def quant_aware(program,
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type) quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0: if len(transform_pass_ops) > 0:
trannsform_func = 'QuantizationTransformPassV2' if onnx_format else 'QuantizationTransformPass' trannsform_func = 'QuantizationTransformPassV2' if config[
'onnx_format'] else 'QuantizationTransformPass'
transform_pass = eval(trannsform_func)( transform_pass = eval(trannsform_func)(
scope=scope, scope=scope,
place=place, place=place,
...@@ -313,7 +317,8 @@ def quant_aware(program, ...@@ -313,7 +317,8 @@ def quant_aware(program,
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0: if len(quant_dequant_ops) > 0:
qdq_func = 'AddQuantDequantPassV2' if onnx_format else 'AddQuantDequantPass' qdq_func = 'AddQuantDequantPassV2' if config[
'onnx_format'] else 'AddQuantDequantPass'
quant_dequant_pass = eval(qdq_func)( quant_dequant_pass = eval(qdq_func)(
scope=scope, scope=scope,
place=place, place=place,
...@@ -516,12 +521,7 @@ def quant_post_static( ...@@ -516,12 +521,7 @@ def quant_post_static(
quant_post = quant_post_static quant_post = quant_post_static
def convert(program, def convert(program, place, config=None, scope=None, save_int8=False):
place,
config=None,
scope=None,
save_int8=False,
onnx_format=False):
""" """
convert quantized and well-trained ``program`` to final quantized convert quantized and well-trained ``program`` to final quantized
``program``that can be used to save ``inference model``. ``program``that can be used to save ``inference model``.
...@@ -560,7 +560,7 @@ def convert(program, ...@@ -560,7 +560,7 @@ def convert(program,
_logger.info("convert config {}".format(config)) _logger.info("convert config {}".format(config))
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
if onnx_format: if config['onnx_format']:
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph) quant_weight_pass.apply(test_graph)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册