From b24c366764920a36b31788f4abdb1c73ac0a4501 Mon Sep 17 00:00:00 2001 From: "Eric.Lee2021" <305141918@qq.com> Date: Sat, 20 Feb 2021 15:36:56 +0800 Subject: [PATCH] add safe hat pretrain model --- README.md | 2 +- predict.py | 29 ++++++++++++++++++----------- train.py | 3 ++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 3959772..caaa2bf 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ h = h*dh * [预训练模型下载地址(百度网盘 Password: ise9 )](https://pan.baidu.com/s/1mxiI-tOpE3sU-9TVPJmPWw) ### 4、安全帽检测预训练模型 -* [预训练模型下载地址(百度网盘 Password: )]() +* [预训练模型下载地址(百度网盘 Password: inu8 )](https://pan.baidu.com/s/1rWUEsPnOTdmfJW4xM8m3Eg) ## 项目使用方法 diff --git a/predict.py b/predict.py index 180474d..ca301cc 100644 --- a/predict.py +++ b/predict.py @@ -60,6 +60,7 @@ def detect( img_size=416, conf_thres=0.5, nms_thres=0.5, + video_path = 0, ): classes = load_classes(parse_data_cfg(data_cfg)['names']) num_classes = len(classes) @@ -93,7 +94,7 @@ def detect( colors = [(v // 32 * 64 + 64, (v // 8) % 4 * 64, v % 8 * 32) for v in range(1, num_classes + 1)][::-1] - video_capture = cv2.VideoCapture("./video/bean_1.mp4") + video_capture = cv2.VideoCapture(video_path) # url="http://admin:admin@192.168.43.1:8081" # video_capture=cv2.VideoCapture(url) @@ -151,7 +152,11 @@ def detect( # print(conf, cls_conf) # xyxy = refine_hand_bbox(xyxy,im0.shape) - plot_one_box(xyxy, im0, label=label, color=(15,255,95),line_thickness = 3) + xyxy = int(xyxy[0]),int(xyxy[1])+6,int(xyxy[2]),int(xyxy[3]) + if int(cls) == 0: + plot_one_box(xyxy, im0, label=label, color=(15,255,95),line_thickness = 3) + else: + plot_one_box(xyxy, im0, label=label, color=(15,155,255),line_thickness = 3) s2 = time.time() print("detect time: {} \n".format(s2 - t)) @@ -177,20 +182,22 @@ def detect( if __name__ == '__main__': - voc_config = 'cfg/person.data' # 模型相关配置文件 - model_path = './weights-yolov3-person/latest_416.pt' # 检测模型路径 - model_cfg = 'yolo' # yolo / yolo-tiny + voc_config = 'cfg/helmet.data' # 模型相关配置文件 + model_path = './weights-yolov3-helmet/hat_416_epoch_410.pt' # 检测模型路径 + model_cfg = 'yolo' # yolo / yolo-tiny 模型结构 + video_path = "./video/hat5.mp4" # 测试视频 img_size = 416 # 图像尺寸 - conf_thres = 0.36# 检测置信度 - nms_thres = 0.5 # nms 阈值 + conf_thres = 0.5# 检测置信度 + nms_thres = 0.6 # nms 阈值 - with torch.no_grad():#设置无梯度运行 + with torch.no_grad():#设置无梯度运行模型推理 detect( model_path = model_path, cfg = model_cfg, data_cfg = voc_config, - img_size=img_size, - conf_thres=conf_thres, - nms_thres=nms_thres, + img_size = img_size, + conf_thres = conf_thres, + nms_thres = nms_thres, + video_path = video_path, ) diff --git a/train.py b/train.py index f2fb891..7a59cf5 100644 --- a/train.py +++ b/train.py @@ -204,7 +204,8 @@ if __name__ == '__main__': # train(data_cfg="cfg/hand.data") # train(data_cfg = "cfg/face.data") - train(data_cfg = "cfg/person.data") + # train(data_cfg = "cfg/person.data") + train(data_cfg = "cfg/helmet.data") print('well done ~ ') -- GitLab