Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
b735e353
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b735e353
编写于
9月 13, 2020
作者:
D
dengkaipeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update yolov3 use new API. test=develop
上级
931564e5
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
55 addition
and
66 deletion
+55
-66
yolov3/darknet.py
yolov3/darknet.py
+6
-5
yolov3/infer.py
yolov3/infer.py
+6
-12
yolov3/main.py
yolov3/main.py
+7
-26
yolov3/modeling.py
yolov3/modeling.py
+36
-23
未找到文件。
yolov3/darknet.py
浏览文件 @
b735e353
...
...
@@ -12,14 +12,14 @@
#See the License for the specific language governing permissions and
#limitations under the License.
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.static
import
InputSpec
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
paddle.incubate.hapi.model
import
Model
from
paddle.incubate.hapi.download
import
get_weights_path_from_url
from
paddle.utils.download
import
get_weights_path_from_url
__all__
=
[
'DarkNet'
,
'darknet53'
]
...
...
@@ -131,7 +131,7 @@ class LayerWarp(fluid.dygraph.Layer):
DarkNet_cfg
=
{
53
:
([
1
,
2
,
8
,
8
,
4
])}
class
DarkNet
(
Model
):
class
DarkNet
(
fluid
.
dygraph
.
Layer
):
"""DarkNet model from
`"YOLOv3: An Incremental Improvement" <https://arxiv.org/abs/1804.02767>`_
...
...
@@ -190,7 +190,8 @@ def _darknet(num_layers=53, input_channels=3, pretrained=True):
weight_path
=
get_weights_path_from_url
(
*
(
pretrain_infos
[
num_layers
]))
assert
weight_path
.
endswith
(
'.pdparams'
),
\
"suffix of weight must be .pdparams"
model
.
load
(
weight_path
[:
-
9
])
weight_dict
,
_
=
fluid
.
load_dygraph
(
weight_path
[:
-
9
])
model
.
set_dict
(
weight_dict
)
return
model
...
...
yolov3/infer.py
浏览文件 @
b735e353
...
...
@@ -20,12 +20,11 @@ import argparse
import
numpy
as
np
from
PIL
import
Image
import
paddle
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
Momentum
from
paddle.io
import
DataLoader
from
paddle.incubate.hapi.model
import
Model
,
Input
,
set_device
from
modeling
import
yolov3_darknet53
,
YoloLoss
from
transforms
import
*
from
utils
import
print_arguments
...
...
@@ -36,6 +35,7 @@ logger = logging.getLogger(__name__)
IMAGE_MEAN
=
[
0.485
,
0.456
,
0.406
]
IMAGE_STD
=
[
0.229
,
0.224
,
0.225
]
NUM_MAX_BOXES
=
50
def
get_save_image_name
(
output_dir
,
image_path
):
...
...
@@ -62,24 +62,18 @@ def load_labels(label_list, with_background=True):
def
main
():
device
=
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
device
)
if
FLAGS
.
dynamic
else
None
inputs
=
[
Input
(
[
None
,
1
],
'int64'
,
name
=
'img_id'
),
Input
(
[
None
,
2
],
'int32'
,
name
=
'img_shape'
),
Input
(
[
None
,
3
,
None
,
None
],
'float32'
,
name
=
'image'
)
]
device
=
paddle
.
set_device
(
FLAGS
.
device
)
paddle
.
disable_static
(
device
)
if
FLAGS
.
dynamic
else
None
cat2name
=
load_labels
(
FLAGS
.
label_list
,
with_background
=
False
)
model
=
yolov3_darknet53
(
num_classes
=
len
(
cat2name
),
num_max_boxes
=
NUM_MAX_BOXES
,
model_mode
=
'test'
,
pretrained
=
FLAGS
.
weights
is
None
)
model
.
prepare
(
inputs
=
inputs
,
device
=
FLAGS
.
device
)
model
.
prepare
()
if
FLAGS
.
weights
is
not
None
:
model
.
load
(
FLAGS
.
weights
,
reset_optimizer
=
True
)
...
...
yolov3/main.py
浏览文件 @
b735e353
...
...
@@ -21,13 +21,11 @@ import os
import
numpy
as
np
import
paddle
from
paddle
import
fluid
from
paddle.fluid.optimizer
import
Momentum
from
paddle.io
import
DataLoader
from
paddle.incubate.hapi.model
import
Model
,
Input
,
set_device
from
paddle.incubate.hapi.distributed
import
DistributedBatchSampler
from
paddle.incubate.hapi.vision.transforms
import
Compose
,
BatchCompose
from
paddle.io
import
DataLoader
,
DistributedBatchSampler
from
paddle.vision.transforms
import
Compose
,
BatchCompose
from
modeling
import
yolov3_darknet53
,
YoloLoss
from
coco
import
COCODataset
...
...
@@ -61,22 +59,8 @@ def make_optimizer(step_per_epoch, parameter_list=None):
def
main
():
device
=
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
device
)
if
FLAGS
.
dynamic
else
None
inputs
=
[
Input
(
[
None
,
1
],
'int64'
,
name
=
'img_id'
),
Input
(
[
None
,
2
],
'int32'
,
name
=
'img_shape'
),
Input
(
[
None
,
3
,
None
,
None
],
'float32'
,
name
=
'image'
)
]
labels
=
[
Input
(
[
None
,
NUM_MAX_BOXES
,
4
],
'float32'
,
name
=
'gt_bbox'
),
Input
(
[
None
,
NUM_MAX_BOXES
],
'int32'
,
name
=
'gt_label'
),
Input
(
[
None
,
NUM_MAX_BOXES
],
'float32'
,
name
=
'gt_score'
)
]
device
=
paddle
.
set_device
(
FLAGS
.
device
)
paddle
.
disable_static
(
device
)
if
FLAGS
.
dynamic
else
None
if
not
FLAGS
.
eval_only
:
# training mode
train_transform
=
Compose
([
...
...
@@ -129,6 +113,7 @@ def main():
pretrained
=
FLAGS
.
eval_only
and
FLAGS
.
weights
is
None
model
=
yolov3_darknet53
(
num_classes
=
dataset
.
num_classes
,
num_max_boxes
=
NUM_MAX_BOXES
,
model_mode
=
'eval'
if
FLAGS
.
eval_only
else
'train'
,
pretrained
=
pretrained
)
...
...
@@ -140,11 +125,7 @@ def main():
len
(
batch_sampler
),
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim
,
YoloLoss
(
num_classes
=
dataset
.
num_classes
),
inputs
=
inputs
,
labels
=
labels
,
device
=
FLAGS
.
device
)
optimizer
=
optim
,
loss
=
YoloLoss
(
num_classes
=
dataset
.
num_classes
))
# NOTE: we implement COCO metric of YOLOv3 model here, separately
# from 'prepare' and 'fit' framework for follwing reason:
...
...
yolov3/modeling.py
浏览文件 @
b735e353
...
...
@@ -15,14 +15,15 @@
from
__future__
import
division
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.
incubate.hapi.model
import
Model
from
paddle.
incubate.hapi.loss
import
Loss
from
paddle.incubate.hapi.download
import
get_weights_path_from_url
from
paddle.
static
import
InputSpec
from
paddle.
utils.download
import
get_weights_path_from_url
from
darknet
import
darknet53
__all__
=
[
'YoloLoss'
,
'YOLOv3'
,
'yolov3_darknet53'
]
...
...
@@ -125,7 +126,7 @@ class YoloDetectionBlock(fluid.dygraph.Layer):
return
route
,
tip
class
YOLOv3
(
Model
):
class
YOLOv3
(
fluid
.
dygraph
.
Layer
):
"""YOLOv3 model from
`"YOLOv3: An Incremental Improvement" <https://arxiv.org/abs/1804.02767>`_
...
...
@@ -194,25 +195,13 @@ class YOLOv3(Model):
act
=
'leaky_relu'
))
self
.
route_blocks
.
append
(
route
)
def
extract_feats
(
self
,
inputs
):
out
=
self
.
backbone
.
conv0
(
inputs
)
out
=
self
.
backbone
.
downsample0
(
out
)
blocks
=
[]
for
i
,
conv_block_i
in
enumerate
(
self
.
backbone
.
darknet53_conv_block_list
):
out
=
conv_block_i
(
out
)
blocks
.
append
(
out
)
if
i
<
len
(
self
.
backbone
.
stages
)
-
1
:
out
=
self
.
backbone
.
downsample_list
[
i
](
out
)
return
blocks
[
-
1
:
-
4
:
-
1
]
def
forward
(
self
,
img_id
,
img_shape
,
inputs
):
outputs
=
[]
boxes
=
[]
scores
=
[]
downsample
=
32
feats
=
self
.
extract_feats
(
inputs
)
feats
=
self
.
backbone
(
inputs
)
route
=
None
for
idx
,
feat
in
enumerate
(
feats
):
if
idx
>
0
:
...
...
@@ -267,7 +256,7 @@ class YOLOv3(Model):
return
outputs
+
preds
class
YoloLoss
(
Loss
):
class
YoloLoss
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_classes
=
80
,
num_max_boxes
=
50
):
super
(
YoloLoss
,
self
).
__init__
()
self
.
num_classes
=
num_classes
...
...
@@ -279,11 +268,16 @@ class YoloLoss(Loss):
]
self
.
anchor_masks
=
[[
6
,
7
,
8
],
[
3
,
4
,
5
],
[
0
,
1
,
2
]]
def
forward
(
self
,
outputs
,
label
s
):
def
forward
(
self
,
*
input
s
):
downsample
=
32
gt_box
,
gt_label
,
gt_score
=
labels
losses
=
[]
if
len
(
inputs
)
==
6
:
output1
,
output2
,
output3
,
gt_box
,
gt_label
,
gt_score
=
inputs
elif
len
(
inputs
)
==
8
:
output1
,
output2
,
output3
,
img_id
,
bbox
,
gt_box
,
gt_label
,
gt_score
=
inputs
outputs
=
[
output1
,
output2
,
output3
]
for
idx
,
out
in
enumerate
(
outputs
):
if
idx
==
3
:
break
# debug
anchor_mask
=
self
.
anchor_masks
[
idx
]
...
...
@@ -306,9 +300,23 @@ class YoloLoss(Loss):
def
_yolov3_darknet
(
num_layers
=
53
,
num_classes
=
80
,
num_max_boxes
=
50
,
model_mode
=
'train'
,
pretrained
=
True
):
model
=
YOLOv3
(
num_classes
,
model_mode
)
inputs
=
[
InputSpec
(
[
None
,
1
],
'int64'
,
name
=
'img_id'
),
InputSpec
(
[
None
,
2
],
'int32'
,
name
=
'img_shape'
),
InputSpec
(
[
None
,
3
,
None
,
None
],
'float32'
,
name
=
'image'
)
]
labels
=
[
InputSpec
(
[
None
,
num_max_boxes
,
4
],
'float32'
,
name
=
'gt_bbox'
),
InputSpec
(
[
None
,
num_max_boxes
],
'int32'
,
name
=
'gt_label'
),
InputSpec
(
[
None
,
num_max_boxes
],
'float32'
,
name
=
'gt_score'
)
]
net
=
YOLOv3
(
num_classes
,
model_mode
)
model
=
paddle
.
Model
(
net
,
inputs
,
labels
)
if
pretrained
:
assert
num_layers
in
pretrain_infos
.
keys
(),
\
"YOLOv3-DarkNet{} do not have pretrained weights now, "
\
...
...
@@ -320,11 +328,15 @@ def _yolov3_darknet(num_layers=53,
return
model
def
yolov3_darknet53
(
num_classes
=
80
,
model_mode
=
'train'
,
pretrained
=
True
):
def
yolov3_darknet53
(
num_classes
=
80
,
num_max_boxes
=
50
,
model_mode
=
'train'
,
pretrained
=
True
):
"""YOLOv3 model with 53-layer DarkNet as backbone
Args:
num_classes (int): class number, default 80.
num_classes (int): max bbox number in a image, default 50.
model_mode (str): 'train', 'eval', 'test' mode, network structure
will be diffrent in the output layer and data, in 'train' mode,
no output layer append, in 'eval' and 'test', output feature
...
...
@@ -334,4 +346,5 @@ def yolov3_darknet53(num_classes=80, model_mode='train', pretrained=True):
pretrained (bool): If True, returns a model with pre-trained model
on COCO, default True
"""
return
_yolov3_darknet
(
53
,
num_classes
,
model_mode
,
pretrained
)
return
_yolov3_darknet
(
53
,
num_classes
,
num_max_boxes
,
model_mode
,
pretrained
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录