onnx_inference.py 3.4 KB
Newer Older
L
lixiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
#-*-coding:utf-8-*-
# date:2021-10-5
# Author: Eric.Lee
# function: onnx Inference
import os, sys
sys.path.append(os.getcwd())
import onnxruntime
import onnx
import cv2
import torch
import numpy as np
from hand_data_iter.datasets import draw_bd_handpose
class ONNXModel():
    def __init__(self, onnx_path,gpu_cfg = False):
        """
        :param onnx_path:
        """
        self.onnx_session = onnxruntime.InferenceSession(onnx_path)
        if gpu_cfg:
            self.onnx_session.set_providers(['CUDAExecutionProvider'], [ {'device_id': 0}])
        self.input_name = self.get_input_name(self.onnx_session)
        self.output_name = self.get_output_name(self.onnx_session)
        print("input_name:{}".format(self.input_name))
        print("output_name:{}".format(self.output_name))

    def get_output_name(self, onnx_session):
        """
        output_name = onnx_session.get_outputs()[0].name
        :param onnx_session:
        :return:
        """
        output_name = []
        for node in onnx_session.get_outputs():
            output_name.append(node.name)
        return output_name

    def get_input_name(self, onnx_session):
        """
        :param onnx_session:
        :return:
        """
        input_name = []
        for node in onnx_session.get_inputs():
            input_name.append(node.name)
        return input_name

    def get_input_feed(self, input_name, image_numpy):
        """
        :param input_name:
        :param image_numpy:
        :return:
        """
        input_feed = {}
        for name in input_name:
            input_feed[name] = image_numpy
        return input_feed

    def forward(self, image_numpy):
        '''
        # image_numpy = image_numpy[np.newaxis, :]
        # onnx_session.run([output_name], {input_name: x})
        # :param image_numpy:
        # :return:
        '''
        input_feed = self.get_input_feed(self.input_name, image_numpy)
        output = self.onnx_session.run(self.output_name, input_feed=input_feed)
        return output
if __name__ == "__main__":
    img_size = 256
    model = ONNXModel("resnet_50_size-256.onnx")
    path_ = "./image/"
    for f_ in os.listdir(path_):

        img0 = cv2.imread(path_ + f_)
        img_width = img0.shape[1]
        img_height = img0.shape[0]
        img = cv2.resize(img0, (img_size,img_size), interpolation = cv2.INTER_CUBIC)

        img_ndarray = img.transpose((2, 0, 1))
        img_ndarray = img_ndarray / 255.
        img_ndarray = np.expand_dims(img_ndarray, 0)

        output = model.forward(img_ndarray.astype('float32'))[0][0]
        output = np.array(output)
        print(output.shape[0])
        pts_hand = {} #构建关键点连线可视化结构
        for i in range(int(output.shape[0]/2)):
            x = (output[i*2+0]*float(img_width))
            y = (output[i*2+1]*float(img_height))

            pts_hand[str(i)] = {}
            pts_hand[str(i)] = {
                "x":x,
                "y":y,
                }

        draw_bd_handpose(img0,pts_hand,0,0) # 绘制关键点连线

        #------------- 绘制关键点
        for i in range(int(output.shape[0]/2)):
            x = (output[i*2+0]*float(img_width))
            y = (output[i*2+1]*float(img_height))

            cv2.circle(img0, (int(x),int(y)), 3, (255,50,60),-1)
            cv2.circle(img0, (int(x),int(y)), 1, (255,150,180),-1)


        cv2.namedWindow('image',0)
        cv2.imshow('image',img0)
        if cv2.waitKey(600) == 27 :
            break

        cv2.waitKey(0)