Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
一汁程序喵
ssd-pytorch
提交
ccb1ac0b
S
ssd-pytorch
项目概览
一汁程序喵
/
ssd-pytorch
与 Fork 源项目一致
从无法访问的项目Fork
通知
8
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
S
ssd-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
ccb1ac0b
编写于
5月 24, 2020
作者:
J
JiaQi Xu
提交者:
GitHub
5月 24, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add files via upload
上级
90fe4cca
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
113 addition
and
57 deletion
+113
-57
nets/ssd_training.py
nets/ssd_training.py
+9
-8
ssd.py
ssd.py
+12
-9
train.py
train.py
+92
-40
未找到文件。
nets/ssd_training.py
浏览文件 @
ccb1ac0b
...
...
@@ -40,24 +40,25 @@ class MultiBoxLoss(nn.Module):
priors
=
priors
[:
loc_data
.
size
(
1
),
:]
# 先验框的数量
num_priors
=
(
priors
.
size
(
0
))
num_classes
=
self
.
num_classes
# 创建一个tensor进行处理
loc_t
=
torch
.
Tensor
(
num
,
num_priors
,
4
)
conf_t
=
torch
.
LongTensor
(
num
,
num_priors
)
if
self
.
use_gpu
:
loc_t
=
loc_t
.
cuda
()
conf_t
=
conf_t
.
cuda
()
priors
=
priors
.
cuda
()
for
idx
in
range
(
num
):
# 获得框
truths
=
targets
[
idx
][:,
:
-
1
]
.
data
truths
=
targets
[
idx
][:,
:
-
1
]
# 获得标签
labels
=
targets
[
idx
][:,
-
1
]
.
data
labels
=
targets
[
idx
][:,
-
1
]
# 获得先验框
defaults
=
priors
.
data
defaults
=
priors
# 找到标签对应的先验框
match
(
self
.
threshold
,
truths
,
defaults
,
self
.
variance
,
labels
,
loc_t
,
conf_t
,
idx
)
if
self
.
use_gpu
:
loc_t
=
loc_t
.
cuda
()
conf_t
=
conf_t
.
cuda
()
# 转化成Variable
loc_t
=
Variable
(
loc_t
,
requires_grad
=
False
)
conf_t
=
Variable
(
conf_t
,
requires_grad
=
False
)
...
...
ssd.py
浏览文件 @
ccb1ac0b
...
...
@@ -17,6 +17,7 @@ class SSD(object):
"classes_path"
:
'model_data/voc_classes.txt'
,
"model_image_size"
:
(
300
,
300
,
3
),
"confidence"
:
0.5
,
"cuda"
:
True
,
}
@
classmethod
...
...
@@ -27,7 +28,7 @@ class SSD(object):
return
"Unrecognized attribute name '"
+
n
+
"'"
#---------------------------------------------------#
# 初始化
RFB
# 初始化
SSD
#---------------------------------------------------#
def
__init__
(
self
,
**
kwargs
):
self
.
__dict__
.
update
(
self
.
_defaults
)
...
...
@@ -51,14 +52,14 @@ class SSD(object):
# 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
# 否则先构建模型再载入
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
model
=
ssd
.
get_ssd
(
"test"
,
self
.
num_classes
)
self
.
net
=
model
model
.
load_state_dict
(
torch
.
load
(
self
.
model_path
))
self
.
net
=
model
.
eval
()
self
.
net
=
torch
.
nn
.
DataParallel
(
self
.
net
)
cudnn
.
benchmark
=
True
self
.
net
=
self
.
net
.
cuda
()
if
self
.
cuda
:
cudnn
.
benchmark
=
True
self
.
net
=
self
.
net
.
cuda
()
print
(
'{} model, anchors, and classes loaded.'
.
format
(
self
.
model_path
))
# 画框设置不同的颜色
...
...
@@ -77,10 +78,12 @@ class SSD(object):
crop_img
=
np
.
array
(
letterbox_image
(
image
,
(
self
.
model_image_size
[
0
],
self
.
model_image_size
[
1
])))
photo
=
np
.
array
(
crop_img
,
dtype
=
np
.
float64
)
# 图片预处理,归一化
photo
=
Variable
(
torch
.
from_numpy
(
np
.
expand_dims
(
np
.
transpose
(
crop_img
-
MEANS
,(
2
,
0
,
1
)),
0
)).
cuda
().
type
(
torch
.
FloatTensor
))
preds
=
self
.
net
(
photo
)
photo
=
Variable
(
torch
.
from_numpy
(
np
.
expand_dims
(
np
.
transpose
(
crop_img
-
MEANS
,(
2
,
0
,
1
)),
0
)).
type
(
torch
.
FloatTensor
))
with
torch
.
no_grad
():
if
self
.
cuda
:
photo
=
photo
.
cuda
()
preds
=
self
.
net
(
photo
)
top_conf
=
[]
top_label
=
[]
...
...
train.py
浏览文件 @
ccb1ac0b
...
...
@@ -18,12 +18,19 @@ def adjust_learning_rate(optimizer, lr, gamma, step):
if
__name__
==
"__main__"
:
Batch_size
=
4
lr
=
1e-5
Epoch
=
50
# ------------------------------------#
# 先冻结一部分权重训练
# 后解冻全部权重训练
# 先大学习率
# 后小学习率
# ------------------------------------#
lr
=
1e-4
freeze_lr
=
1e-5
Cuda
=
True
Start_iter
=
0
# 需要使用device来指定网络在GPU还是CPU运行
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
Freeze_epoch
=
25
Epoch
=
50
model
=
get_ssd
(
"train"
,
Config
[
"num_classes"
])
print
(
'Loading weights into state dict...'
)
...
...
@@ -34,7 +41,7 @@ if __name__ == "__main__":
model
.
load_state_dict
(
model_dict
)
print
(
'Finished!'
)
net
=
model
net
=
model
.
train
()
if
Cuda
:
net
=
torch
.
nn
.
DataParallel
(
model
)
cudnn
.
benchmark
=
True
...
...
@@ -51,45 +58,90 @@ if __name__ == "__main__":
gen
=
Generator
(
Batch_size
,
lines
,
(
Config
[
"min_dim"
],
Config
[
"min_dim"
]),
Config
[
"num_classes"
]).
generate
()
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
criterion
=
MultiBoxLoss
(
Config
[
'num_classes'
],
0.5
,
True
,
0
,
True
,
3
,
0.5
,
False
,
Cuda
)
epoch_size
=
num_train
//
Batch_size
net
.
train
()
if
True
:
# ------------------------------------#
# 冻结一定部分训练
# ------------------------------------#
for
param
in
model
.
vgg
.
parameters
():
param
.
requires_grad
=
False
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
lr
)
for
epoch
in
range
(
Start_iter
,
Freeze_epoch
):
if
epoch
%
10
==
0
:
adjust_learning_rate
(
optimizer
,
lr
,
0.95
,
epoch
)
loc_loss
=
0
conf_loss
=
0
for
iteration
in
range
(
epoch_size
):
images
,
targets
=
next
(
gen
)
with
torch
.
no_grad
():
if
Cuda
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
type
(
torch
.
FloatTensor
)).
cuda
()
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
type
(
torch
.
FloatTensor
)).
cuda
()
for
ann
in
targets
]
else
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
type
(
torch
.
FloatTensor
))
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
type
(
torch
.
FloatTensor
))
for
ann
in
targets
]
# 前向传播
out
=
net
(
images
)
# 清零梯度
optimizer
.
zero_grad
()
# 计算loss
loss_l
,
loss_c
=
criterion
(
out
,
targets
)
loss
=
loss_l
+
loss_c
# 反向传播
loss
.
backward
()
optimizer
.
step
()
# 加上
loc_loss
+=
loss_l
.
item
()
conf_loss
+=
loss_c
.
item
()
epoch_size
=
num_train
//
Batch_size
for
epoch
in
range
(
Start_iter
,
Epoch
):
if
epoch
%
10
==
0
:
adjust_learning_rate
(
optimizer
,
lr
,
0.95
,
epoch
)
loc_loss
=
0
conf_loss
=
0
for
iteration
in
range
(
epoch_size
):
images
,
targets
=
next
(
gen
)
with
torch
.
no_grad
():
if
Cuda
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
cuda
().
type
(
torch
.
FloatTensor
))
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
cuda
().
type
(
torch
.
FloatTensor
))
for
ann
in
targets
]
else
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
type
(
torch
.
FloatTensor
))
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
type
(
torch
.
FloatTensor
))
for
ann
in
targets
]
# 前向传播
out
=
net
(
images
)
# 清零梯度
optimizer
.
zero_grad
()
# 计算loss
loss_l
,
loss_c
=
criterion
(
out
,
targets
)
loss
=
loss_l
+
loss_c
# 反向传播
loss
.
backward
()
optimizer
.
step
()
# 加上
loc_loss
+=
loss_l
.
item
()
conf_loss
+=
loss_c
.
item
()
print
(
'
\n
Epoch:'
+
str
(
epoch
+
1
)
+
'/'
+
str
(
Freeze_epoch
))
print
(
'iter:'
+
str
(
iteration
)
+
'/'
+
str
(
epoch_size
)
+
' || Loc_Loss: %.4f || Conf_Loss: %.4f ||'
%
(
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)),
end
=
' '
)
print
(
'Saving state, iter:'
,
str
(
epoch
+
1
))
torch
.
save
(
model
.
state_dict
(),
'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'
%
((
epoch
+
1
),
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)))
if
True
:
# ------------------------------------#
# 全部解冻训练
# ------------------------------------#
for
param
in
model
.
vgg
.
parameters
():
param
.
requires_grad
=
True
optimizer
=
optim
.
Adam
(
net
.
parameters
(),
lr
=
freeze_lr
)
for
epoch
in
range
(
Freeze_epoch
,
Epoch
):
if
epoch
%
10
==
0
:
adjust_learning_rate
(
optimizer
,
freeze_lr
,
0.95
,
epoch
)
loc_loss
=
0
conf_loss
=
0
for
iteration
in
range
(
epoch_size
):
images
,
targets
=
next
(
gen
)
with
torch
.
no_grad
():
if
Cuda
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
type
(
torch
.
FloatTensor
)).
cuda
()
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
type
(
torch
.
FloatTensor
)).
cuda
()
for
ann
in
targets
]
else
:
images
=
Variable
(
torch
.
from_numpy
(
images
).
type
(
torch
.
FloatTensor
))
targets
=
[
Variable
(
torch
.
from_numpy
(
ann
).
type
(
torch
.
FloatTensor
))
for
ann
in
targets
]
# 前向传播
out
=
net
(
images
)
# 清零梯度
optimizer
.
zero_grad
()
# 计算loss
loss_l
,
loss_c
=
criterion
(
out
,
targets
)
loss
=
loss_l
+
loss_c
# 反向传播
loss
.
backward
()
optimizer
.
step
()
# 加上
loc_loss
+=
loss_l
.
item
()
conf_loss
+=
loss_c
.
item
()
print
(
'
\n
Epoch:'
+
str
(
epoch
+
1
)
+
'/'
+
str
(
Epoch
))
print
(
'iter:'
+
str
(
iteration
)
+
'/'
+
str
(
epoch_size
)
+
' || Loc_Loss: %.4f || Conf_Loss: %.4f ||'
%
(
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)),
end
=
' '
)
print
(
'
\n
Epoch:'
+
str
(
epoch
+
1
)
+
'/'
+
str
(
Epoch
))
print
(
'iter:'
+
str
(
iteration
)
+
'/'
+
str
(
epoch_size
)
+
' || Loc_Loss: %.4f || Conf_Loss: %.4f ||'
%
(
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)),
end
=
' '
)
print
(
'Saving state, iter:'
,
str
(
epoch
+
1
))
torch
.
save
(
model
.
state_dict
(),
'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'
%
((
epoch
+
1
),
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)))
print
(
'Saving state, iter:'
,
str
(
epoch
+
1
))
torch
.
save
(
model
.
state_dict
(),
'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'
%
((
epoch
+
1
),
loc_loss
/
(
iteration
+
1
),
conf_loss
/
(
iteration
+
1
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录