diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 42b4cbfe9d7b7d8d73bbbb512b68f9dce8f6af5d..8c1198c1aee8faf37938c6b67aa7df1b3806a1b2 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -66,8 +66,8 @@ def arg_parser(): parser.add_argument( "--without_data_format_optimization", "-wo", - action="store_true", - default=True, + type=_text_type, + default="True", help="tf model conversion without data format optimization") parser.add_argument( "--define_input_shape", @@ -93,7 +93,7 @@ def arg_parser(): def tf2paddle(model_path, save_dir, - without_data_format_optimization=False, + without_data_format_optimization, define_input_shape=False, params_merge=False): # check tensorflow installation and version @@ -240,11 +240,12 @@ def main(): if args.framework == "tensorflow": assert args.model is not None, "--model should be defined while translating tensorflow model" - without_data_format_optimization = False + assert args.without_data_format_optimization in [ + "True", "False" + ], "--the param without_data_format_optimization should be defined True or False" define_input_shape = False params_merge = False - if args.without_data_format_optimization: - without_data_format_optimization = True + without_data_format_optimization = True if args.without_data_format_optimization == "True" else False if args.define_input_shape: define_input_shape = True if args.params_merge: