diff --git a/best (3).onnx b/best (3).onnx new file mode 100644 index 0000000000000000000000000000000000000000..266143d3daa56ba32c34055d4dac9f3d97453d95 Binary files /dev/null and b/best (3).onnx differ diff --git a/main.py b/main.py index 4c0c135f61696bcf42c375ca5ab62aa5b105afc8..7d6e7007ad4879fc4378f2e8c10c076dda870d64 100644 --- a/main.py +++ b/main.py @@ -1 +1,142 @@ -print('欢迎来到 InsCode') \ No newline at end of file +import cv2 +import numpy as np +import onnxruntime as ort + +class YOLOv5_Lite: + def __init__(self, onnx_model_path, class_names_path, conf_threshold=0.5, iou_threshold=0.5): + """ + 初始化 YOLOv5_Lite 类 + :param onnx_model_path: ONNX 模型文件路径 + :param class_names_path: 类别名称文件路径 + :param conf_threshold: 置信度阈值 + :param iou_threshold: IOU 阈值 + """ + # 创建 ONNX Runtime 会话 + so = ort.SessionOptions() + so.log_severity_level = 3 + self.session = ort.InferenceSession(onnx_model_path, so) + # 获取输入输出节点名称 + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + # 加载类别名称 + with open(class_names_path, 'r') as f: + self.class_names = [line.strip() for line in f.readlines()] + # 设置阈值 + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + + def preprocess(self, image): + """ + 图像预处理 + :param image: 输入图像 + :return: 预处理后的图像数据 + """ + # 转为 RGB 格式 + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # 缩放至模型输入尺寸 + img = cv2.resize(image, (640, 640)) + # 归一化 + img = img.astype(np.float32) / 255.0 + # 调整维度 + img = np.expand_dims(img.transpose(2, 0, 1), axis=0).astype(np.float32) + return img + + def postprocess(self, outputs, img_shape, origin_shape): + """ + 后处理函数 + :param outputs: 模型输出 + :param img_shape: 输入图像尺寸 + :param origin_shape: 原始图像尺寸 + :return: 检测框坐标、类别名称、置信度 + """ + # 获取模型输出 + outputs = outputs[0] + # 过滤低置信度的检测结果 + outputs = outputs[outputs[:, 4] >= self.conf_threshold] + # 初始化结果列表 + boxes = [] + scores = [] + classes = [] + # 遍历每个检测结果 + for output in outputs: + # 获取类别索引和置信度 + class_id = np.argmax(output[5:]) + confidence = output[5 + class_id] + # 过滤低置信度的检测结果 + if confidence < self.conf_threshold: + continue + # 获取边界框坐标 + x, y, w, h = output[:4] + # 反归一化 + x *= img_shape[1] + y *= img_shape[0] + w *= img_shape[1] + h *= img_shape[0] + # 转换为左上角和右下角坐标 + x1 = int(x - w / 2) + y1 = int(y - h / 2) + x2 = int(x + w / 2) + y2 = int(y + h / 2) + # 调整边界框坐标以适配原始图像尺寸 + x1 = max(0, min(x1, origin_shape[1])) + y1 = max(0, min(y1, origin_shape[0])) + x2 = max(0, min(x2, origin_shape[1])) + y2 = max(0, min(y2, origin_shape[0])) + # 添加到结果列表 + boxes.append([x1, y1, x2, y2]) + scores.append(float(confidence)) + classes.append(self.class_names[class_id]) + # 非极大值抑制 + indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf_threshold, self.iou_threshold) + # 过滤后的检测结果 + boxes = [boxes[i] for i in indices] + classes = [classes[i] for i in indices] + scores = [scores[i] for i in indices] + return boxes, classes, scores + + def detect(self, image): + """ + 物体检测函数 + :param image: 输入图像 + :return: 检测框坐标、类别名称、置信度 + """ + # 图像预处理 + img = self.preprocess(image) + # 模型推理 + outputs = self.session.run([self.output_name], {self.input_name: img}) + # 后处理 + boxes, classes, scores = self.postprocess(outputs, img.shape[2:], image.shape[:2]) + return boxes, classes, scores + +if __name__ == "__main__": + # 模型路径 + onnx_model_path = "best (3).onnx" + # 类别文件路径 + class_names_path = "rubbish.names" + # 初始化 YOLOv5_Lite 类 + detector = YOLOv5_Lite(onnx_model_path, class_names_path) + # 打开摄像头 + cap = cv2.VideoCapture(0) + while True: + # 读取一帧图像 + ret, frame = cap.read() + if not ret: + break + # 检测物体 + boxes, classes, scores = detector.detect(frame) + # 在图像上绘制检测结果 + for box, cls, score in zip(boxes, classes, scores): + x1, y1, x2, y2 = box + # 绘制边界框 + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + # 显示类别名称和置信度 + label = f"{cls}: {score:.2f}" + cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + # 显示图像 + cv2.imshow("Object Detection", frame) + # 按下 'q' 键退出循环 + if cv2.waitKey(1) & 0xFF == ord('q'): + break + # 释放摄像头和关闭窗口 + cap.release() + cv2.destroyAllWindows() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..51decf87adad3e5697ac7a7e325a119339eab6b9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1 @@ +onnxruntime \ No newline at end of file diff --git a/rubbish.names b/rubbish.names new file mode 100644 index 0000000000000000000000000000000000000000..c607191b16dc191ff5df5f15f4ce602ce33cb1b7 --- /dev/null +++ b/rubbish.names @@ -0,0 +1,4 @@ +person +bicycle +car +motorcycle \ No newline at end of file