提交 162d160f 编写于 作者: S SunAhong1993

fix the modify

上级 77565629
......@@ -67,8 +67,8 @@ def arg_parser():
parser.add_argument(
"--without_data_format_optimization",
"-wo",
action="store_true",
default=False,
type=_text_type,
default="True",
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
......@@ -135,14 +135,21 @@ def tf2paddle(model_path,
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.op_mapper.tf_op_mapper_nhwc import TFOpMapperNHWC
from x2paddle.optimizer.tf_optimizer import TFOptimizer
from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapperNHWC(model)
mapper = TFOpMapper(model)
program.build()
bias_opt = BiasOpt()
transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt()
bias_opt.run(program)
batch_norm_opt.run(program)
transpose_opt.run(program)
program.gen_model(save_dir)
......@@ -172,8 +179,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try:
import onnx
version = onnx.version.version
if version != '1.6.0':
print("[ERROR] onnx==1.6.0 is required")
if version < '1.6.0':
print("[ERROR] onnx>=1.6.0 is required")
return
except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......@@ -229,11 +236,21 @@ def pytorch2paddle(model_path, save_dir, jit_type, input_files):
def paddle2onnx(model_path, save_dir, opset_version=10):
from x2paddle.decoder.paddle_decoder import PaddleDecoder
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
model = PaddleDecoder(model_path, '__model__', '__params__')
mapper = PaddleOpMapper()
mapper.convert(model.program, save_dir, opset_number=opset_version)
import paddle.fluid as fluid
try:
import paddle2onnx
except:
print(
"[ERROR] paddle2onnx not installed, use \"pip install paddle2onnx\"")
import paddle2onnx as p2o
model = p2o.PaddleDecoder(model_path, '__model__', '__params__')
mapper = p2o.PaddleOpMapper()
mapper.convert(
model.program,
save_dir,
scope=fluid.global_scope(),
opset_version=opset_version)
def main():
......@@ -304,7 +321,7 @@ def main():
elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, args.onnx_opset)
paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
else:
raise Exception(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册