提交 c990aed7 编写于 作者: R root

[onnx] fix bug of import onnxruntime

上级 2ed0c371
...@@ -194,6 +194,16 @@ def main(): ...@@ -194,6 +194,16 @@ def main():
assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)" assert args.framework is not None, "--framework is not defined(support tensorflow/caffe/onnx)"
assert args.save_dir is not None, "--save_dir is not defined" assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "onnx":
try:
import onnxruntime as rt
version = rt.__version__
if version != '1.0.0':
print("onnxruntime==1.0.0 is required")
return
except:
print("onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\".")
try: try:
import paddle import paddle
v0, v1, v2 = paddle.__version__.split('.') v0, v1, v2 = paddle.__version__.split('.')
......
...@@ -476,17 +476,7 @@ class ONNXDecoder(object): ...@@ -476,17 +476,7 @@ class ONNXDecoder(object):
return 'x2paddle_' + name return 'x2paddle_' + name
def check_model_running_state(self, model_path): def check_model_running_state(self, model_path):
try:
import onnxruntime as rt import onnxruntime as rt
version = rt.__version__
if version != '1.0.0':
print("onnxruntime==1.0.0 is required")
return
except:
raise Exception(
"onnxruntime is not installed, use \"pip install onnxruntime==1.0.0\"."
)
model = onnx.load(model_path) model = onnx.load(model_path)
model = onnx.shape_inference.infer_shapes(model) model = onnx.shape_inference.infer_shapes(model)
if len(model.graph.value_info) < len(model.graph.node) - 1: if len(model.graph.value_info) < len(model.graph.node) - 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册