Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
码农StayUp
yolov7-obb
提交
a7c29abb
Y
yolov7-obb
项目概览
码农StayUp
/
yolov7-obb
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Y
yolov7-obb
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
a7c29abb
编写于
1月 31, 2023
作者:
_白鹭先生_
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
修改解耦
上级
d59a72cb
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
109 addition
and
108 deletion
+109
-108
utils/utils_bbox.py
utils/utils_bbox.py
+105
-105
yolo.py
yolo.py
+4
-3
未找到文件。
utils/utils_bbox.py
浏览文件 @
a7c29abb
import
numpy
as
np
import
torch
from
torchvision.ops
import
nms
from
utils.nms_rotated
import
obb_nms
#
from utils.nms_rotated import obb_nms
class
DecodeBox
():
def
__init__
(
self
,
anchors
,
num_classes
,
input_shape
,
anchors_mask
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]):
...
...
@@ -23,7 +23,7 @@ class DecodeBox():
#-----------------------------------------------#
# 输入的input一共有三个,他们的shape分别是
# batch_size = 1
# batch_size, 3 * (
4
+ 1 + 80), 20, 20
# batch_size, 3 * (
5
+ 1 + 80), 20, 20
# batch_size, 255, 40, 40
# batch_size, 255, 80, 80
#-----------------------------------------------#
...
...
@@ -64,11 +64,11 @@ class DecodeBox():
#-----------------------------------------------#
# 获得置信度,是否有物体
#-----------------------------------------------#
conf
=
torch
.
sigmoid
(
prediction
[...,
4
])
conf
=
torch
.
sigmoid
(
prediction
[...,
5
])
#-----------------------------------------------#
# 种类置信度
#-----------------------------------------------#
pred_cls
=
torch
.
sigmoid
(
prediction
[...,
5
:])
pred_cls
=
torch
.
sigmoid
(
prediction
[...,
6
:])
FloatTensor
=
torch
.
cuda
.
FloatTensor
if
x
.
is_cuda
else
torch
.
FloatTensor
LongTensor
=
torch
.
cuda
.
LongTensor
if
x
.
is_cuda
else
torch
.
LongTensor
...
...
@@ -232,7 +232,7 @@ class DecodeBox():
output
[
i
][:,
:
4
]
=
self
.
yolo_correct_boxes
(
box_xy
,
box_wh
,
input_shape
,
image_shape
,
letterbox_image
)
return
output
def
non_max_suppression_obb
(
prediction
,
conf_thres
=
0.25
,
iou_thres
=
0.45
,
classes
=
None
,
agnostic
=
False
,
multi_label
=
False
,
def
non_max_suppression_obb
(
self
,
prediction
,
conf_thres
=
0.25
,
iou_thres
=
0.45
,
classes
=
None
,
agnostic
=
False
,
multi_label
=
False
,
labels
=
()):
"""Runs Non-Maximum Suppression (NMS) on inference results
...
...
@@ -339,7 +339,7 @@ if __name__ == "__main__":
#---------------------------------------------------#
def
get_anchors_and_decode
(
input
,
input_shape
,
anchors
,
anchors_mask
,
num_classes
):
#-----------------------------------------------#
# input batch_size, 3 * (
4
+ 1 + num_classes), 20, 20
# input batch_size, 3 * (
5
+ 1 + num_classes), 20, 20
#-----------------------------------------------#
batch_size
=
input
.
size
(
0
)
input_height
=
input
.
size
(
2
)
...
...
@@ -364,7 +364,7 @@ if __name__ == "__main__":
# batch_size, 3, 20, 20, 4 + 1 + num_classes
#-----------------------------------------------#
prediction
=
input
.
view
(
batch_size
,
len
(
anchors_mask
[
2
]),
num_classes
+
5
,
input_height
,
input_width
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
num_classes
+
6
,
input_height
,
input_width
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
#-----------------------------------------------#
# 先验框的中心位置的调整参数
...
...
@@ -379,11 +379,11 @@ if __name__ == "__main__":
#-----------------------------------------------#
# 获得置信度,是否有物体 0 - 1
#-----------------------------------------------#
conf
=
torch
.
sigmoid
(
prediction
[...,
4
])
conf
=
torch
.
sigmoid
(
prediction
[...,
5
])
#-----------------------------------------------#
# 种类置信度 0 - 1
#-----------------------------------------------#
pred_cls
=
torch
.
sigmoid
(
prediction
[...,
5
:])
pred_cls
=
torch
.
sigmoid
(
prediction
[...,
6
:])
FloatTensor
=
torch
.
cuda
.
FloatTensor
if
x
.
is_cuda
else
torch
.
FloatTensor
LongTensor
=
torch
.
cuda
.
LongTensor
if
x
.
is_cuda
else
torch
.
LongTensor
...
...
@@ -498,7 +498,7 @@ if __name__ == "__main__":
plt
.
show
()
#
feat
=
torch
.
from_numpy
(
np
.
random
.
normal
(
0.2
,
0.5
,
[
4
,
25
5
,
20
,
20
])).
float
()
feat
=
torch
.
from_numpy
(
np
.
random
.
normal
(
0.2
,
0.5
,
[
4
,
25
8
,
20
,
20
])).
float
()
anchors
=
np
.
array
([[
116
,
90
],
[
156
,
198
],
[
373
,
326
],
[
30
,
61
],
[
62
,
45
],
[
59
,
119
],
[
10
,
13
],
[
16
,
30
],
[
33
,
23
]])
anchors_mask
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
get_anchors_and_decode
(
feat
,
[
640
,
640
],
anchors
,
anchors_mask
,
80
)
yolo.py
浏览文件 @
a7c29abb
...
...
@@ -10,7 +10,7 @@ from PIL import ImageDraw, ImageFont
from
nets.yolo
import
YoloBody
from
utils.utils
import
(
cvtColor
,
get_anchors
,
get_classes
,
preprocess_input
,
resize_image
,
show_config
)
from
utils.utils_bbox
import
non_max_suppression_obb
from
utils.utils_bbox
import
DecodeBox
from
utils.utils_rbox
import
rbox2poly
'''
训练自己的数据集必看注释!
...
...
@@ -84,7 +84,7 @@ class YOLO(object):
#---------------------------------------------------#
self
.
class_names
,
self
.
num_classes
=
get_classes
(
self
.
classes_path
)
self
.
anchors
,
self
.
num_anchors
=
get_anchors
(
self
.
anchors_path
)
self
.
bbox_util
=
DecodeBox
(
self
.
anchors
,
self
.
num_classes
,
(
self
.
input_shape
[
0
],
self
.
input_shape
[
1
]),
self
.
anchors_mask
)
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
...
...
@@ -144,10 +144,11 @@ class YOLO(object):
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs
=
self
.
net
(
images
)
outputs
=
self
.
bbox_util
.
decode_box
(
outputs
)
#---------------------------------------------------------#
# 将预测框进行堆叠,然后进行非极大抑制
#---------------------------------------------------------#
results
=
non_max_suppression_obb
(
outputs
,
self
.
confidence
,
self
.
nms_iou
,
classes
=
self
.
num_classes
)
results
=
self
.
bbox_util
.
non_max_suppression_obb
(
torch
.
cat
(
outputs
,
1
)
,
self
.
confidence
,
self
.
nms_iou
,
classes
=
self
.
num_classes
)
if
results
[
0
]
is
None
:
return
image
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录