Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BlackJayson
yolo_v3
提交
d018bd5f
yolo_v3
项目概览
BlackJayson
/
yolo_v3
与 Fork 源项目一致
Fork自
DataBall / yolo_v3
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
yolo_v3
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d018bd5f
编写于
2月 02, 2021
作者:
DataBall
🚴🏻
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
create pro
上级
13e1ab1c
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
713 addition
and
0 deletion
+713
-0
train.py
train.py
+208
-0
yolov3.py
yolov3.py
+505
-0
未找到文件。
train.py
0 → 100644
浏览文件 @
d018bd5f
#coding:utf-8
from
yolov3
import
Yolov3
,
Yolov3Tiny
from
utils.parse_config
import
parse_data_cfg
from
utils.torch_utils
import
select_device
import
torch
from
torch.utils.data
import
DataLoader
from
utils.datasets
import
LoadImagesAndLabels
from
utils.utils
import
*
import
os
import
numpy
as
np
def
set_learning_rate
(
optimizer
,
lr
):
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
train
(
data_cfg
=
'cfg/voc.data'
,
accumulate
=
1
):
device
=
select_device
()
# Configure run
get_data_cfg
=
parse_data_cfg
(
data_cfg
)
#返回训练配置参数,类型:字典
gpus
=
get_data_cfg
[
'gpus'
]
num_workers
=
int
(
get_data_cfg
[
'num_workers'
])
cfg_model
=
get_data_cfg
[
'cfg_model'
]
train_path
=
get_data_cfg
[
'train'
]
valid_ptah
=
get_data_cfg
[
'valid'
]
num_classes
=
int
(
get_data_cfg
[
'classes'
])
finetune_model
=
get_data_cfg
[
'finetune_model'
]
batch_size
=
int
(
get_data_cfg
[
'batch_size'
])
img_size
=
int
(
get_data_cfg
[
'img_size'
])
multi_scale
=
get_data_cfg
[
'multi_scale'
]
epochs
=
int
(
get_data_cfg
[
'epochs'
])
lr_step
=
str
(
get_data_cfg
[
'lr_step'
])
lr0
=
float
(
get_data_cfg
[
'lr0'
])
if
multi_scale
==
'True'
:
multi_scale
=
True
else
:
multi_scale
=
False
print
(
'data_cfg : '
,
data_cfg
)
print
(
'voc.data config len : '
,
len
(
get_data_cfg
))
print
(
'gpus : '
,
gpus
)
print
(
'num_workers : '
,
num_workers
)
print
(
'model : '
,
cfg_model
)
print
(
'finetune_model : '
,
finetune_model
)
print
(
'train_path : '
,
train_path
)
print
(
'valid_ptah : '
,
valid_ptah
)
print
(
'num_classes : '
,
num_classes
)
print
(
'batch_size : '
,
batch_size
)
print
(
'img_size : '
,
img_size
)
print
(
'multi_scale : '
,
multi_scale
)
print
(
'lr0 : '
,
lr0
)
print
(
'lr_step : '
,
lr_step
)
# load model
pattern_data_
=
data_cfg
.
split
(
"/"
)[
-
1
:][
0
].
replace
(
".data"
,
""
)
if
"-tiny"
in
cfg_model
:
a_scalse
=
416.
/
img_size
anchors
=
[(
10
,
14
),
(
23
,
27
),
(
37
,
58
),
(
81
,
82
),
(
135
,
169
),
(
344
,
319
)]
anchors_new
=
[
(
int
(
anchors
[
j
][
0
]
/
a_scalse
),
int
(
anchors
[
j
][
1
]
/
a_scalse
))
for
j
in
range
(
len
(
anchors
))
]
model
=
Yolov3Tiny
(
num_classes
,
anchors
=
anchors_new
)
# weights = './weights-yolov3-person-tiny/'
weights
=
'./weights-yolov3-{}-tiny/'
.
format
(
pattern_data_
)
else
:
a_scalse
=
416.
/
img_size
anchors
=
[(
10
,
13
),
(
16
,
30
),
(
33
,
23
),
(
30
,
61
),
(
62
,
45
),
(
59
,
119
),
(
116
,
90
),
(
156
,
198
),
(
373
,
326
)]
anchors_new
=
[
(
int
(
anchors
[
j
][
0
]
/
a_scalse
),
int
(
anchors
[
j
][
1
]
/
a_scalse
))
for
j
in
range
(
len
(
anchors
))
]
model
=
Yolov3
(
num_classes
,
anchors
=
anchors_new
)
weights
=
'./weights-yolov3-{}/'
.
format
(
pattern_data_
)
# mkdir save model document
if
not
os
.
path
.
exists
(
weights
):
os
.
mkdir
(
weights
)
model
=
model
.
to
(
device
)
latest
=
weights
+
'latest_{}.pt'
.
format
(
img_size
)
best
=
weights
+
'best_{}.pt'
.
format
(
img_size
)
# Optimizer
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
lr0
,
momentum
=
0.9
,
weight_decay
=
0.0005
)
start_epoch
=
0
if
os
.
access
(
finetune_model
,
os
.
F_OK
):
# load retrain/finetune_model
print
(
'loading yolo-v3 finetune_model ~~~~~~'
,
finetune_model
)
not_load_filters
=
3
*
(
80
+
5
)
# voc: 3*(20+5), coco: 3*(80+5)=255
chkpt
=
torch
.
load
(
finetune_model
,
map_location
=
device
)
model
.
load_state_dict
({
k
:
v
for
k
,
v
in
chkpt
[
'model'
].
items
()
if
v
.
numel
()
>
1
and
v
.
shape
[
0
]
!=
not_load_filters
},
strict
=
False
)
# model.load_state_dict(chkpt['model'])
if
'coco'
not
in
finetune_model
:
start_epoch
=
chkpt
[
'epoch'
]
if
chkpt
[
'optimizer'
]
is
not
None
:
optimizer
.
load_state_dict
(
chkpt
[
'optimizer'
])
best_loss
=
chkpt
[
'best_loss'
]
# Set scheduler (reduce lr at epochs 218, 245, i.e. batches 400k, 450k) gamma:学习率下降的乘数因子
milestones
=
[
int
(
i
)
for
i
in
lr_step
.
split
(
","
)]
print
(
'milestones : '
,
milestones
)
scheduler
=
torch
.
optim
.
lr_scheduler
.
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
i
)
for
i
in
lr_step
.
split
(
","
)],
gamma
=
0.1
,
last_epoch
=
start_epoch
-
1
)
# Dataset
print
(
'multi_scale : '
,
multi_scale
)
dataset
=
LoadImagesAndLabels
(
train_path
,
batch_size
=
batch_size
,
img_size
=
img_size
,
augment
=
True
,
multi_scale
=
multi_scale
)
print
(
'--------------->>> imge num : '
,
dataset
.
__len__
())
# Dataloader
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
num_workers
,
shuffle
=
True
,
pin_memory
=
False
,
drop_last
=
False
,
collate_fn
=
dataset
.
collate_fn
)
# Start training
t
=
time
.
time
()
# model_info(model)# 打印模型信息
nB
=
len
(
dataloader
)
n_burnin
=
min
(
round
(
nB
/
5
+
1
),
1000
)
# burn-in batches
best_loss
=
float
(
'inf'
)
test_loss
=
float
(
'inf'
)
flag_start
=
False
for
epoch
in
range
(
0
,
epochs
):
print
(
' ~~~~'
)
model
.
train
()
if
flag_start
:
scheduler
.
step
()
flag_start
=
True
mloss
=
defaultdict
(
float
)
# mean loss
for
i
,
(
imgs
,
targets
,
img_path_
,
_
)
in
enumerate
(
dataloader
):
multi_size
=
imgs
.
size
()
imgs
=
imgs
.
to
(
device
)
targets
=
targets
.
to
(
device
)
nt
=
len
(
targets
)
if
nt
==
0
:
# if no targets continue
continue
# SGD burn-in
if
epoch
==
0
and
i
<=
n_burnin
:
lr
=
lr0
*
(
i
/
n_burnin
)
**
4
for
x
in
optimizer
.
param_groups
:
x
[
'lr'
]
=
lr
# Run model
pred
=
model
(
imgs
)
# Build targets
target_list
=
build_targets
(
model
,
targets
)
# Compute loss
loss
,
loss_dict
=
compute_loss
(
pred
,
target_list
)
# Compute gradient
loss
.
backward
()
# Accumulate gradient for x batches before optimizing
if
(
i
+
1
)
%
accumulate
==
0
or
(
i
+
1
)
==
nB
:
optimizer
.
step
()
optimizer
.
zero_grad
()
# Running epoch-means of tracked metrics
for
key
,
val
in
loss_dict
.
items
():
mloss
[
key
]
=
(
mloss
[
key
]
*
i
+
val
)
/
(
i
+
1
)
print
(
' Epoch {:3d}/{:3d}, Batch {:6d}/{:6d}, Img_size {}x{}, nTargets {}, lr {:.6f}, loss: xy {:.3f}, wh {:.3f}, '
'conf {:.3f}, cls {:.3f}, total {:.3f}, time {:.3f}s'
.
format
(
epoch
,
epochs
-
1
,
i
,
nB
-
1
,
multi_size
[
2
],
multi_size
[
3
]
,
nt
,
scheduler
.
get_lr
()[
0
],
mloss
[
'xy'
],
mloss
[
'wh'
],
mloss
[
'conf'
],
mloss
[
'cls'
],
mloss
[
'total'
],
time
.
time
()
-
t
),
end
=
'
\r
'
)
s
=
(
'%8s%12s'
+
'%10.3g'
*
7
)
%
(
'%g/%g'
%
(
epoch
,
epochs
-
1
),
'%g/%g'
%
(
i
,
nB
-
1
),
mloss
[
'xy'
],
mloss
[
'wh'
],
mloss
[
'conf'
],
mloss
[
'cls'
],
mloss
[
'total'
],
nt
,
time
.
time
()
-
t
)
t
=
time
.
time
()
print
()
# Create checkpoint
chkpt
=
{
'epoch'
:
epoch
,
'best_loss'
:
best_loss
,
'model'
:
model
.
module
.
state_dict
()
if
type
(
model
)
is
nn
.
parallel
.
DistributedDataParallel
else
model
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
()}
# Save latest checkpoint
torch
.
save
(
chkpt
,
latest
)
# Save best checkpoint
if
best_loss
==
test_loss
and
epoch
%
5
==
0
:
torch
.
save
(
chkpt
,
best
)
# Save backup every 10 epochs (optional)
if
epoch
>
0
and
epoch
%
5
==
0
:
torch
.
save
(
chkpt
,
weights
+
'yoloV3_{}_epoch_{}.pt'
.
format
(
img_size
,
epoch
))
# Delete checkpoint
del
chkpt
#-------------------------------------------------------------------------------
if
__name__
==
'__main__'
:
train
(
data_cfg
=
'cfg/hand.data'
)
print
(
'well done ~ '
)
yolov3.py
0 → 100644
浏览文件 @
d018bd5f
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录