未验证 提交 d4bbb576 编写于 作者: W WJJ1995 提交者: GitHub

Solve the overload problem (#865)

* add logical ops

* add run_dynamic switch

* add Or and Xor

* add Compare ops

* fixed compare bug

* fixed pad bug

* fixed pad bug

* fixed reload bug
上级 5e645693
......@@ -13,6 +13,8 @@
# limitations under the License.
import os
import sys
import importlib
import numpy as np
import logging
import paddle
......@@ -187,20 +189,23 @@ class ONNXConverter(object):
make paddle res
"""
# input data
paddle_numpy_feed = list()
paddle_tensor_feed = list()
for i in range(len(self.input_feed)):
paddle_numpy_feed.append(self.input_feed[self.inputs_name[i]])
paddle_tensor_feed.append(
paddle.to_tensor(self.input_feed[self.inputs_name[i]]))
if self.run_dynamic:
paddle_path = os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle/')
import sys
sys.path.append(paddle_path)
from x2paddle_code import main
result = main(*paddle_tensor_feed)
restore = paddle.load(os.path.join(paddle_path, "model.pdparams"))
sys.path.insert(0, paddle_path)
import x2paddle_code
# Solve the problem of function overloading caused by traversing the model
importlib.reload(x2paddle_code)
model = getattr(x2paddle_code, "ONNXModel")()
model.set_dict(restore)
model.eval()
result = model(*paddle_tensor_feed)
else:
paddle_path = os.path.join(
self.pwd, self.name,
......@@ -208,7 +213,7 @@ class ONNXConverter(object):
paddle.disable_static()
# run
model = paddle.jit.load(paddle_path)
result = model(*paddle_numpy_feed)
result = model(*paddle_tensor_feed)
# get paddle outputs
if isinstance(result, (tuple, list)):
result = tuple(out.numpy() for out in result)
......
......@@ -21,6 +21,7 @@ import sys
import os
import six
import pickle
import importlib
from os import path as osp
from x2paddle.core.util import *
......@@ -566,6 +567,8 @@ class PaddleGraph(object):
path = osp.abspath(save_dir)
sys.path.insert(0, save_dir)
import x2paddle_code
# Solve the problem of function overloading caused by traversing the model
importlib.reload(x2paddle_code)
paddle.disable_static()
restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册