From 4bb194dd79012a266367f1d9eb5f3765f65fa218 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 17 Sep 2020 14:22:44 +0800 Subject: [PATCH] fix the convert.py --- x2paddle/convert.py | 63 +++++++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 1503913..d8e7d0c 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -13,7 +13,6 @@ # limitations under the License. from six import text_type as _text_type -from x2paddle import program import argparse import sys @@ -67,8 +66,8 @@ def arg_parser(): parser.add_argument( "--without_data_format_optimization", "-wo", - action="store_true", - default=False, + type=_text_type, + default="True", help="tf model conversion without data format optimization") parser.add_argument( "--define_input_shape", @@ -94,13 +93,11 @@ def arg_parser(): action='append', default=None, help="define the inputs' shape") - return parser - def tf2paddle(model_path, save_dir, - without_data_format_optimization=False, + without_data_format_optimization, define_input_shape=False, params_merge=False): # check tensorflow installation and version @@ -127,10 +124,29 @@ def tf2paddle(model_path, print("Now translating model from tensorflow to paddle.") model = TFDecoder(model_path, define_input_shape=define_input_shape) - - mapper = TFOpMapperNHWC(model) - program.build() - program.gen_model(save_dir) + if not without_data_format_optimization: + mapper = TFOpMapper(model) + optimizer = TFOptimizer(mapper) + # neccesary optimization + optimizer.delete_redundance_code() + # optimizer below is experimental + optimizer.optimize_elementwise_op() + optimizer.merge_activation() + optimizer.merge_bias() + optimizer.optimize_sub_graph() + +# optimizer.merge_batch_norm() +# optimizer.merge_prelu() + else: + mapper = TFOpMapperNHWC(model) + optimizer = TFOptimizer(mapper) + optimizer.delete_redundance_code() + optimizer.strip_graph() + optimizer.merge_activation() + optimizer.merge_bias() + optimizer.make_nchw_input_output() + optimizer.remove_transpose() + mapper.save_inference_model(save_dir, params_merge) def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): @@ -158,8 +174,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): try: import onnx version = onnx.version.version - if version != '1.6.0': - print("[ERROR] onnx==1.6.0 is required") + if version < '1.6.0': + print("[ERROR] onnx>=1.6.0 is required") return except: print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") @@ -178,8 +194,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): print("Paddle model and code generating ...") mapper.save_inference_model(save_dir, params_merge) print("Paddle model and code generated.") - - + + def pytorch2paddle(model_path, save_dir, input_shapes): # check pytorch installation and version try: @@ -222,9 +238,14 @@ def pytorch2paddle(model_path, save_dir, input_shapes): def paddle2onnx(model_path, save_dir, opset_version=10): from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper + import paddle.fluid as fluid model = PaddleDecoder(model_path, '__model__', '__params__') mapper = PaddleOpMapper() - mapper.convert(model.program, save_dir, opset_number=opset_version) + mapper.convert( + model.program, + save_dir, + scope=fluid.global_scope(), + opset_version=opset_version) def main(): @@ -262,11 +283,12 @@ def main(): if args.framework == "tensorflow": assert args.model is not None, "--model should be defined while translating tensorflow model" - without_data_format_optimization = False + assert args.without_data_format_optimization in [ + "True", "False" + ], "--the param without_data_format_optimization should be defined True or False" define_input_shape = False params_merge = False - if args.without_data_format_optimization: - without_data_format_optimization = True + without_data_format_optimization = True if args.without_data_format_optimization == "True" else False if args.define_input_shape: define_input_shape = True if args.params_merge: @@ -288,13 +310,14 @@ def main(): if args.params_merge: params_merge = True onnx2paddle(args.model, args.save_dir, params_merge) + elif args.framework == "pytorch": assert args.model is not None, "--model should be defined while translating pytorch model" pytorch2paddle(args.model, args.save_dir, args.input_shapes) elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" - paddle2onnx(args.model, args.save_dir, args.onnx_opset) + paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset) else: raise Exception( @@ -302,4 +325,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file -- GitLab