提交 cf5217ad 编写于 作者: L lixiang

add onnx tool

上级 3f071225
......@@ -2,8 +2,7 @@
手势 21 个关键点检测
## 重要更新
### 第一个完整pipeline 的 "dpcas" 项目,本地手势交互应用 Demo,之后会推出web架构的手势交互。
### 该项目 "dpcas" 虽然它才刚刚诞生,有各种不足,但是我会继续改进,努力让更多人看到它,希望它不仅仅是一个demo。
### 添加 onnx 模块,预训练模型中有转好的resnet50-onnx模型,注意:目前不支持rexnetv1
### "dpcas" 项目地址:https://codechina.csdn.net/EricLee/dpcas
## 项目Wiki
......@@ -104,6 +103,9 @@
### 模型推理
* 根目录下运行命令: python inference.py (注意脚本内相关参数配置 )
### onnx使用
* step1: 设定相关配置包括模型类型和模型参数路径,根目录下运行命令: python model2onnx.py (注意脚本内相关参数配置 )
* step2: 设定onnx模型路径,根目录下运行命令: python onnx_inference.py (注意脚本内相关参数配置 )
* 建议
```
......
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: Eric.Lee
# function: pytorch model 2 onnx
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
from models.resnet import resnet18,resnet34,resnet50,resnet101
from models.squeezenet import squeezenet1_1,squeezenet1_0
from models.shufflenetv2 import ShuffleNetV2
from models.shufflenet import ShuffleNet
from models.mobilenetv2 import MobileNetV2
from torchvision.models import shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0
from models.rexnetv1 import ReXNetV1
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=' Project Hand Pose Inference')
parser.add_argument('--model_path', type=str, default = './weights1/resnet_50-size-256-wingloss102-0.119.pth',
help = 'model_path') # 模型路径
parser.add_argument('--model', type=str, default = 'shufflenet_v2_x1_5',
help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2
shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0''') # 模型类型
parser.add_argument('--num_classes', type=int , default = 42,
help = 'num_classes') # 手部21关键点, (x,y)*2 = 42
parser.add_argument('--GPUS', type=str, default = '0',
help = 'GPUS') # GPU选择
parser.add_argument('--test_path', type=str, default = './image/',
help = 'test_path') # 测试图片路径
parser.add_argument('--img_size', type=tuple , default = (256,256),
help = 'img_size') # 输入模型图片尺寸
parser.add_argument('--vis', type=bool , default = True,
help = 'vis') # 是否可视化图片
print('\n/******************* {} ******************/\n'.format(parser.description))
#--------------------------------------------------------------------------
ops = parser.parse_args()# 解析添加参数
#--------------------------------------------------------------------------
print('----------------------------------')
unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典
for key in unparsed.keys():
print('{} : {}'.format(key,unparsed[key]))
#---------------------------------------------------------------------------
os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS
test_path = ops.test_path # 测试图片文件夹路径
#---------------------------------------------------------------- 构建模型
print('use model : %s'%(ops.model))
if ops.model == 'resnet_50':
model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_18':
model_ = resnet18(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_34':
model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == 'resnet_101':
model_ = resnet101(num_classes = ops.num_classes,img_size=ops.img_size[0])
elif ops.model == "squeezenet1_0":
model_ = squeezenet1_0(num_classes=ops.num_classes)
elif ops.model == "squeezenet1_1":
model_ = squeezenet1_1(num_classes=ops.num_classes)
elif ops.model == "shufflenetv2":
model_ = ShuffleNetV2(ratio=1., num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x1_5":
model_ = shufflenet_v2_x1_5(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x1_0":
model_ = shufflenet_v2_x1_0(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet_v2_x2_0":
model_ = shufflenet_v2_x2_0(pretrained=False,num_classes=ops.num_classes)
elif ops.model == "shufflenet":
model_ = ShuffleNet(num_blocks = [2,4,2], num_classes=ops.num_classes, groups=3)
elif ops.model == "mobilenetv2":
model_ = MobileNetV2(num_classes=ops.num_classes)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
model_ = model_.to(device)
model_.eval() # 设置为前向推断模式
# 加载测试模型
if os.access(ops.model_path,os.F_OK):# checkpoint
chkpt = torch.load(ops.model_path, map_location=device)
model_.load_state_dict(chkpt)
print('load test model : {}'.format(ops.model_path))
input_size = ops.img_size[0]
batch_size = 1 #批处理大小
input_shape = (3, input_size,input_size) #输入数据,改成自己的输入shape
print("input_size : ",input_size)
x = torch.randn(batch_size, *input_shape) # 生成张量
x = x.to(device)
export_onnx_file = "{}_size-{}.onnx".format(ops.model,input_size) # 目的ONNX文件名
torch.onnx.export(model_,
x,
export_onnx_file,
opset_version=9,
do_constant_folding=True, # 是否执行常量折叠优化
input_names=["input"], # 输入名
output_names=["output"], # 输出名
#dynamic_axes={"input":{0:"batch_size"}, # 批处理变量
# "output":{0:"batch_size"}}
)
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: Eric.Lee
# function: onnx Inference
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
from hand_data_iter.datasets import draw_bd_handpose
class ONNXModel():
def __init__(self, onnx_path,gpu_cfg = False):
"""
:param onnx_path:
"""
self.onnx_session = onnxruntime.InferenceSession(onnx_path)
if gpu_cfg:
self.onnx_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}])
self.input_name = self.get_input_name(self.onnx_session)
self.output_name = self.get_output_name(self.onnx_session)
print("input_name:{}".format(self.input_name))
print("output_name:{}".format(self.output_name))
def get_output_name(self, onnx_session):
"""
output_name = onnx_session.get_outputs()[0].name
:param onnx_session:
:return:
"""
output_name = []
for node in onnx_session.get_outputs():
output_name.append(node.name)
return output_name
def get_input_name(self, onnx_session):
"""
:param onnx_session:
:return:
"""
input_name = []
for node in onnx_session.get_inputs():
input_name.append(node.name)
return input_name
def get_input_feed(self, input_name, image_numpy):
"""
:param input_name:
:param image_numpy:
:return:
"""
input_feed = {}
for name in input_name:
input_feed[name] = image_numpy
return input_feed
def forward(self, image_numpy):
'''
# image_numpy = image_numpy[np.newaxis, :]
# onnx_session.run([output_name], {input_name: x})
# :param image_numpy:
# :return:
'''
input_feed = self.get_input_feed(self.input_name, image_numpy)
output = self.onnx_session.run(self.output_name, input_feed=input_feed)
return output
if __name__ == "__main__":
img_size = 256
model = ONNXModel("resnet_50_size-256.onnx")
path_ = "./image/"
for f_ in os.listdir(path_):
img0 = cv2.imread(path_ + f_)
img_width = img0.shape[1]
img_height = img0.shape[0]
img = cv2.resize(img0, (img_size,img_size), interpolation = cv2.INTER_CUBIC)
img_ndarray = img.transpose((2, 0, 1))
img_ndarray = img_ndarray / 255.
img_ndarray = np.expand_dims(img_ndarray, 0)
output = model.forward(img_ndarray.astype('float32'))[0][0]
output = np.array(output)
print(output.shape[0])
pts_hand = {} #构建关键点连线可视化结构
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
pts_hand[str(i)] = {}
pts_hand[str(i)] = {
"x":x,
"y":y,
}
draw_bd_handpose(img0,pts_hand,0,0) # 绘制关键点连线
#------------- 绘制关键点
for i in range(int(output.shape[0]/2)):
x = (output[i*2+0]*float(img_width))
y = (output[i*2+1]*float(img_height))
cv2.circle(img0, (int(x),int(y)), 3, (255,50,60),-1)
cv2.circle(img0, (int(x),int(y)), 1, (255,150,180),-1)
cv2.namedWindow('image',0)
cv2.imshow('image',img0)
if cv2.waitKey(600) == 27 :
break
cv2.waitKey(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册