提交 162d160f 编写于 作者: S SunAhong1993

fix the modify

上级 77565629
...@@ -67,8 +67,8 @@ def arg_parser(): ...@@ -67,8 +67,8 @@ def arg_parser():
parser.add_argument( parser.add_argument(
"--without_data_format_optimization", "--without_data_format_optimization",
"-wo", "-wo",
action="store_true", type=_text_type,
default=False, default="True",
help="tf model conversion without data format optimization") help="tf model conversion without data format optimization")
parser.add_argument( parser.add_argument(
"--define_input_shape", "--define_input_shape",
...@@ -135,14 +135,21 @@ def tf2paddle(model_path, ...@@ -135,14 +135,21 @@ def tf2paddle(model_path,
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.op_mapper.tf_op_mapper_nhwc import TFOpMapperNHWC from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tf_optimizer import TFOptimizer from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapperNHWC(model) mapper = TFOpMapper(model)
program.build() program.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
program.gen_model(save_dir) program.gen_model(save_dir)
...@@ -172,8 +179,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): ...@@ -172,8 +179,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try: try:
import onnx import onnx
version = onnx.version.version version = onnx.version.version
if version != '1.6.0': if version < '1.6.0':
print("[ERROR] onnx==1.6.0 is required") print("[ERROR] onnx>=1.6.0 is required")
return return
except: except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
...@@ -229,11 +236,21 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files): ...@@ -229,11 +236,21 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files):
def paddle2onnx(model_path, save_dir, opset_version=10): def paddle2onnx(model_path, save_dir, opset_version=10):
from x2paddle.decoder.paddle_decoder import PaddleDecoder import paddle.fluid as fluid
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper try:
model = PaddleDecoder(model_path, '__model__', '__params__') import paddle2onnx
mapper = PaddleOpMapper() except:
mapper.convert(model.program, save_dir, opset_number=opset_version) print(
"[ERROR] paddle2onnx not installed, use \"pip install paddle2onnx\"")
import paddle2onnx as p2o
model = p2o.PaddleDecoder(model_path, '__model__', '__params__')
mapper = p2o.PaddleOpMapper()
mapper.convert(
model.program,
save_dir,
scope=fluid.global_scope(),
opset_version=opset_version)
def main(): def main():
...@@ -304,7 +321,7 @@ def main(): ...@@ -304,7 +321,7 @@ def main():
elif args.framework == "paddle2onnx": elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx" assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, args.onnx_opset) paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
else: else:
raise Exception( raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册