提交 2960b4c2 编写于 作者: L lubin10

support onnx inference for cls; add readme

上级 fcc2fac0
# paddle2onnx 模型转化与预测
本章节介绍 ResNet50_vd 模型如何转化为 ONNX 模型,并基于 ONNX 引擎预测。
## 1. 环境准备
需要准备 Paddle2ONNX 模型转化环境,和 ONNX 模型预测环境
### Paddle2ONNX
Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式,算子目前稳定支持导出 ONNX Opset 9~11,部分Paddle算子支持更低的ONNX Opset转换。
更多细节可参考 [Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/README_zh.md)
- 安装 Paddle2ONNX
```
python3.7 -m pip install paddle2onnx
```
- 安装 ONNX 运行时
```
python3.7 -m pip install onnxruntime
```
## 2. 模型转换
- ResNet50_vd inference模型下载
```
cd deploy
mkdir models && cd models
wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_vd_infer.tar && tar xf ResNet50_vd_infer.tar
cd ..
```
- 模型转换
使用 Paddle2ONNX 将Paddle静态图模型转换为ONNX模型格式:
```
paddle2onnx --model_dir=./models/ResNet50_vd_infer/ \
--model_filename=inference.pdmodel \
--params_filename=inference.pdiparams \
--save_file=./models/ResNet50_vd_infer/inference.onnx \
--opset_version=10 \
--enable_onnx_checker=True
```
执行完毕后,ONNX 模型 `inference.onnx` 会被保存在 `./models/ResNet50_vd_infer/` 路径下
## 3. onnx 预测
执行如下命令:
```
python3.7 python/predict_cls.py \
-c configs/inference_cls.yaml \
-o Global.use_onnx=True \
-o Global.use_gpu=False \
-o Global.inference_model_dir=./models/ResNet50_vd_infer \
```
结果如下:
```
ILSVRC2012_val_00000010.jpeg: class id(s): [153, 204, 229, 332, 155], score(s): [0.69, 0.10, 0.02, 0.01, 0.01], label_name(s): ['Maltese dog, Maltese terrier, Maltese', 'Lhasa, Lhasa apso', 'Old English sheepdog, bobtail', 'Angora, Angora rabbit', 'Shih-Tzu']
```
......@@ -67,12 +67,17 @@ class ClsPredictor(Predictor):
warmup=2)
def predict(self, images):
input_names = self.paddle_predictor.get_input_names()
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
use_onnx = self.args.get("use_onnx", False)
if not use_onnx:
input_names = self.predictor.get_input_names()
input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.predictor.get_output_names()
output_tensor = self.predictor.get_output_handle(output_names[0])
else:
input_names = self.predictor.get_inputs()[0].name
output_names = self.predictor.get_outputs()[0].name
output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(output_names[
0])
if self.benchmark:
self.auto_logger.times.start()
if not isinstance(images, (list, )):
......@@ -84,9 +89,15 @@ class ClsPredictor(Predictor):
if self.benchmark:
self.auto_logger.times.stamp()
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
if not use_onnx:
input_tensor.copy_from_cpu(image)
self.predictor.run()
batch_output = output_tensor.copy_to_cpu()
else:
batch_output = self.predictor.run(
output_names=[output_names],
input_feed={input_names: image})[0]
if self.benchmark:
self.auto_logger.times.stamp()
if self.postprocess is not None:
......
......@@ -58,12 +58,16 @@ class RecPredictor(Predictor):
warmup=2)
def predict(self, images, feature_normalize=True):
input_names = self.paddle_predictor.get_input_names()
input_tensor = self.paddle_predictor.get_input_handle(input_names[0])
use_onnx = self.args.get("use_onnx", False)
if not use_onnx:
input_names = self.predictor.get_input_names()
input_tensor = self.predictor.get_input_handle(input_names[0])
output_names = self.paddle_predictor.get_output_names()
output_tensor = self.paddle_predictor.get_output_handle(output_names[
0])
output_names = self.predictor.get_output_names()
output_tensor = self.predictor.get_output_handle(output_names[0])
else:
input_names = self.predictor.get_inputs()[0].name
output_names = self.predictor.get_outputs()[0].name
if self.benchmark:
self.auto_logger.times.start()
......@@ -76,9 +80,15 @@ class RecPredictor(Predictor):
if self.benchmark:
self.auto_logger.times.stamp()
input_tensor.copy_from_cpu(image)
self.paddle_predictor.run()
batch_output = output_tensor.copy_to_cpu()
if not use_onnx:
input_tensor.copy_from_cpu(image)
self.predictor.run()
batch_output = output_tensor.copy_to_cpu()
else:
batch_output = self.predictor.run(
output_names=[output_names],
input_feed={input_names: image})[0]
if self.benchmark:
self.auto_logger.times.stamp()
......
......@@ -28,8 +28,12 @@ class Predictor(object):
if args.use_fp16 is True:
assert args.use_tensorrt is True
self.args = args
self.paddle_predictor, self.config = self.create_paddle_predictor(
args, inference_model_dir)
if self.args.get("use_onnx", False):
self.predictor, self.config = self.create_onnx_predictor(
args, inference_model_dir)
else:
self.predictor, self.config = self.create_paddle_predictor(
args, inference_model_dir)
def predict(self, image):
raise NotImplementedError
......@@ -69,3 +73,20 @@ class Predictor(object):
predictor = create_predictor(config)
return predictor, config
def create_onnx_predictor(self, args, inference_model_dir=None):
import onnxruntime as ort
if inference_model_dir is None:
inference_model_dir = args.inference_model_dir
model_file = os.path.join(inference_model_dir, "inference.onnx")
config = ort.SessionOptions()
if args.use_gpu:
raise ValueError(
"onnx inference now only supports cpu! please specify use_gpu false."
)
else:
config.intra_op_num_threads = args.cpu_num_threads
if args.ir_optim:
config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
predictor = ort.InferenceSession(model_file, sess_options=config)
return predictor, config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册