diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 585c829654ecce26627fcb5bc56a76c992263dfe..81f6bb317e96ca538f2816e24beff5ce16e29c24 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -68,6 +68,11 @@ def arg_parser(): action="store_true", default=False, help="define input shape for tf model") + parser.add_argument("--param_merge", + "-pm", + action="store_true", + default=False, + help="define whether merge the params") return parser @@ -75,7 +80,8 @@ def arg_parser(): def tf2paddle(model_path, save_dir, without_data_format_optimization=False, - define_input_shape=False): + define_input_shape=False, + param_merge=False): # check tensorflow installation and version try: import os @@ -121,10 +127,10 @@ def tf2paddle(model_path, optimizer.merge_bias() optimizer.make_nchw_input_output() optimizer.remove_transpose() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, param_merge) -def caffe2paddle(proto, weight, save_dir, caffe_proto): +def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False): from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer @@ -141,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): optimizer = CaffeOptimizer(mapper) optimizer.merge_bn_scale() optimizer.merge_op_activation() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, param_merge) -def onnx2paddle(model_path, save_dir): +def onnx2paddle(model_path, save_dir, param_merge=False): # check onnx installation and version try: import onnx @@ -167,7 +173,7 @@ def onnx2paddle(model_path, save_dir): optimizer = ONNXOptimizer(mapper) optimizer.delete_redundance_code() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, param_merge) def main(): @@ -202,20 +208,29 @@ def main(): assert args.model is not None, "--model should be defined while translating tensorflow model" without_data_format_optimization = False define_input_shape = False + param_merge = False if args.without_data_format_optimization: without_data_format_optimization = True if args.define_input_shape: define_input_shape = True + if args.param_merge: + param_merge = True tf2paddle(args.model, args.save_dir, without_data_format_optimization, - define_input_shape) + define_input_shape, param_merge) elif args.framework == "caffe": assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" + param_merge = False + if args.param_merge: + param_merge = True caffe2paddle(args.prototxt, args.weight, args.save_dir, - args.caffe_proto) + args.caffe_proto, param_merge) elif args.framework == "onnx": assert args.model is not None, "--model should be defined while translating onnx model" - onnx2paddle(args.model, args.save_dir) + param_merge = False + if args.param_merge: + param_merge = True + onnx2paddle(args.model, args.save_dir, param_merge) else: raise Exception("--framework only support tensorflow/caffe/onnx now") diff --git a/x2paddle/core/op_mapper.py b/x2paddle/core/op_mapper.py index 34aebb0905dc1888cd037f51a4d3864005e6e141..7adced2775885775008328ab85c016b9cab6a39d 100644 --- a/x2paddle/core/op_mapper.py +++ b/x2paddle/core/op_mapper.py @@ -110,7 +110,7 @@ class OpMapper(object): self.add_codes("import paddle.fluid as fluid") self.add_codes("") - def save_inference_model(self, save_dir): + def save_inference_model(self, save_dir, param_merge): self.save_python_model(save_dir) import sys @@ -138,13 +138,20 @@ class OpMapper(object): py_code_dir, fluid.default_main_program(), predicate=if_exist) - - fluid.io.save_inference_model(dirname=os.path.join( - save_dir, "inference_model"), - feeded_var_names=input_names, - target_vars=outputs, - executor=exe, - params_filename=None) + if param_merge: + fluid.io.save_inference_model(dirname=os.path.join( + save_dir, "inference_model"), + feeded_var_names=input_names, + target_vars=outputs, + executor=exe, + params_filename="__params__") + else: + fluid.io.save_inference_model(dirname=os.path.join( + save_dir, "inference_model"), + feeded_var_names=input_names, + target_vars=outputs, + executor=exe, + params_filename=None) except: raise Exception( "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."