提交 1a3f37c5 编写于 作者: W wjj19950828

fixed reload bug

上级 bdbc0262
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
import importlib
import numpy as np import numpy as np
import logging import logging
import paddle import paddle
...@@ -187,20 +189,23 @@ class ONNXConverter(object): ...@@ -187,20 +189,23 @@ class ONNXConverter(object):
make paddle res make paddle res
""" """
# input data # input data
paddle_numpy_feed = list()
paddle_tensor_feed = list() paddle_tensor_feed = list()
for i in range(len(self.input_feed)): for i in range(len(self.input_feed)):
paddle_numpy_feed.append(self.input_feed[self.inputs_name[i]])
paddle_tensor_feed.append( paddle_tensor_feed.append(
paddle.to_tensor(self.input_feed[self.inputs_name[i]])) paddle.to_tensor(self.input_feed[self.inputs_name[i]]))
if self.run_dynamic: if self.run_dynamic:
paddle_path = os.path.join(self.pwd, self.name, paddle_path = os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle/') self.name + '_' + str(ver) + '_paddle/')
import sys restore = paddle.load(os.path.join(paddle_path, "model.pdparams"))
sys.path.append(paddle_path) sys.path.insert(0, paddle_path)
from x2paddle_code import main import x2paddle_code
result = main(*paddle_tensor_feed) # Solve the problem of function overloading caused by traversing the model
importlib.reload(x2paddle_code)
model = getattr(x2paddle_code, "ONNXModel")()
model.set_dict(restore)
model.eval()
result = model(*paddle_tensor_feed)
else: else:
paddle_path = os.path.join( paddle_path = os.path.join(
self.pwd, self.name, self.pwd, self.name,
...@@ -208,7 +213,7 @@ class ONNXConverter(object): ...@@ -208,7 +213,7 @@ class ONNXConverter(object):
paddle.disable_static() paddle.disable_static()
# run # run
model = paddle.jit.load(paddle_path) model = paddle.jit.load(paddle_path)
result = model(*paddle_numpy_feed) result = model(*paddle_tensor_feed)
# get paddle outputs # get paddle outputs
if isinstance(result, (tuple, list)): if isinstance(result, (tuple, list)):
result = tuple(out.numpy() for out in result) result = tuple(out.numpy() for out in result)
......
...@@ -21,6 +21,7 @@ import sys ...@@ -21,6 +21,7 @@ import sys
import os import os
import six import six
import pickle import pickle
import importlib
from os import path as osp from os import path as osp
from x2paddle.core.util import * from x2paddle.core.util import *
...@@ -566,6 +567,8 @@ class PaddleGraph(object): ...@@ -566,6 +567,8 @@ class PaddleGraph(object):
path = osp.abspath(save_dir) path = osp.abspath(save_dir)
sys.path.insert(0, save_dir) sys.path.insert(0, save_dir)
import x2paddle_code import x2paddle_code
# Solve the problem of function overloading caused by traversing the model
importlib.reload(x2paddle_code)
paddle.disable_static() paddle.disable_static()
restore = paddle.load(osp.join(save_dir, "model.pdparams")) restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册