提交 1256f60f 编写于 作者: W wjj19950828

move to paddleinference api

上级 be6475f5
...@@ -20,6 +20,8 @@ import logging ...@@ -20,6 +20,8 @@ import logging
import paddle import paddle
import onnx import onnx
import shutil import shutil
from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config
from onnx import helper from onnx import helper
from onnx import TensorProto from onnx import TensorProto
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
...@@ -191,6 +193,7 @@ class ONNXConverter(object): ...@@ -191,6 +193,7 @@ class ONNXConverter(object):
""" """
# input data # input data
paddle_tensor_feed = list() paddle_tensor_feed = list()
result = list()
for i in range(len(self.input_feed)): for i in range(len(self.input_feed)):
paddle_tensor_feed.append( paddle_tensor_feed.append(
paddle.to_tensor(self.input_feed[self.inputs_name[i]])) paddle.to_tensor(self.input_feed[self.inputs_name[i]]))
...@@ -208,18 +211,47 @@ class ONNXConverter(object): ...@@ -208,18 +211,47 @@ class ONNXConverter(object):
model.eval() model.eval()
result = model(*paddle_tensor_feed) result = model(*paddle_tensor_feed)
else: else:
paddle_path = os.path.join( paddle_model_path = os.path.join(
self.pwd, self.name, self.pwd, self.name, self.name + '_' + str(ver) +
self.name + '_' + str(ver) + '_paddle/inference_model/model') '_paddle/inference_model/model.pdmodel')
paddle.disable_static() paddle_param_path = os.path.join(
# run self.pwd, self.name, self.name + '_' + str(ver) +
model = paddle.jit.load(paddle_path) '_paddle/inference_model/model.pdiparams')
model.eval() config = Config()
result = model(*paddle_tensor_feed) config.set_prog_file(paddle_model_path)
shutil.rmtree(os.path.join(self.pwd, self.name)) if os.path.exists(paddle_param_path):
config.set_params_file(paddle_param_path)
# initial GPU memory(M), device ID
config.enable_use_gpu(200, 0)
# optimize graph and fuse op
config.switch_ir_optim(False)
config.enable_memory_optim()
# disable feed, fetch OP, needed by zero_copy_run
config.switch_use_feed_fetch_ops(False)
config.disable_glog_info()
pass_builder = config.pass_builder()
predictor = create_predictor(config)
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
for i in range(len(input_names)):
input_tensor = predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(self.input_feed[self.inputs_name[i]])
predictor.run()
for output_name in output_names:
output_tensor = predictor.get_output_handle(output_name)
result.append(output_tensor.copy_to_cpu())
shutil.rmtree(
os.path.join(self.pwd, self.name, self.name + '_' + str(ver) +
'_paddle/'))
# get paddle outputs # get paddle outputs
if isinstance(result, (tuple, list)): if isinstance(result, (tuple, list)):
if isinstance(result[0], np.ndarray):
result = tuple(out for out in result)
else:
result = tuple(out.numpy() for out in result) result = tuple(out.numpy() for out in result)
else:
if isinstance(result, np.ndarray):
result = (result, )
else: else:
result = (result.numpy(), ) result = (result.numpy(), )
return result return result
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册