Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
9da63d61
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9da63d61
编写于
3月 11, 2019
作者:
T
tink2123
提交者:
dengkaipeng
3月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added darknet and modified models
上级
b4bdcc5e
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
405 addition
and
1025 deletion
+405
-1025
fluid/PaddleCV/yolov3/config.py
fluid/PaddleCV/yolov3/config.py
+8
-2
fluid/PaddleCV/yolov3/config/__init__.py
fluid/PaddleCV/yolov3/config/__init__.py
+0
-0
fluid/PaddleCV/yolov3/config/edict.py
fluid/PaddleCV/yolov3/config/edict.py
+0
-37
fluid/PaddleCV/yolov3/config/yolov3-tiny.cfg
fluid/PaddleCV/yolov3/config/yolov3-tiny.cfg
+0
-182
fluid/PaddleCV/yolov3/config/yolov3.cfg
fluid/PaddleCV/yolov3/config/yolov3.cfg
+0
-789
fluid/PaddleCV/yolov3/config_parser.py
fluid/PaddleCV/yolov3/config_parser.py
+0
-0
fluid/PaddleCV/yolov3/eval.py
fluid/PaddleCV/yolov3/eval.py
+2
-3
fluid/PaddleCV/yolov3/infer.py
fluid/PaddleCV/yolov3/infer.py
+2
-2
fluid/PaddleCV/yolov3/models/darknet.py
fluid/PaddleCV/yolov3/models/darknet.py
+102
-0
fluid/PaddleCV/yolov3/models/yolov3.py
fluid/PaddleCV/yolov3/models/yolov3.py
+279
-0
fluid/PaddleCV/yolov3/reader.py
fluid/PaddleCV/yolov3/reader.py
+1
-1
fluid/PaddleCV/yolov3/train.py
fluid/PaddleCV/yolov3/train.py
+10
-8
fluid/PaddleCV/yolov3/utility.py
fluid/PaddleCV/yolov3/utility.py
+1
-1
未找到文件。
fluid/PaddleCV/yolov3/config
/config
.py
→
fluid/PaddleCV/yolov3/config.py
浏览文件 @
9da63d61
...
@@ -13,7 +13,7 @@ from __future__ import absolute_import
...
@@ -13,7 +13,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
from
.
edict
import
AttrDict
from
edict
import
AttrDict
import
six
import
six
import
numpy
as
np
import
numpy
as
np
...
@@ -24,6 +24,10 @@ cfg = _C
...
@@ -24,6 +24,10 @@ cfg = _C
# Training options
# Training options
#
#
# batch
_C
.
batch
=
8
# Snapshot period
# Snapshot period
_C
.
snapshot_iter
=
2000
_C
.
snapshot_iter
=
2000
...
@@ -88,7 +92,9 @@ _C.weight_decay = 0.0005
...
@@ -88,7 +92,9 @@ _C.weight_decay = 0.0005
# momentum with SGD
# momentum with SGD
_C
.
momentum
=
0.9
_C
.
momentum
=
0.9
#
# decay
_C
.
decay
=
0.0005
# ENV options
# ENV options
#
#
...
...
fluid/PaddleCV/yolov3/config/__init__.py
已删除
100644 → 0
浏览文件 @
b4bdcc5e
fluid/PaddleCV/yolov3/config/edict.py
已删除
100644 → 0
浏览文件 @
b4bdcc5e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
class
AttrDict
(
dict
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
AttrDict
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
if
name
in
self
.
__dict__
:
return
self
.
__dict__
[
name
]
elif
name
in
self
:
return
self
[
name
]
else
:
raise
AttributeError
(
name
)
def
__setattr__
(
self
,
name
,
value
):
if
name
in
self
.
__dict__
:
self
.
__dict__
[
name
]
=
value
else
:
self
[
name
]
=
value
fluid/PaddleCV/yolov3/config/yolov3-tiny.cfg
已删除
100644 → 0
浏览文件 @
b4bdcc5e
[net]
# Testing
batch=1
subdivisions=1
# Training
# batch=64
# subdivisions=2
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=1
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
###########
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=80
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 8
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319
classes=80
num=6
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
fluid/PaddleCV/yolov3/config/yolov3.cfg
已删除
100644 → 0
浏览文件 @
b4bdcc5e
[net]
# Testing
# batch=1
# subdivisions=1
# Training
batch=64
subdivisions=16
width=608
height=608
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
# Downsample
[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
######################
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 6,7,8
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 61
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 36
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=255
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=80
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
fluid/PaddleCV/yolov3/config
/config
_parser.py
→
fluid/PaddleCV/yolov3/config_parser.py
浏览文件 @
9da63d61
文件已移动
fluid/PaddleCV/yolov3/eval.py
浏览文件 @
9da63d61
...
@@ -21,12 +21,12 @@ import numpy as np
...
@@ -21,12 +21,12 @@ import numpy as np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
reader
import
reader
import
models
import
models
.yolov3
as
models
from
utility
import
print_arguments
,
parse_args
from
utility
import
print_arguments
,
parse_args
import
json
import
json
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
,
Params
from
pycocotools.cocoeval
import
COCOeval
,
Params
from
config
.config
import
cfg
from
config
import
cfg
def
eval
():
def
eval
():
...
@@ -42,7 +42,6 @@ def eval():
...
@@ -42,7 +42,6 @@ def eval():
model
=
models
.
YOLOv3
(
cfg
.
model_cfg_path
,
is_train
=
False
)
model
=
models
.
YOLOv3
(
cfg
.
model_cfg_path
,
is_train
=
False
)
model
.
build_model
()
model
.
build_model
()
outputs
=
model
.
get_pred
()
outputs
=
model
.
get_pred
()
hyperparams
=
model
.
get_hyperparams
()
yolo_anchors
=
model
.
get_yolo_anchors
()
yolo_anchors
=
model
.
get_yolo_anchors
()
yolo_classes
=
model
.
get_yolo_classes
()
yolo_classes
=
model
.
get_yolo_classes
()
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
...
...
fluid/PaddleCV/yolov3/infer.py
浏览文件 @
9da63d61
...
@@ -6,12 +6,12 @@ import paddle.fluid as fluid
...
@@ -6,12 +6,12 @@ import paddle.fluid as fluid
import
box_utils
import
box_utils
import
reader
import
reader
from
utility
import
print_arguments
,
parse_args
from
utility
import
print_arguments
,
parse_args
import
models
import
models
.yolov3
as
models
# from coco_reader import load_label_names
# from coco_reader import load_label_names
import
json
import
json
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
,
Params
from
pycocotools.cocoeval
import
COCOeval
,
Params
from
config
.config
import
cfg
from
config
import
cfg
def
infer
():
def
infer
():
...
...
fluid/PaddleCV/yolov3/models/darknet.py
0 → 100644
浏览文件 @
9da63d61
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Constant
from
paddle.fluid.regularizer
import
L2Decay
from
config
import
cfg
def
conv_bn_layer
(
input
,
ch_out
,
filter_size
,
stride
,
padding
,
act
=
'leaky'
,
i
=
0
):
conv1
=
fluid
.
layers
.
conv2d
(
input
=
input
,
num_filters
=
ch_out
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
act
=
None
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
),
name
=
"conv"
+
str
(
i
)
+
"_weights"
),
bias_attr
=
False
)
bn_name
=
"bn"
+
str
(
i
)
out
=
fluid
.
layers
.
batch_norm
(
input
=
conv1
,
act
=
None
,
is_test
=
True
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
),
regularizer
=
L2Decay
(
0.
),
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
),
name
=
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_var'
)
if
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
return
out
def
basicblock
(
input
,
ch_out
,
stride
,
i
):
"""
channel: convolution channels for 1x1 conv
"""
conv1
=
conv_bn_layer
(
input
,
ch_out
,
1
,
1
,
0
,
i
=
i
)
conv2
=
conv_bn_layer
(
conv1
,
ch_out
*
2
,
3
,
1
,
1
,
i
=
i
+
1
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
input
,
y
=
conv2
,
act
=
None
,
name
=
"res"
+
str
(
i
+
2
))
return
out
def
layer_warp
(
block_func
,
input
,
ch_out
,
count
,
stride
,
i
):
res_out
=
block_func
(
input
,
ch_out
,
stride
,
i
=
i
)
for
j
in
range
(
1
,
count
):
res_out
=
block_func
(
res_out
,
ch_out
,
1
,
i
=
i
+
j
*
3
)
return
res_out
DarkNet_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
],
basicblock
)
}
# num_filters = [32, 64, 128, 256, 512, 1024]
def
add_DarkNet53_conv_body
(
body_input
):
stages
,
block_func
=
DarkNet_cfg
[
53
]
stages
=
stages
[
0
:
5
]
conv1
=
conv_bn_layer
(
body_input
,
ch_out
=
32
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
act
=
"leaky"
,
i
=
0
)
conv2
=
conv_bn_layer
(
conv1
,
ch_out
=
64
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
act
=
"leaky"
,
i
=
1
)
block3
=
layer_warp
(
block_func
,
conv2
,
32
,
stages
[
0
],
1
,
i
=
2
)
downsample3
=
conv_bn_layer
(
block3
,
ch_out
=
128
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
i
=
5
)
block4
=
layer_warp
(
block_func
,
downsample3
,
64
,
stages
[
1
],
1
,
i
=
6
)
downsample4
=
conv_bn_layer
(
block4
,
ch_out
=
256
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
i
=
12
)
block5
=
layer_warp
(
block_func
,
downsample4
,
128
,
stages
[
2
],
1
,
i
=
13
)
downsample5
=
conv_bn_layer
(
block5
,
ch_out
=
512
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
i
=
37
)
block6
=
layer_warp
(
block_func
,
downsample5
,
256
,
stages
[
3
],
1
,
i
=
38
)
downsample6
=
conv_bn_layer
(
block6
,
ch_out
=
1024
,
filter_size
=
3
,
stride
=
2
,
padding
=
1
,
i
=
62
)
block7
=
layer_warp
(
block_func
,
downsample6
,
512
,
stages
[
4
],
1
,
i
=
63
)
return
block7
,
block6
,
block5
fluid/PaddleCV/yolov3/models/yolov3.py
0 → 100644
浏览文件 @
9da63d61
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from
__future__
import
division
from
__future__
import
print_function
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Constant
from
paddle.fluid.initializer
import
Normal
from
paddle.fluid.regularizer
import
L2Decay
from
config_parser
import
ConfigPaser
from
config
import
cfg
from
darknet
import
add_DarkNet53_conv_body
from
darknet
import
conv_bn_layer
def
yolo_detection_block
(
input
,
channel
,
i
):
assert
channel
%
2
==
0
,
"channel {} cannot be divided by 2"
.
format
(
channel
)
conv1
=
input
for
j
in
range
(
2
):
conv1
=
conv_bn_layer
(
conv1
,
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
i
=
i
+
j
*
2
)
conv1
=
conv_bn_layer
(
conv1
,
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
i
=
i
+
j
*
2
+
1
)
route
=
conv_bn_layer
(
conv1
,
channel
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
i
=
i
+
4
)
tip
=
conv_bn_layer
(
route
,
channel
*
2
,
filter_size
=
3
,
stride
=
1
,
padding
=
1
,
i
=
i
+
5
)
return
route
,
tip
def
upsample
(
out
,
stride
=
2
,
name
=
None
):
out
=
out
scale
=
stride
# get dynamic upsample output shape
shape_nchw
=
fluid
.
layers
.
shape
(
out
)
shape_hw
=
fluid
.
layers
.
slice
(
shape_nchw
,
axes
=
[
0
],
starts
=
[
2
],
ends
=
[
4
])
shape_hw
.
stop_gradient
=
True
in_shape
=
fluid
.
layers
.
cast
(
shape_hw
,
dtype
=
'int32'
)
out_shape
=
in_shape
*
scale
out_shape
.
stop_gradient
=
True
# reisze by actual_shape
out
=
fluid
.
layers
.
resize_nearest
(
input
=
out
,
scale
=
scale
,
actual_shape
=
out_shape
,
name
=
name
)
return
out
class
YOLOv3
(
object
):
def
__init__
(
self
,
model_cfg_path
,
is_train
=
True
,
use_pyreader
=
True
,
use_random
=
True
):
self
.
model_cfg_path
=
model_cfg_path
self
.
config_parser
=
ConfigPaser
(
model_cfg_path
)
self
.
is_train
=
is_train
self
.
use_pyreader
=
use_pyreader
self
.
use_random
=
use_random
self
.
outputs
=
[]
self
.
losses
=
[]
self
.
downsample
=
32
self
.
ignore_thresh
=
.
7
self
.
class_num
=
80
def
build_model
(
self
):
self
.
img_height
=
cfg
.
input_size
self
.
img_width
=
cfg
.
input_size
self
.
build_input
()
out
=
self
.
image
self
.
yolo_anchors
=
[]
self
.
yolo_classes
=
[]
self
.
outputs
=
[]
self
.
boxes
=
[]
self
.
scores
=
[]
scale1
,
scale2
,
scale3
=
add_DarkNet53_conv_body
(
out
)
# 13*13 scale output
route1
,
tip1
=
yolo_detection_block
(
scale1
,
channel
=
512
,
i
=
75
)
# scale1 output
scale1_out
=
fluid
.
layers
.
conv2d
(
input
=
tip1
,
num_filters
=
255
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Normal
(
0.
,
0.02
),
name
=
"conv81_weights"
),
bias_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
regularizer
=
L2Decay
(
0.
),
name
=
"conv81_bias"
))
self
.
outputs
.
append
(
scale1_out
)
route1
=
conv_bn_layer
(
input
=
route1
,
ch_out
=
256
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
i
=
84
)
# upsample
route1
=
upsample
(
route1
)
# concat
route1
=
fluid
.
layers
.
concat
(
input
=
[
route1
,
scale2
],
axis
=
1
)
# 26*26 scale output
route2
,
tip2
=
yolo_detection_block
(
route1
,
channel
=
256
,
i
=
87
)
# scale2 output
scale2_out
=
fluid
.
layers
.
conv2d
(
input
=
tip2
,
num_filters
=
255
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
"conv93_weights"
),
bias_attr
=
ParamAttr
(
name
=
"conv93_bias"
))
self
.
outputs
.
append
(
scale2_out
)
route2
=
conv_bn_layer
(
input
=
route2
,
ch_out
=
128
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
i
=
96
)
# upsample
route2
=
upsample
(
route2
)
# concat
route2
=
fluid
.
layers
.
concat
(
input
=
[
route2
,
scale3
],
axis
=
1
)
# 52*52 scale output
route3
,
tip3
=
yolo_detection_block
(
route2
,
channel
=
128
,
i
=
99
)
# scale3 output
scale3_out
=
fluid
.
layers
.
conv2d
(
input
=
tip3
,
num_filters
=
255
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
"conv105_weights"
),
bias_attr
=
ParamAttr
(
name
=
"conv105_bias"
))
self
.
outputs
.
append
(
scale3_out
)
# yolo
anchor_mask
=
[
6
,
7
,
8
,
3
,
4
,
5
,
0
,
1
,
2
]
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
,
30
,
61
,
62
,
45
,
59
,
119
,
116
,
90
,
156
,
198
,
373
,
326
]
for
i
,
out
in
enumerate
(
self
.
outputs
):
mask
=
anchor_mask
[
i
*
3
:
(
i
+
1
)
*
3
]
mask_anchors
=
[]
for
m
in
mask
:
mask_anchors
.
append
(
anchors
[
2
*
m
])
mask_anchors
.
append
(
anchors
[
2
*
m
+
1
])
self
.
yolo_anchors
.
append
(
mask_anchors
)
class_num
=
int
(
self
.
class_num
)
self
.
yolo_classes
.
append
(
class_num
)
if
self
.
is_train
:
ignore_thresh
=
float
(
self
.
ignore_thresh
)
loss
=
fluid
.
layers
.
yolov3_loss
(
x
=
out
,
gtbox
=
self
.
gtbox
,
gtlabel
=
self
.
gtlabel
,
# gtscore=self.gtscore,
anchors
=
anchors
,
anchor_mask
=
mask
,
class_num
=
class_num
,
ignore_thresh
=
ignore_thresh
,
downsample_ratio
=
self
.
downsample
,
# use_label_smooth=False,
name
=
"yolo_loss"
+
str
(
i
))
self
.
losses
.
append
(
fluid
.
layers
.
reduce_mean
(
loss
))
else
:
boxes
,
scores
=
fluid
.
layers
.
yolo_box
(
x
=
out
,
img_size
=
self
.
im_shape
,
anchors
=
mask_anchors
,
class_num
=
class_num
,
conf_thresh
=
cfg
.
valid_thresh
,
downsample_ratio
=
self
.
downsample
,
name
=
"yolo_box"
+
str
(
i
))
self
.
boxes
.
append
(
boxes
)
self
.
scores
.
append
(
fluid
.
layers
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
self
.
downsample
//=
2
def
loss
(
self
):
return
sum
(
self
.
losses
)
def
get_pred
(
self
):
# return self.outputs
yolo_boxes
=
fluid
.
layers
.
concat
(
self
.
boxes
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
self
.
scores
,
axis
=
2
)
return
fluid
.
layers
.
multiclass_nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
,
score_threshold
=
cfg
.
valid_thresh
,
nms_top_k
=
cfg
.
nms_topk
,
keep_top_k
=
cfg
.
nms_posk
,
nms_threshold
=
cfg
.
nms_thresh
,
background_label
=-
1
,
name
=
"multiclass_nms"
)
def
get_yolo_anchors
(
self
):
return
self
.
yolo_anchors
def
get_yolo_classes
(
self
):
return
self
.
yolo_classes
def
build_input
(
self
):
self
.
image_shape
=
[
3
,
self
.
img_height
,
self
.
img_width
]
if
self
.
use_pyreader
and
self
.
is_train
:
self
.
py_reader
=
fluid
.
layers
.
py_reader
(
capacity
=
64
,
shapes
=
[[
-
1
]
+
self
.
image_shape
,
[
-
1
,
cfg
.
max_box_num
,
4
],
[
-
1
,
cfg
.
max_box_num
],
[
-
1
,
cfg
.
max_box_num
]],
lod_levels
=
[
0
,
0
,
0
,
0
],
dtypes
=
[
'float32'
]
*
2
+
[
'int32'
]
+
[
'float32'
],
use_double_buffer
=
True
)
self
.
image
,
self
.
gtbox
,
self
.
gtlabel
,
self
.
gtscore
=
fluid
.
layers
.
read_file
(
self
.
py_reader
)
else
:
self
.
image
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
self
.
image_shape
,
dtype
=
'float32'
)
self
.
gtbox
=
fluid
.
layers
.
data
(
name
=
'gtbox'
,
shape
=
[
cfg
.
max_box_num
,
4
],
dtype
=
'float32'
)
self
.
gtlabel
=
fluid
.
layers
.
data
(
name
=
'gtlabel'
,
shape
=
[
cfg
.
max_box_num
],
dtype
=
'int32'
)
self
.
gtscore
=
fluid
.
layers
.
data
(
name
=
'gtscore'
,
shape
=
[
cfg
.
max_box_num
],
dtype
=
'float32'
)
self
.
im_shape
=
fluid
.
layers
.
data
(
name
=
"im_shape"
,
shape
=
[
2
],
dtype
=
'int32'
)
self
.
im_id
=
fluid
.
layers
.
data
(
name
=
"im_id"
,
shape
=
[
1
],
dtype
=
'int32'
)
def
feeds
(
self
):
if
not
self
.
is_train
:
return
[
self
.
image
,
self
.
im_id
,
self
.
im_shape
]
return
[
self
.
image
,
self
.
gtbox
,
self
.
gtlabel
,
self
.
gtscore
]
def
get_input_size
(
self
):
return
cfg
.
input_size
fluid/PaddleCV/yolov3/reader.py
浏览文件 @
9da63d61
...
@@ -28,7 +28,7 @@ import box_utils
...
@@ -28,7 +28,7 @@ import box_utils
import
image_utils
import
image_utils
from
pycocotools.coco
import
COCO
from
pycocotools.coco
import
COCO
from
data_utils
import
GeneratorEnqueuer
from
data_utils
import
GeneratorEnqueuer
from
config
.config
import
cfg
from
config
import
cfg
class
DataSetReader
(
object
):
class
DataSetReader
(
object
):
...
...
fluid/PaddleCV/yolov3/train.py
浏览文件 @
9da63d61
...
@@ -26,9 +26,9 @@ from utility import parse_args, print_arguments, SmoothedValue
...
@@ -26,9 +26,9 @@ from utility import parse_args, print_arguments, SmoothedValue
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
reader
import
reader
import
models
import
models
.yolov3
as
models
from
learning_rate
import
exponential_with_warmup_decay
from
learning_rate
import
exponential_with_warmup_decay
from
config
.config
import
cfg
from
config
import
cfg
def
train
():
def
train
():
...
@@ -48,12 +48,14 @@ def train():
...
@@ -48,12 +48,14 @@ def train():
loss
=
model
.
loss
()
loss
=
model
.
loss
()
loss
.
persistable
=
True
loss
.
persistable
=
True
hyperparams
=
model
.
get_hyperparams
()
print
(
"cfg.learning"
,
cfg
.
learning_rate
)
print
(
"cfg.decay"
,
cfg
.
decay
)
devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
""
devices
=
os
.
getenv
(
"CUDA_VISIBLE_DEVICES"
)
or
""
devices_num
=
len
(
devices
.
split
(
","
))
devices_num
=
len
(
devices
.
split
(
","
))
print
(
"Found {} CUDA devices."
.
format
(
devices_num
))
print
(
"Found {} CUDA devices."
.
format
(
devices_num
))
learning_rate
=
float
(
hyperparams
[
'learning_rate'
]
)
learning_rate
=
float
(
cfg
.
learning_rate
)
boundaries
=
cfg
.
lr_steps
boundaries
=
cfg
.
lr_steps
gamma
=
cfg
.
lr_gamma
gamma
=
cfg
.
lr_gamma
step_num
=
len
(
cfg
.
lr_steps
)
step_num
=
len
(
cfg
.
lr_steps
)
...
@@ -70,8 +72,8 @@ def train():
...
@@ -70,8 +72,8 @@ def train():
warmup_iter
=
cfg
.
warm_up_iter
,
warmup_iter
=
cfg
.
warm_up_iter
,
warmup_factor
=
cfg
.
warm_up_factor
,
warmup_factor
=
cfg
.
warm_up_factor
,
start_step
=
cfg
.
start_iter
),
start_step
=
cfg
.
start_iter
),
regularization
=
fluid
.
regularizer
.
L2Decay
(
float
(
hyperparams
[
'decay'
]
)),
regularization
=
fluid
.
regularizer
.
L2Decay
(
float
(
cfg
.
decay
)),
momentum
=
float
(
hyperparams
[
'momentum'
]
))
momentum
=
float
(
cfg
.
momentum
))
optimizer
.
minimize
(
loss
)
optimizer
.
minimize
(
loss
)
fluid
.
memory_optimize
(
fluid
.
default_main_program
())
fluid
.
memory_optimize
(
fluid
.
default_main_program
())
...
@@ -96,11 +98,11 @@ def train():
...
@@ -96,11 +98,11 @@ def train():
mixup_iter
=
cfg
.
max_iter
-
cfg
.
start_iter
-
cfg
.
no_mixup_iter
mixup_iter
=
cfg
.
max_iter
-
cfg
.
start_iter
-
cfg
.
no_mixup_iter
if
cfg
.
use_pyreader
:
if
cfg
.
use_pyreader
:
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
]
)
/
devices_num
,
shuffle
=
True
,
mixup_iter
=
mixup_iter
*
devices_num
,
random_sizes
=
random_sizes
,
interval
=
10
,
pyreader_num
=
devices_num
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
cfg
.
batch
)
/
devices_num
,
shuffle
=
True
,
mixup_iter
=
mixup_iter
*
devices_num
,
random_sizes
=
random_sizes
,
interval
=
10
,
pyreader_num
=
devices_num
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
py_reader
=
model
.
py_reader
py_reader
=
model
.
py_reader
py_reader
.
decorate_paddle_reader
(
train_reader
)
py_reader
.
decorate_paddle_reader
(
train_reader
)
else
:
else
:
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
hyperparams
[
'batch'
]
),
shuffle
=
True
,
mixup_iter
=
mixup_iter
,
random_sizes
=
random_sizes
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
train_reader
=
reader
.
train
(
input_size
,
batch_size
=
int
(
cfg
.
batch
),
shuffle
=
True
,
mixup_iter
=
mixup_iter
,
random_sizes
=
random_sizes
,
use_multiprocessing
=
cfg
.
use_multiprocess
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
model
.
feeds
())
def
save_model
(
postfix
):
def
save_model
(
postfix
):
...
...
fluid/PaddleCV/yolov3/utility.py
浏览文件 @
9da63d61
...
@@ -26,7 +26,7 @@ from collections import deque
...
@@ -26,7 +26,7 @@ from collections import deque
from
paddle.fluid
import
core
from
paddle.fluid
import
core
import
argparse
import
argparse
import
functools
import
functools
from
config
.config
import
*
from
config
import
*
def
print_arguments
(
args
):
def
print_arguments
(
args
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录