未验证 提交 db170b2b 编写于 作者: P Paulina Gacek 提交者: GitHub

Remove tests with save_quant_model.py (#50307)

* got rid of save_quant_model

* review changes
上级 7f87d75b
......@@ -200,32 +200,6 @@ function(download_quant_model install_dir data_file check_sum)
endif()
endfunction()
function(save_quant_ic_model_test target quant_model_dir int8_model_save_path)
py_test(
${target}
SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS
--quant_model_path
${quant_model_dir}
--int8_model_save_path
${int8_model_save_path}
--debug)
endfunction()
function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path
ops_to_quantize)
py_test(
${target}
SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS
--quant_model_path
${quant_model_dir}
--int8_model_save_path
${int8_model_save_path}
--ops_to_quantize
${ops_to_quantize})
endfunction()
function(convert_model2dot_test target model_path save_graph_dir
save_graph_name)
py_test(
......@@ -438,31 +412,6 @@ if(LINUX AND WITH_MKLDNN)
set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test")
download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE}
40a693803b12ee9e251258f32559abcb)
set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm")
### Save FP32 model or INT8 model from Quant model
set(QUANT2_INT8_RESNET50_SAVE_PATH
"${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
save_quant_ic_model_test(
save_quant2_model_resnet50
${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float
${QUANT2_INT8_RESNET50_SAVE_PATH})
set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8")
save_quant_nlp_model_test(
save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8")
save_quant_nlp_model_test(
save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc
${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE})
set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8")
save_quant_nlp_model_test(
save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant
${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE})
# Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH
......@@ -474,6 +423,7 @@ if(LINUX AND WITH_MKLDNN)
### PTQ INT8
# PTQ int8 lstm model
set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8")
set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz")
set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data")
download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE}
......@@ -561,7 +511,6 @@ if(LINUX AND WITH_MKLDNN)
120)
set_tests_properties(test_quant2_int8_resnet50_range_mkldnn PROPERTIES TIMEOUT
120)
set_tests_properties(save_quant2_model_resnet50 PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
120)
......
......@@ -19,7 +19,6 @@ import time
import unittest
import numpy as np
from save_quant_model import transform_and_save_int8_model
import paddle
from paddle.framework import core
......@@ -107,7 +106,7 @@ class TestLstmModelPTQ(unittest.TestCase):
mkldnn_cache_capacity,
warmup_data=None,
use_analysis=False,
enable_ptq=False,
mode="fp32",
):
config = core.AnalysisConfig(model_path)
config.set_cpu_math_library_num_threads(num_threads)
......@@ -118,12 +117,15 @@ class TestLstmModelPTQ(unittest.TestCase):
config.enable_mkldnn()
config.disable_mkldnn_fc_passes() # fc passes caused dnnl error
config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
if enable_ptq:
if mode == "ptq":
# This pass to work properly, must be added before fc_fuse_pass
config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass")
config.enable_quantizer()
config.quantizer_config().set_quant_data(warmup_data)
config.quantizer_config().set_quant_batch_size(1)
elif mode == "qat":
config.enable_mkldnn_int8()
return config
def run_program(
......@@ -134,7 +136,7 @@ class TestLstmModelPTQ(unittest.TestCase):
mkldnn_cache_capacity,
warmup_iter,
use_analysis=False,
enable_ptq=False,
mode="fp32",
):
place = paddle.CPUPlace()
warmup_data, inputs = self.get_warmup_tensor(data_path, place)
......@@ -145,7 +147,7 @@ class TestLstmModelPTQ(unittest.TestCase):
mkldnn_cache_capacity,
warmup_data,
use_analysis,
enable_ptq,
mode,
)
predictor = core.create_paddle_predictor(config)
......@@ -228,7 +230,7 @@ class TestLstmModelPTQ(unittest.TestCase):
mkldnn_cache_capacity,
warmup_iter,
False,
False,
mode="fp32",
)
(int8_hx_acc, int8_ctc_acc, int8_fps) = self.run_program(
......@@ -238,23 +240,17 @@ class TestLstmModelPTQ(unittest.TestCase):
mkldnn_cache_capacity,
warmup_iter,
True,
True,
)
quant_model_save_path = quant_model + "_int8"
# transform model to quant2
transform_and_save_int8_model(
quant_model, quant_model_save_path, "fusion_lstm,concat"
mode="ptq",
)
(quant_hx_acc, quant_ctc_acc, quant_fps) = self.run_program(
quant_model_save_path,
quant_model + "_int8",
infer_data,
num_threads,
mkldnn_cache_capacity,
warmup_iter,
True,
False,
mode="qat",
)
print(
......@@ -270,7 +266,7 @@ class TestLstmModelPTQ(unittest.TestCase):
)
print(
"QUANT2_INT8: fps {0}, hx_acc {1}, ctc_acc {2}".format(
"QAT: fps {0}, hx_acc {1}, ctc_acc {2}".format(
quant_fps, quant_hx_acc, quant_ctc_acc
)
)
......
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import argparse
import os
import sys
import unittest
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
from paddle.static.quantization import Quant2Int8MkldnnPass
paddle.enable_static()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--quant_model_path',
type=str,
default='',
help='A path to a Quant model.',
)
parser.add_argument(
'--int8_model_save_path',
type=str,
default='',
help='Saved optimized and quantized INT8 model',
)
parser.add_argument(
'--ops_to_quantize',
type=str,
default='',
help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.',
)
parser.add_argument(
'--op_ids_to_skip',
type=str,
default='',
help='A comma separated list of operator ids to skip in quantization.',
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of Quant model is drawn.',
)
parser.add_argument(
'--quant_model_filename',
type=str,
default="",
help='The input model`s file name. If empty, search default `__model__` and separate parameter files and use them or in case if not found, attempt loading `model` and `params` files.',
)
parser.add_argument(
'--quant_params_filename',
type=str,
default="",
help='If quant_model_filename is empty, this field is ignored. The input model`s all parameters file name. If empty load parameters from separate files.',
)
parser.add_argument(
'--save_model_filename',
type=str,
default="__model__",
help='The name of file to save the inference program itself. If is set None, a default filename __model__ will be used.',
)
parser.add_argument(
'--save_params_filename',
type=str,
default=None,
help='The name of file to save all related parameters. If it is set None, parameters will be saved in separate files',
)
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
def transform_and_save_int8_model(
original_path,
save_path,
ops_to_quantize='',
op_ids_to_skip='',
debug=False,
quant_model_filename='',
quant_params_filename='',
save_model_filename="model",
save_params_filename=None,
):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
inference_scope = paddle.static.global_scope()
with paddle.static.scope_guard(inference_scope):
if not quant_model_filename:
if os.path.exists(os.path.join(original_path, '__model__')):
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.fluid.io.load_inference_model(original_path, exe)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
original_path,
exe,
model_filename='model',
params_filename='params',
)
else:
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(
original_path,
exe,
model_filename=quant_model_filename,
params_filename=quant_params_filename,
)
ops_to_quantize_set = set()
print(ops_to_quantize)
if len(ops_to_quantize) > 0:
ops_to_quantize_set = set(ops_to_quantize.split(','))
op_ids_to_skip_set = set([-1])
print(op_ids_to_skip)
if len(op_ids_to_skip) > 0:
op_ids_to_skip_set = set(map(int, op_ids_to_skip.split(',')))
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if debug:
graph.draw('.', 'quant_orig', graph.all_op_nodes())
transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
ops_to_quantize_set,
_op_ids_to_skip=op_ids_to_skip_set,
_scope=inference_scope,
_place=place,
_core=core,
_debug=debug,
)
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
with paddle.static.scope_guard(inference_scope):
path_prefix = os.path.join(save_path, save_model_filename)
feed_vars = [
inference_program.global_block().var(name)
for name in feed_target_names
]
paddle.static.save_inference_model(
path_prefix,
feed_vars,
fetch_targets,
executor=exe,
program=inference_program,
)
print(
"Success! INT8 model obtained from the Quant model can be found at {}\n".format(
save_path
)
)
if __name__ == '__main__':
global test_args
test_args, remaining_args = parse_args()
transform_and_save_int8_model(
test_args.quant_model_path,
test_args.int8_model_save_path,
test_args.ops_to_quantize,
test_args.op_ids_to_skip,
test_args.debug,
test_args.quant_model_filename,
test_args.quant_params_filename,
test_args.save_model_filename,
test_args.save_params_filename,
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册