未验证 提交 dec2b1ca 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Modify save_quant_model to support different input and output filenames (#40542)

* Modify save_quant_model.py to support differnet input and output filenames

* Correct wrong order of arguments
上级 23c036d6
......@@ -52,6 +52,30 @@ def parse_args():
'--debug',
action='store_true',
help='If used, the graph of Quant model is drawn.')
parser.add_argument(
'--quant_model_filename',
type=str,
default="",
help='The input model`s file name. If empty, search default `__model__` and separate parameter files and use them or in case if not found, attempt loading `model` and `params` files.'
)
parser.add_argument(
'--quant_params_filename',
type=str,
default="",
help='If quant_model_filename is empty, this field is ignored. The input model`s all parameters file name. If empty load parameters from separate files.'
)
parser.add_argument(
'--save_model_filename',
type=str,
default="__model__",
help='The name of file to save the inference program itself. If is set None, a default filename __model__ will be used.'
)
parser.add_argument(
'--save_params_filename',
type=str,
default=None,
help='The name of file to save all related parameters. If it is set None, parameters will be saved in separate files'
)
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
......@@ -61,18 +85,29 @@ def transform_and_save_int8_model(original_path,
save_path,
ops_to_quantize='',
op_ids_to_skip='',
debug=False):
debug=False,
quant_model_filename='',
quant_params_filename='',
save_model_filename='',
save_params_filename=''):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(original_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path, exe)
if not quant_model_filename:
if os.path.exists(os.path.join(original_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path,
exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
original_path, exe, 'model', 'params')
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(original_path, exe,
'model', 'params')
fetch_targets] = fluid.io.load_inference_model(
original_path, exe, quant_model_filename,
quant_params_filename)
ops_to_quantize_set = set()
print(ops_to_quantize)
......@@ -97,8 +132,14 @@ def transform_and_save_int8_model(original_path,
graph = transform_to_mkldnn_int8_pass.apply(graph)
inference_program = graph.to_program()
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(save_path, feed_target_names,
fetch_targets, exe, inference_program)
fluid.io.save_inference_model(
save_path,
feed_target_names,
fetch_targets,
exe,
inference_program,
model_filename=save_model_filename,
params_filename=save_params_filename)
print(
"Success! INT8 model obtained from the Quant model can be found at {}\n"
.format(save_path))
......@@ -109,4 +150,6 @@ if __name__ == '__main__':
test_args, remaining_args = parse_args()
transform_and_save_int8_model(
test_args.quant_model_path, test_args.int8_model_save_path,
test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug)
test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug,
test_args.quant_model_filename, test_args.quant_params_filename,
test_args.save_model_filename, test_args.save_params_filename)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册