diff --git a/README.md b/README.md index 1745cd5a9d6ecf2856486baa5bd7d0ec31c0a097..8976e27aeb4c33df1c20a4ccc12ecdb12540ded9 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ 手势 21 个关键点检测 ## 重要更新 -### 第一个完整pipeline 的 "dpcas" 项目,本地手势交互应用 Demo,之后会推出web架构的手势交互。 -### 该项目 "dpcas" 虽然它才刚刚诞生,有各种不足,但是我会继续改进,努力让更多人看到它,希望它不仅仅是一个demo。 +### 添加 onnx 模块,预训练模型中有转好的resnet50-onnx模型,注意:目前不支持rexnetv1 ### "dpcas" 项目地址:https://codechina.csdn.net/EricLee/dpcas ## 项目Wiki @@ -104,6 +103,9 @@ ### 模型推理 * 根目录下运行命令: python inference.py (注意脚本内相关参数配置 ) +### onnx使用 +* step1: 设定相关配置包括模型类型和模型参数路径,根目录下运行命令: python model2onnx.py (注意脚本内相关参数配置 ) +* step2: 设定onnx模型路径,根目录下运行命令: python onnx_inference.py (注意脚本内相关参数配置 ) * 建议 ``` diff --git a/model2onnx.py b/model2onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..4f21a8cb226133e528346ddbc6e5d6e20c7939ed --- /dev/null +++ b/model2onnx.py @@ -0,0 +1,110 @@ +#-*-coding:utf-8-*- +# date:2021-10-5 +# Author: Eric.Lee +# function: pytorch model 2 onnx + +import os +import argparse +import torch +import torch.nn as nn +import numpy as np + +from models.resnet import resnet18,resnet34,resnet50,resnet101 +from models.squeezenet import squeezenet1_1,squeezenet1_0 +from models.shufflenetv2 import ShuffleNetV2 +from models.shufflenet import ShuffleNet +from models.mobilenetv2 import MobileNetV2 +from torchvision.models import shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0 +from models.rexnetv1 import ReXNetV1 + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description=' Project Hand Pose Inference') + parser.add_argument('--model_path', type=str, default = './weights1/resnet_50-size-256-wingloss102-0.119.pth', + help = 'model_path') # 模型路径 + parser.add_argument('--model', type=str, default = 'shufflenet_v2_x1_5', + help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2 + shufflenet_v2_x1_5 ,shufflenet_v2_x1_0 , shufflenet_v2_x2_0''') # 模型类型 + parser.add_argument('--num_classes', type=int , default = 42, + help = 'num_classes') # 手部21关键点, (x,y)*2 = 42 + parser.add_argument('--GPUS', type=str, default = '0', + help = 'GPUS') # GPU选择 + parser.add_argument('--test_path', type=str, default = './image/', + help = 'test_path') # 测试图片路径 + parser.add_argument('--img_size', type=tuple , default = (256,256), + help = 'img_size') # 输入模型图片尺寸 + parser.add_argument('--vis', type=bool , default = True, + help = 'vis') # 是否可视化图片 + + print('\n/******************* {} ******************/\n'.format(parser.description)) + #-------------------------------------------------------------------------- + ops = parser.parse_args()# 解析添加参数 + #-------------------------------------------------------------------------- + print('----------------------------------') + + unparsed = vars(ops) # parse_args()方法的返回值为namespace,用vars()内建函数化为字典 + for key in unparsed.keys(): + print('{} : {}'.format(key,unparsed[key])) + + #--------------------------------------------------------------------------- + os.environ['CUDA_VISIBLE_DEVICES'] = ops.GPUS + + test_path = ops.test_path # 测试图片文件夹路径 + + #---------------------------------------------------------------- 构建模型 + print('use model : %s'%(ops.model)) + + if ops.model == 'resnet_50': + model_ = resnet50(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_18': + model_ = resnet18(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_34': + model_ = resnet34(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == 'resnet_101': + model_ = resnet101(num_classes = ops.num_classes,img_size=ops.img_size[0]) + elif ops.model == "squeezenet1_0": + model_ = squeezenet1_0(num_classes=ops.num_classes) + elif ops.model == "squeezenet1_1": + model_ = squeezenet1_1(num_classes=ops.num_classes) + elif ops.model == "shufflenetv2": + model_ = ShuffleNetV2(ratio=1., num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x1_5": + model_ = shufflenet_v2_x1_5(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x1_0": + model_ = shufflenet_v2_x1_0(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet_v2_x2_0": + model_ = shufflenet_v2_x2_0(pretrained=False,num_classes=ops.num_classes) + elif ops.model == "shufflenet": + model_ = ShuffleNet(num_blocks = [2,4,2], num_classes=ops.num_classes, groups=3) + elif ops.model == "mobilenetv2": + model_ = MobileNetV2(num_classes=ops.num_classes) + + use_cuda = torch.cuda.is_available() + device = torch.device("cuda:0" if use_cuda else "cpu") + model_ = model_.to(device) + model_.eval() # 设置为前向推断模式 + + # 加载测试模型 + if os.access(ops.model_path,os.F_OK):# checkpoint + chkpt = torch.load(ops.model_path, map_location=device) + model_.load_state_dict(chkpt) + print('load test model : {}'.format(ops.model_path)) + + input_size = ops.img_size[0] + batch_size = 1 #批处理大小 + input_shape = (3, input_size,input_size) #输入数据,改成自己的输入shape + print("input_size : ",input_size) + + x = torch.randn(batch_size, *input_shape) # 生成张量 + x = x.to(device) + export_onnx_file = "{}_size-{}.onnx".format(ops.model,input_size) # 目的ONNX文件名 + torch.onnx.export(model_, + x, + export_onnx_file, + opset_version=9, + do_constant_folding=True, # 是否执行常量折叠优化 + input_names=["input"], # 输入名 + output_names=["output"], # 输出名 + #dynamic_axes={"input":{0:"batch_size"}, # 批处理变量 + # "output":{0:"batch_size"}} + ) diff --git a/onnx_inference.py b/onnx_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a297b81c35d9b2ff5d0bd729c11416b9c9b0596e --- /dev/null +++ b/onnx_inference.py @@ -0,0 +1,113 @@ +#-*-coding:utf-8-*- +# date:2021-10-5 +# Author: Eric.Lee +# function: onnx Inference +import os, sys +sys.path.append(os.getcwd()) +import onnxruntime +import onnx +import cv2 +import torch +import numpy as np +from hand_data_iter.datasets import draw_bd_handpose +class ONNXModel(): + def __init__(self, onnx_path,gpu_cfg = False): + """ + :param onnx_path: + """ + self.onnx_session = onnxruntime.InferenceSession(onnx_path) + if gpu_cfg: + self.onnx_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}]) + self.input_name = self.get_input_name(self.onnx_session) + self.output_name = self.get_output_name(self.onnx_session) + print("input_name:{}".format(self.input_name)) + print("output_name:{}".format(self.output_name)) + + def get_output_name(self, onnx_session): + """ + output_name = onnx_session.get_outputs()[0].name + :param onnx_session: + :return: + """ + output_name = [] + for node in onnx_session.get_outputs(): + output_name.append(node.name) + return output_name + + def get_input_name(self, onnx_session): + """ + :param onnx_session: + :return: + """ + input_name = [] + for node in onnx_session.get_inputs(): + input_name.append(node.name) + return input_name + + def get_input_feed(self, input_name, image_numpy): + """ + :param input_name: + :param image_numpy: + :return: + """ + input_feed = {} + for name in input_name: + input_feed[name] = image_numpy + return input_feed + + def forward(self, image_numpy): + ''' + # image_numpy = image_numpy[np.newaxis, :] + # onnx_session.run([output_name], {input_name: x}) + # :param image_numpy: + # :return: + ''' + input_feed = self.get_input_feed(self.input_name, image_numpy) + output = self.onnx_session.run(self.output_name, input_feed=input_feed) + return output +if __name__ == "__main__": + img_size = 256 + model = ONNXModel("resnet_50_size-256.onnx") + path_ = "./image/" + for f_ in os.listdir(path_): + + img0 = cv2.imread(path_ + f_) + img_width = img0.shape[1] + img_height = img0.shape[0] + img = cv2.resize(img0, (img_size,img_size), interpolation = cv2.INTER_CUBIC) + + img_ndarray = img.transpose((2, 0, 1)) + img_ndarray = img_ndarray / 255. + img_ndarray = np.expand_dims(img_ndarray, 0) + + output = model.forward(img_ndarray.astype('float32'))[0][0] + output = np.array(output) + print(output.shape[0]) + pts_hand = {} #构建关键点连线可视化结构 + for i in range(int(output.shape[0]/2)): + x = (output[i*2+0]*float(img_width)) + y = (output[i*2+1]*float(img_height)) + + pts_hand[str(i)] = {} + pts_hand[str(i)] = { + "x":x, + "y":y, + } + + draw_bd_handpose(img0,pts_hand,0,0) # 绘制关键点连线 + + #------------- 绘制关键点 + for i in range(int(output.shape[0]/2)): + x = (output[i*2+0]*float(img_width)) + y = (output[i*2+1]*float(img_height)) + + cv2.circle(img0, (int(x),int(y)), 3, (255,50,60),-1) + cv2.circle(img0, (int(x),int(y)), 1, (255,150,180),-1) + + + cv2.namedWindow('image',0) + cv2.imshow('image',img0) + if cv2.waitKey(600) == 27 : + break + + cv2.waitKey(0)