未验证 提交 472af090 编写于 作者: L Lyon 提交者: GitHub

Merge pull request #35 from Oneflow-Inc/resnet_onnx

add docs about convert resnet to onnx
......@@ -516,3 +516,33 @@ python3 preprocess_imagenet_validation_data.py ../data/imagenet/validation
至此,已经完成了全部的数据预处理,您可以直接跳转至**转换训练集****转换验证集**部分,轻松完成ImageNet-2012数据集到OFRecord的转换过程了。
### OneFlow 模型转 ONNX 模型
ONNX (Open Neural Network Exchange) 是一种较为广泛使用的神经网络中间格式,通过 ONNX 格式,OneFlow 模型可以被许多部署框架(如 OpenVINO、ONNX Runtime 和移动端的 ncnn、tnn、TEngine 等)所使用。这一节介绍如何将训练好的 resnet50 v1.5 模型转换为 ONNX 模型并验证正确性,可以在 resnet\_to\_onnx.py 中找到参考代码。
#### 如何生成 ONNX 模型
**步骤一:将网络权重保存到磁盘**
首先将训练得到的网络权重保存到磁盘,例如我们保存到 /tmp/resnet50_weights 这个文件夹下
```python
check_point = flow.train.CheckPoint()
check_point.save("/tmp/resnet50_weights")
```
**步骤二:新建一个用于推理的 job function**
然后新建一个用于推理的 job function,它只包含网络结构本身,不包含读取 OFRecord 的算子,并且直接接受 numpy 数组形式的输入。可参考 resnet\_to\_onnx.py 中的 `InferenceNet`
**步骤三:调用 flow.onnx.export 方法**
接下来调用 `flow.onnx.export` 方法,从 OneFlow 网络得到 ONNX 模型,它的第一个参数是上文所说的专用于推理的 job function,第二个参数是 /tmp/resnet50_weights 这个保存了网络权重的文件夹,第三个参数是 ONNX 模型文件的路径。
```python
flow.onnx.export(InferenceNet, '/tmp/resnet50_weights', 'resnet50_v1.5.onnx')
```
#### 验证 ONNX 模型的正确性
生成 ONNX 模型之后可以使用 ONNX Runtime 运行 ONNX 模型,以验证 OneFlow 模型和 ONNX 模型能够在相同的输入下产生相同的结果。相应的代码在 resnet\_to\_onnx.py 的 `check_equality`
......@@ -2,20 +2,22 @@
# from __future__ import division
# from __future__ import print_function
from collections import OrderedDict
import os
from PIL import Image
import time
from typing import Callable, Text
import numpy as np
import oneflow as flow
import onnx
import onnxruntime as ort
import numpy as np
import time
from PIL import Image
from collections import OrderedDict
from resnet_model import resnet50
from imagenet1000_clsidx_to_labels import clsidx_2_labels
def load_image(image_path):
def load_image(image_path: Text) -> np.ndarray:
rgb_mean = [123.68, 116.779, 103.939]
rgb_std = [58.393, 57.12, 57.375]
print(image_path)
......@@ -42,74 +44,62 @@ def InferenceNet(images=flow.FixedTensorDef((1, 3, 224, 224))):
return predictions
def onnx_inference(image_path, onnx_model_path, ort_optimize=True):
def onnx_inference(image: np.ndarray, onnx_model: onnx.ModelProto):
"""
test onnx model with onnx runtime
:param image_path: input image path
:param onnx_model_path: path of model.onnx
:param ort_optimize:
:param image: input image, a numpy array
:param onnx_model: onnx model
:return:
"""
assert os.path.isfile(image_path) and os.path.isfile(onnx_model_path)
ort_sess_opt = ort.SessionOptions()
ort_sess_opt.graph_optimization_level = \
ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED if ort_optimize else \
ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess = ort.InferenceSession(onnx_model_path, sess_options=ort_sess_opt)
assert os.path.isfile(image_path)
sess = ort.InferenceSession(onnx_model.SerializeToString())
assert len(sess.get_outputs()) == 1 and len(sess.get_inputs()) <= 1
ipt_dict = OrderedDict()
for ipt in sess.get_inputs():
ipt_dict[ipt.name] = load_image(image_path)
start = time.time()
ipt_dict[ipt.name] = image
onnx_res = sess.run([], ipt_dict)[0]
print('Cost: %.4f s' % (time.time() - start))
clsidx_onnx = onnx_res.argmax()
print('Onnx >> ', onnx_res.max(), clsidx_2_labels[clsidx_onnx])
return onnx_res
def oneflow_to_onnx(job_func, flow_model_path, onnx_model_dir, external_data=False):
def oneflow_to_onnx(job_func: Callable, flow_weights_path: Text, onnx_model_dir: Text, external_data: bool=False):
"""
convert oneflow model to onnx model
:param job_func: inference function in oneflow
:param flow_model_path: input oneflow model path
:param onnx_model_dir: output dir path to save model.onnx
:param external_data:
:return: ture or false
:param job_func: inference function in oneflow
:param flow_weights_path: input oneflow model path
:param onnx_model_dir: output dir path to save model.onnx
:return: onnx model
"""
if not os.path.exists(onnx_model_dir): os.makedirs(onnx_model_dir)
assert os.path.exists(flow_model_path) and os.path.isdir(onnx_model_dir)
assert os.path.exists(flow_weights_path) and os.path.isdir(onnx_model_dir)
check_point = flow.train.CheckPoint()
# it is a trick to keep check_point.save() from hanging when there is no variable
@flow.global_function(flow.FunctionConfig())
def add_var():
return flow.get_variable(
name="trick",
shape=(1,),
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
check_point.init()
onnx_model_path = os.path.join(onnx_model_dir, os.path.basename(flow_model_path) + '.onnx')
flow.onnx.export(job_func, flow_model_path, onnx_model_path, opset=11, external_data=external_data)
onnx_model_path = os.path.join(onnx_model_dir, os.path.basename(flow_weights_path) + '.onnx')
flow.onnx.export(job_func, flow_weights_path, onnx_model_path, opset=11, external_data=external_data)
print('Convert to onnx success! >> ', onnx_model_path)
return onnx_model_path
return onnx.load_model(onnx_model_path)
def check_equality(job_func: Callable, onnx_model: onnx.ModelProto, image_path: Text) -> (bool, np.ndarray):
image = load_image(image_path)
onnx_res = onnx_inference(image, onnx_model)
oneflow_res = job_func(image).get().ndarray()
is_equal = np.allclose(onnx_res, oneflow_res, rtol=1e-4, atol=1e-5)
return is_equal, onnx_res
if __name__ == "__main__":
# path = 'tiger.jpg'
path = 'test_img/ILSVRC2012_val_00020287.JPEG'
flow_model_path = '/your/oneflow/model/path'
# image_path = 'tiger.jpg'
image_path = 'test_img/ILSVRC2012_val_00020287.JPEG'
flow_weights_path = '/your/oneflow/weights/path'
onnx_model_dir = 'onnx/model'
# conver oneflow to onnx
onnx_model_path = oneflow_to_onnx(InferenceNet, flow_model_path, onnx_model_dir, external_data=False)
check_point = flow.train.CheckPoint()
check_point.load(flow_weights_path)
# inference
onnx_inference(path, onnx_model_path)
# conver oneflow to onnx
onnx_model = oneflow_to_onnx(InferenceNet, flow_weights_path, onnx_model_dir, external_data=False)
# Output:
# ILSVRC2012_val_00020287.JPEG
# Cost: 0.0319s
# Onnx >> 0.9924272 hay
# check equality
are_equal, onnx_res = check_equality(InferenceNet, onnx_model, image_path)
clsidx_onnx = onnx_res.argmax()
print('Are the results equal? {}'.format('Yes' if are_equal else 'No'))
print('Class: {}; score: {}'.format(clsidx_2_labels[clsidx_onnx], onnx_res.max()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册