From 1a3f37c5a7d63fe952c03aa9101eebcfe1f6ebbd Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 9 Aug 2022 11:35:31 +0800 Subject: [PATCH] fixed reload bug --- tests/onnx/onnxbase.py | 19 ++++++++++++------- x2paddle/core/program.py | 3 +++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/onnx/onnxbase.py b/tests/onnx/onnxbase.py index 34954b4..f687108 100644 --- a/tests/onnx/onnxbase.py +++ b/tests/onnx/onnxbase.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import sys +import importlib import numpy as np import logging import paddle @@ -187,20 +189,23 @@ class ONNXConverter(object): make paddle res """ # input data - paddle_numpy_feed = list() paddle_tensor_feed = list() for i in range(len(self.input_feed)): - paddle_numpy_feed.append(self.input_feed[self.inputs_name[i]]) paddle_tensor_feed.append( paddle.to_tensor(self.input_feed[self.inputs_name[i]])) if self.run_dynamic: paddle_path = os.path.join(self.pwd, self.name, self.name + '_' + str(ver) + '_paddle/') - import sys - sys.path.append(paddle_path) - from x2paddle_code import main - result = main(*paddle_tensor_feed) + restore = paddle.load(os.path.join(paddle_path, "model.pdparams")) + sys.path.insert(0, paddle_path) + import x2paddle_code + # Solve the problem of function overloading caused by traversing the model + importlib.reload(x2paddle_code) + model = getattr(x2paddle_code, "ONNXModel")() + model.set_dict(restore) + model.eval() + result = model(*paddle_tensor_feed) else: paddle_path = os.path.join( self.pwd, self.name, @@ -208,7 +213,7 @@ class ONNXConverter(object): paddle.disable_static() # run model = paddle.jit.load(paddle_path) - result = model(*paddle_numpy_feed) + result = model(*paddle_tensor_feed) # get paddle outputs if isinstance(result, (tuple, list)): result = tuple(out.numpy() for out in result) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index f04748f..7d1ba96 100755 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -21,6 +21,7 @@ import sys import os import six import pickle +import importlib from os import path as osp from x2paddle.core.util import * @@ -566,6 +567,8 @@ class PaddleGraph(object): path = osp.abspath(save_dir) sys.path.insert(0, save_dir) import x2paddle_code + # Solve the problem of function overloading caused by traversing the model + importlib.reload(x2paddle_code) paddle.disable_static() restore = paddle.load(osp.join(save_dir, "model.pdparams")) model = getattr(x2paddle_code, self.name)() -- GitLab