diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 0dcda088e0f530b3c9a42e252cd7c151d6a61813..57da0f0626b6c5f6b94b8345cecb947ec1e024b9 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -68,6 +68,11 @@ def arg_parser(): action="store_true", default=False, help="define input shape for tf model") + parser.add_argument("--params_merge", + "-pm", + action="store_true", + default=False, + help="define whether merge the params") return parser @@ -75,7 +80,8 @@ def arg_parser(): def tf2paddle(model_path, save_dir, without_data_format_optimization=False, - define_input_shape=False): + define_input_shape=False, + params_merge=False): # check tensorflow installation and version try: import os @@ -104,10 +110,10 @@ def tf2paddle(model_path, optimizer.strip_graph() # optimizer.merge_activation() # optimizer.merge_bias() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, params_merge) -def caffe2paddle(proto, weight, save_dir, caffe_proto): +def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer @@ -124,10 +130,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): optimizer = CaffeOptimizer(mapper) optimizer.merge_bn_scale() optimizer.merge_op_activation() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, params_merge) -def onnx2paddle(model_path, save_dir): +def onnx2paddle(model_path, save_dir, params_merge=False): # check onnx installation and version try: import onnx @@ -150,7 +156,7 @@ def onnx2paddle(model_path, save_dir): optimizer = ONNXOptimizer(mapper) optimizer.delete_redundance_code() - mapper.save_inference_model(save_dir) + mapper.save_inference_model(save_dir, params_merge) def main(): @@ -193,20 +199,29 @@ def main(): assert args.model is not None, "--model should be defined while translating tensorflow model" without_data_format_optimization = False define_input_shape = False + params_merge = False if args.without_data_format_optimization: without_data_format_optimization = True if args.define_input_shape: define_input_shape = True + if args.params_merge: + params_merge = True tf2paddle(args.model, args.save_dir, without_data_format_optimization, - define_input_shape) + define_input_shape, params_merge) elif args.framework == "caffe": assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" + params_merge = False + if args.params_merge: + params_merge = True caffe2paddle(args.prototxt, args.weight, args.save_dir, - args.caffe_proto) + args.caffe_proto, params_merge) elif args.framework == "onnx": assert args.model is not None, "--model should be defined while translating onnx model" - onnx2paddle(args.model, args.save_dir) + params_merge = False + if args.params_merge: + params_merge = True + onnx2paddle(args.model, args.save_dir, params_merge) else: raise Exception("--framework only support tensorflow/caffe/onnx now") diff --git a/x2paddle/core/op_mapper.py b/x2paddle/core/op_mapper.py index 34aebb0905dc1888cd037f51a4d3864005e6e141..d311e3093f2697137dc334bf4b32a21465bb6328 100644 --- a/x2paddle/core/op_mapper.py +++ b/x2paddle/core/op_mapper.py @@ -110,7 +110,7 @@ class OpMapper(object): self.add_codes("import paddle.fluid as fluid") self.add_codes("") - def save_inference_model(self, save_dir): + def save_inference_model(self, save_dir, params_merge): self.save_python_model(save_dir) import sys @@ -138,13 +138,20 @@ class OpMapper(object): py_code_dir, fluid.default_main_program(), predicate=if_exist) - - fluid.io.save_inference_model(dirname=os.path.join( - save_dir, "inference_model"), - feeded_var_names=input_names, - target_vars=outputs, - executor=exe, - params_filename=None) + if params_merge: + fluid.io.save_inference_model(dirname=os.path.join( + save_dir, "inference_model"), + feeded_var_names=input_names, + target_vars=outputs, + executor=exe, + params_filename="__params__") + else: + fluid.io.save_inference_model(dirname=os.path.join( + save_dir, "inference_model"), + feeded_var_names=input_names, + target_vars=outputs, + executor=exe, + params_filename=None) except: raise Exception( "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."