提交 4bb194dd 编写于 作者: S SunAhong1993

fix the convert.py

上级 2ed1bcd8
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from six import text_type as _text_type from six import text_type as _text_type
from x2paddle import program
import argparse import argparse
import sys import sys
...@@ -67,8 +66,8 @@ def arg_parser(): ...@@ -67,8 +66,8 @@ def arg_parser():
parser.add_argument( parser.add_argument(
"--without_data_format_optimization", "--without_data_format_optimization",
"-wo", "-wo",
action="store_true", type=_text_type,
default=False, default="True",
help="tf model conversion without data format optimization") help="tf model conversion without data format optimization")
parser.add_argument( parser.add_argument(
"--define_input_shape", "--define_input_shape",
...@@ -94,13 +93,11 @@ def arg_parser(): ...@@ -94,13 +93,11 @@ def arg_parser():
action='append', action='append',
default=None, default=None,
help="define the inputs' shape") help="define the inputs' shape")
return parser return parser
def tf2paddle(model_path, def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization,
define_input_shape=False, define_input_shape=False,
params_merge=False): params_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
...@@ -127,10 +124,29 @@ def tf2paddle(model_path, ...@@ -127,10 +124,29 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
if not without_data_format_optimization:
mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model) mapper = TFOpMapperNHWC(model)
program.build() optimizer = TFOptimizer(mapper)
program.gen_model(save_dir) optimizer.delete_redundance_code()
optimizer.strip_graph()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir, params_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False): def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
...@@ -158,8 +174,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -158,8 +174,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try: try:
import onnx import onnx
version = onnx.version.version version = onnx.version.version
if version != '1.6.0': if version < '1.6.0':
print("[ERROR] onnx==1.6.0 is required") print("[ERROR] onnx>=1.6.0 is required")
return return
except: except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
...@@ -222,9 +238,14 @@ def pytorch2paddle(model_path, save_dir, input_shapes): ...@@ -222,9 +238,14 @@ def pytorch2paddle(model_path, save_dir, input_shapes):
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.decoder.paddle_decoder import PaddleDecoder
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
import paddle.fluid as fluid
model = PaddleDecoder(model_path, '__model__', '__params__') model = PaddleDecoder(model_path, '__model__', '__params__')
mapper = PaddleOpMapper() mapper = PaddleOpMapper()
mapper.convert(model.program, save_dir, opset_number=opset_version) mapper.convert(
model.program,
save_dir,
scope=fluid.global_scope(),
opset_version=opset_version)
def main(): def main():
...@@ -262,11 +283,12 @@ def main(): ...@@ -262,11 +283,12 @@ def main():
if args.framework == "tensorflow": if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False assert args.without_data_format_optimization in [
"True", "False"
], "--the param without_data_format_optimization should be defined True or False"
define_input_shape = False define_input_shape = False
params_merge = False params_merge = False
if args.without_data_format_optimization: without_data_format_optimization = True if args.without_data_format_optimization == "True" else False
without_data_format_optimization = True
if args.define_input_shape: if args.define_input_shape:
define_input_shape = True define_input_shape = True
if args.params_merge: if args.params_merge:
...@@ -288,13 +310,14 @@ def main(): ...@@ -288,13 +310,14 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge) onnx2paddle(args.model, args.save_dir, params_merge)
elif args.framework == "pytorch": elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model" assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir, args.input_shapes) pytorch2paddle(args.model, args.save_dir, args.input_shapes)
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, args.onnx_opset) paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
else: else:
raise Exception( raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册