提交 1256f60f 编写于 作者: W wjj19950828

move to paddleinference api

上级 be6475f5
......@@ -20,6 +20,8 @@ import logging
import paddle
import onnx
import shutil
from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config
from onnx import helper
from onnx import TensorProto
from onnxruntime import InferenceSession
......@@ -191,6 +193,7 @@ class ONNXConverter(object):
"""
# input data
paddle_tensor_feed = list()
result = list()
for i in range(len(self.input_feed)):
paddle_tensor_feed.append(
paddle.to_tensor(self.input_feed[self.inputs_name[i]]))
......@@ -208,20 +211,49 @@ class ONNXConverter(object):
model.eval()
result = model(*paddle_tensor_feed)
else:
paddle_path = os.path.join(
self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle/inference_model/model')
paddle.disable_static()
# run
model = paddle.jit.load(paddle_path)
model.eval()
result = model(*paddle_tensor_feed)
shutil.rmtree(os.path.join(self.pwd, self.name))
paddle_model_path = os.path.join(
self.pwd, self.name, self.name + '_' + str(ver) +
'_paddle/inference_model/model.pdmodel')
paddle_param_path = os.path.join(
self.pwd, self.name, self.name + '_' + str(ver) +
'_paddle/inference_model/model.pdiparams')
config = Config()
config.set_prog_file(paddle_model_path)
if os.path.exists(paddle_param_path):
config.set_params_file(paddle_param_path)
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(False)
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
config.disable_glog_info()
pass_builder = config.pass_builder()
predictor = create_predictor(config)
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(self.input_feed[self.inputs_name[i]])
predictor.run()
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
result.append(output_tensor.copy_to_cpu())
shutil.rmtree(
os.path.join(self.pwd, self.name, self.name + '_' + str(ver) +
'_paddle/'))
# get paddle outputs
if isinstance(result, (tuple, list)):
result = tuple(out.numpy() for out in result)
if isinstance(result[0], np.ndarray):
result = tuple(out for out in result)
else:
result = tuple(out.numpy() for out in result)
else:
result = (result.numpy(), )
if isinstance(result, np.ndarray):
result = (result, )
else:
result = (result.numpy(), )
return result
def _mk_onnx_res(self, ver):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册