diff --git a/python/paddle/fluid/contrib/slim/tests/save_quant_model.py b/python/paddle/fluid/contrib/slim/tests/save_quant_model.py index 3fadf25150f9ef3556a343fdce8acc24d788f5dc..f97c2778c0918ecbfbed546089c17e9d505818cd 100644 --- a/python/paddle/fluid/contrib/slim/tests/save_quant_model.py +++ b/python/paddle/fluid/contrib/slim/tests/save_quant_model.py @@ -52,6 +52,30 @@ def parse_args(): '--debug', action='store_true', help='If used, the graph of Quant model is drawn.') + parser.add_argument( + '--quant_model_filename', + type=str, + default="", + help='The input model`s file name. If empty, search default `__model__` and separate parameter files and use them or in case if not found, attempt loading `model` and `params` files.' + ) + parser.add_argument( + '--quant_params_filename', + type=str, + default="", + help='If quant_model_filename is empty, this field is ignored. The input model`s all parameters file name. If empty load parameters from separate files.' + ) + parser.add_argument( + '--save_model_filename', + type=str, + default="__model__", + help='The name of file to save the inference program itself. If is set None, a default filename __model__ will be used.' + ) + parser.add_argument( + '--save_params_filename', + type=str, + default=None, + help='The name of file to save all related parameters. If it is set None, parameters will be saved in separate files' + ) test_args, args = parser.parse_known_args(namespace=unittest) return test_args, sys.argv[:1] + args @@ -61,18 +85,29 @@ def transform_and_save_int8_model(original_path, save_path, ops_to_quantize='', op_ids_to_skip='', - debug=False): + debug=False, + quant_model_filename='', + quant_params_filename='', + save_model_filename='', + save_params_filename=''): place = fluid.CPUPlace() exe = fluid.Executor(place) inference_scope = fluid.executor.global_scope() with fluid.scope_guard(inference_scope): - if os.path.exists(os.path.join(original_path, '__model__')): - [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(original_path, exe) + if not quant_model_filename: + if os.path.exists(os.path.join(original_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(original_path, + exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + original_path, exe, 'model', 'params') else: [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(original_path, exe, - 'model', 'params') + fetch_targets] = fluid.io.load_inference_model( + original_path, exe, quant_model_filename, + quant_params_filename) ops_to_quantize_set = set() print(ops_to_quantize) @@ -97,8 +132,14 @@ def transform_and_save_int8_model(original_path, graph = transform_to_mkldnn_int8_pass.apply(graph) inference_program = graph.to_program() with fluid.scope_guard(inference_scope): - fluid.io.save_inference_model(save_path, feed_target_names, - fetch_targets, exe, inference_program) + fluid.io.save_inference_model( + save_path, + feed_target_names, + fetch_targets, + exe, + inference_program, + model_filename=save_model_filename, + params_filename=save_params_filename) print( "Success! INT8 model obtained from the Quant model can be found at {}\n" .format(save_path)) @@ -109,4 +150,6 @@ if __name__ == '__main__': test_args, remaining_args = parse_args() transform_and_save_int8_model( test_args.quant_model_path, test_args.int8_model_save_path, - test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug) + test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug, + test_args.quant_model_filename, test_args.quant_params_filename, + test_args.save_model_filename, test_args.save_params_filename)