Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
7d3a89f6
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7d3a89f6
编写于
9月 27, 2020
作者:
W
wangxinxin08
提交者:
GitHub
9月 27, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
migrate yolov3 tp API 2.0 (#1500)
上级
079c83c7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
92 addition
and
85 deletion
+92
-85
ppdet/core/workspace.py
ppdet/core/workspace.py
+1
-1
ppdet/modeling/architecture/meta_arch.py
ppdet/modeling/architecture/meta_arch.py
+5
-5
ppdet/modeling/backbone/darknet.py
ppdet/modeling/backbone/darknet.py
+21
-24
ppdet/modeling/bbox.py
ppdet/modeling/bbox.py
+6
-3
ppdet/modeling/head/yolo_head.py
ppdet/modeling/head/yolo_head.py
+14
-17
ppdet/optimizer.py
ppdet/optimizer.py
+12
-10
ppdet/utils/check.py
ppdet/utils/check.py
+8
-4
tools/eval.py
tools/eval.py
+11
-11
tools/train.py
tools/train.py
+14
-10
未找到文件。
ppdet/core/workspace.py
浏览文件 @
7d3a89f6
...
...
@@ -248,7 +248,7 @@ def create(cls_or_name, **kwargs):
if
isinstance
(
target
,
SchemaDict
):
kwargs
[
k
]
=
create
(
target_key
)
elif
hasattr
(
target
,
'__dict__'
):
# serialized object
kwargs
[
k
]
=
new_dic
t
kwargs
[
k
]
=
targe
t
else
:
raise
ValueError
(
"Unsupported injection type:"
,
target_key
)
# prevent modification of global config values of reference types
...
...
ppdet/modeling/architecture/meta_arch.py
浏览文件 @
7d3a89f6
...
...
@@ -3,8 +3,8 @@ from __future__ import division
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.dygraph.base
import
to_variable
import
paddle
import
paddle.nn
as
nn
from
ppdet.core.workspace
import
register
from
ppdet.utils.data_structure
import
BufferDict
...
...
@@ -12,7 +12,7 @@ __all__ = ['BaseArch']
@
register
class
BaseArch
(
Layer
):
class
BaseArch
(
nn
.
Layer
):
def
__init__
(
self
):
super
(
BaseArch
,
self
).
__init__
()
...
...
@@ -39,10 +39,10 @@ class BaseArch(Layer):
input_v
=
np
.
array
(
input
)[
np
.
newaxis
,
...]
inputs
[
name
].
append
(
input_v
)
for
name
in
input_def
:
inputs
[
name
]
=
to_variable
(
np
.
concatenate
(
inputs
[
name
]))
inputs
[
name
]
=
paddle
.
to_tensor
(
np
.
concatenate
(
inputs
[
name
]))
return
inputs
def
model_arch
(
self
,
mode
):
def
model_arch
(
self
):
raise
NotImplementedError
(
"Should implement model_arch method!"
)
def
loss
(
self
,
):
...
...
ppdet/modeling/backbone/darknet.py
浏览文件 @
7d3a89f6
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.param_attr
import
ParamAttr
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
ppdet.core.workspace
import
register
,
serializable
__all__
=
[
'DarkNet'
,
'ConvBNLayer'
]
class
ConvBNLayer
(
Layer
):
class
ConvBNLayer
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
...
...
@@ -20,25 +20,22 @@ class ConvBNLayer(Layer):
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
conv
=
Conv2D
(
num
_channels
=
ch_in
,
num_filter
s
=
ch_out
,
filter
_size
=
filter_size
,
self
.
conv
=
nn
.
Conv2d
(
in
_channels
=
ch_in
,
out_channel
s
=
ch_out
,
kernel
_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
,
param_attr
=
ParamAttr
(
name
=
name
+
'.conv.weights'
),
bias_attr
=
False
,
act
=
None
)
weight_attr
=
ParamAttr
(
name
=
name
+
'.conv.weights'
),
bias_attr
=
False
)
bn_name
=
name
+
'.bn'
self
.
batch_norm
=
BatchNorm
(
num_channels
=
ch_out
,
param
_attr
=
ParamAttr
(
self
.
batch_norm
=
nn
.
BatchNorm2d
(
ch_out
,
weight
_attr
=
ParamAttr
(
name
=
bn_name
+
'.scale'
,
regularizer
=
L2Decay
(
0.
)),
bias_attr
=
ParamAttr
(
name
=
bn_name
+
'.offset'
,
regularizer
=
L2Decay
(
0.
)),
moving_mean_name
=
bn_name
+
'.mean'
,
moving_variance_name
=
bn_name
+
'.var'
)
name
=
bn_name
+
'.offset'
,
regularizer
=
L2Decay
(
0.
)))
self
.
act
=
act
...
...
@@ -46,11 +43,11 @@ class ConvBNLayer(Layer):
out
=
self
.
conv
(
inputs
)
out
=
self
.
batch_norm
(
out
)
if
self
.
act
==
'leaky'
:
out
=
fluid
.
layers
.
leaky_relu
(
x
=
out
,
alpha
=
0.1
)
out
=
F
.
leaky_relu
(
out
,
0.1
)
return
out
class
DownSample
(
Layer
):
class
DownSample
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
...
...
@@ -75,7 +72,7 @@ class DownSample(Layer):
return
out
class
BasicBlock
(
Layer
):
class
BasicBlock
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
name
=
None
):
super
(
BasicBlock
,
self
).
__init__
()
...
...
@@ -97,11 +94,11 @@ class BasicBlock(Layer):
def
forward
(
self
,
inputs
):
conv1
=
self
.
conv1
(
inputs
)
conv2
=
self
.
conv2
(
conv1
)
out
=
fluid
.
layers
.
elementwise_add
(
x
=
inputs
,
y
=
conv2
,
act
=
None
)
out
=
paddle
.
add
(
x
=
inputs
,
y
=
conv2
)
return
out
class
Blocks
(
Layer
):
class
Blocks
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
ch_out
,
count
,
name
=
None
):
super
(
Blocks
,
self
).
__init__
()
...
...
@@ -127,7 +124,7 @@ DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
@
register
@
serializable
class
DarkNet
(
Layer
):
class
DarkNet
(
nn
.
Layer
):
def
__init__
(
self
,
depth
=
53
,
freeze_at
=-
1
,
...
...
ppdet/modeling/bbox.py
浏览文件 @
7d3a89f6
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
ppdet.core.workspace
import
register
...
...
@@ -90,9 +93,9 @@ class BBoxPostProcessYOLO(object):
self
.
num_classes
,
i
)
boxes_list
.
append
(
boxes
)
scores_list
.
append
(
fluid
.
layers
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
yolo_boxes
=
fluid
.
layers
.
concat
(
boxes_list
,
axis
=
1
)
yolo_scores
=
fluid
.
layers
.
concat
(
scores_list
,
axis
=
2
)
scores_list
.
append
(
paddle
.
transpose
(
scores
,
perm
=
[
0
,
2
,
1
]))
yolo_boxes
=
paddle
.
concat
(
boxes_list
,
axis
=
1
)
yolo_scores
=
paddle
.
concat
(
scores_list
,
axis
=
2
)
bbox
=
self
.
nms
(
bboxes
=
yolo_boxes
,
scores
=
yolo_scores
)
# TODO: parse the lod of nmsed_bbox
# default batch size is 1
...
...
ppdet/modeling/head/yolo_head.py
浏览文件 @
7d3a89f6
import
paddle.fluid
as
fluid
import
paddle
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle
.fluid.initializer
import
Normal
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle
import
ParamAttr
from
paddle.fluid.regularizer
import
L2Decay
from
paddle.fluid.dygraph.nn
import
Conv2D
,
BatchNorm
from
paddle.fluid.dygraph
import
Sequential
from
ppdet.core.workspace
import
register
from
..backbone.darknet
import
ConvBNLayer
class
YoloDetBlock
(
Layer
):
class
YoloDetBlock
(
nn
.
Layer
):
def
__init__
(
self
,
ch_in
,
channel
,
name
):
super
(
YoloDetBlock
,
self
).
__init__
()
self
.
ch_in
=
ch_in
...
...
@@ -26,7 +24,7 @@ class YoloDetBlock(Layer):
#['tip', channel, channel * 2, 3],
]
self
.
conv_module
=
Sequential
()
self
.
conv_module
=
nn
.
Sequential
()
for
idx
,
(
conv_name
,
ch_in
,
ch_out
,
filter_size
,
post_name
)
in
enumerate
(
conv_def
):
self
.
conv_module
.
add_sublayer
(
...
...
@@ -52,7 +50,7 @@ class YoloDetBlock(Layer):
@
register
class
YOLOFeat
(
Layer
):
class
YOLOFeat
(
nn
.
Layer
):
__shared__
=
[
'num_levels'
]
def
__init__
(
self
,
feat_in_list
=
[
1024
,
768
,
384
],
num_levels
=
3
):
...
...
@@ -88,19 +86,19 @@ class YOLOFeat(Layer):
yolo_feats
=
[]
for
i
,
block
in
enumerate
(
body_feats
):
if
i
>
0
:
block
=
fluid
.
layers
.
concat
(
input
=
[
route
,
block
],
axis
=
1
)
block
=
paddle
.
concat
(
[
route
,
block
],
axis
=
1
)
route
,
tip
=
self
.
yolo_blocks
[
i
](
block
)
yolo_feats
.
append
(
tip
)
if
i
<
self
.
num_levels
-
1
:
route
=
self
.
route_blocks
[
i
](
route
)
route
=
fluid
.
layers
.
resize_nearest
(
route
,
scale
=
2.
)
route
=
F
.
resize_nearest
(
route
,
scale
=
2.
)
return
yolo_feats
@
register
class
YOLOv3Head
(
Layer
):
class
YOLOv3Head
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'num_levels'
,
'use_fine_grained_loss'
]
__inject__
=
[
'yolo_feat'
]
...
...
@@ -130,14 +128,13 @@ class YOLOv3Head(Layer):
name
=
'yolo_output.{}'
.
format
(
i
)
yolo_out
=
self
.
add_sublayer
(
name
,
Conv2D
(
num
_channels
=
1024
//
(
2
**
i
),
num_filter
s
=
num_filters
,
filter
_size
=
1
,
nn
.
Conv2d
(
in
_channels
=
1024
//
(
2
**
i
),
out_channel
s
=
num_filters
,
kernel
_size
=
1
,
stride
=
1
,
padding
=
0
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
'.conv.weights'
),
weight_attr
=
ParamAttr
(
name
=
name
+
'.conv.weights'
),
bias_attr
=
ParamAttr
(
name
=
name
+
'.conv.bias'
,
regularizer
=
L2Decay
(
0.
))))
self
.
yolo_out_list
.
append
(
yolo_out
)
...
...
ppdet/optimizer.py
浏览文件 @
7d3a89f6
...
...
@@ -19,12 +19,12 @@ from __future__ import print_function
import
math
import
logging
from
paddle
import
fluid
import
paddle
import
paddle.nn
as
nn
import
paddle.
fluid.
optimizer
as
optimizer
import
paddle.optimizer
as
optimizer
import
paddle.fluid.regularizer
as
regularizer
from
paddle.fluid.layers.learning_rate_scheduler
import
_decay_step_counter
from
paddle.fluid.layers.ops
import
cos
from
paddle
import
cos
from
ppdet.core.workspace
import
register
,
serializable
...
...
@@ -61,7 +61,7 @@ class PiecewiseDecay(object):
for
i
in
self
.
gamma
:
value
.
append
(
base_lr
*
i
)
return
fluid
.
dygraph
.
PiecewiseDecay
(
boundary
,
value
,
begin
=
0
,
step
=
1
)
return
optimizer
.
lr_scheduler
.
PiecewiseLR
(
boundary
,
value
)
@
serializable
...
...
@@ -142,9 +142,10 @@ class OptimizerBuilder():
def
__call__
(
self
,
learning_rate
,
params
=
None
):
if
self
.
clip_grad_by_norm
is
not
None
:
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
self
.
clip_grad_by_norm
))
grad_clip
=
nn
.
GradientClipByGlobalNorm
(
clip_norm
=
self
.
clip_grad_by_norm
)
else
:
grad_clip
=
None
if
self
.
regularizer
:
reg_type
=
self
.
regularizer
[
'type'
]
+
'Decay'
...
...
@@ -158,6 +159,7 @@ class OptimizerBuilder():
del
optim_args
[
'type'
]
op
=
getattr
(
optimizer
,
optim_type
)
return
op
(
learning_rate
=
learning_rate
,
parameter_list
=
params
,
regularization
=
regularization
,
parameters
=
params
,
weight_decay
=
regularization
,
grad_clip
=
grad_clip
,
**
optim_args
)
ppdet/utils/check.py
浏览文件 @
7d3a89f6
...
...
@@ -18,8 +18,8 @@ from __future__ import print_function
import
sys
import
paddle
.fluid
as
fluid
import
paddle
from
paddle
import
fluid
import
logging
import
six
import
paddle.version
as
fluid_version
...
...
@@ -65,9 +65,13 @@ def check_version(version='1.7.0'):
version_split
=
version
.
split
(
'.'
)
length
=
min
(
len
(
version_installed
),
len
(
version_split
))
flag
=
False
for
i
in
six
.
moves
.
range
(
length
):
if
version_installed
[
i
]
<
version_split
[
i
]:
raise
Exception
(
err
)
if
version_installed
[
i
]
>
version_split
[
i
]:
flag
=
True
break
if
not
flag
:
raise
Exception
(
err
)
def
check_config
(
cfg
):
...
...
tools/eval.py
浏览文件 @
7d3a89f6
...
...
@@ -13,7 +13,8 @@ import warnings
warnings
.
filterwarnings
(
'ignore'
)
import
random
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle
from
paddle.distributed
import
ParallelEnv
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.cli
import
ArgsParser
...
...
@@ -50,10 +51,10 @@ def run(FLAGS, cfg):
main_arch
=
cfg
.
architecture
model
=
create
(
cfg
.
architecture
)
# Init Model
# Init Model
model
=
load_dygraph_ckpt
(
model
,
ckpt
=
cfg
.
weights
)
# Data Reader
# Data Reader
if
FLAGS
.
use_gpu
:
devices_num
=
1
else
:
...
...
@@ -65,12 +66,12 @@ def run(FLAGS, cfg):
start_time
=
time
.
time
()
sample_num
=
0
for
iter_id
,
data
in
enumerate
(
eval_reader
()):
# forward
# forward
model
.
eval
()
outs
=
model
(
data
,
cfg
[
'EvalReader'
][
'inputs_def'
][
'fields'
],
'infer'
)
outs_res
.
append
(
outs
)
# log
# log
sample_num
+=
len
(
data
)
if
iter_id
%
100
==
0
:
logger
.
info
(
"Eval iter: {}"
.
format
(
iter_id
))
...
...
@@ -78,7 +79,7 @@ def run(FLAGS, cfg):
cost_time
=
time
.
time
()
-
start_time
logger
.
info
(
'Total sample number: {}, averge FPS: {}'
.
format
(
sample_num
,
sample_num
/
cost_time
))
# Metric
# Metric
coco_eval_results
(
outs_res
,
include_mask
=
True
if
getattr
(
cfg
,
'MaskHead'
,
None
)
else
False
,
...
...
@@ -94,11 +95,10 @@ def main():
check_gpu
(
cfg
.
use_gpu
)
check_version
()
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
()
.
dev_id
)
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
run
(
FLAGS
,
cfg
)
place
=
paddle
.
CUDAPlace
(
ParallelEnv
()
.
dev_id
)
if
cfg
.
use_gpu
else
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
run
(
FLAGS
,
cfg
)
if
__name__
==
'__main__'
:
...
...
tools/train.py
浏览文件 @
7d3a89f6
...
...
@@ -15,14 +15,15 @@ import random
import
datetime
import
numpy
as
np
from
collections
import
deque
import
paddle.fluid
as
fluid
import
paddle
from
paddle
import
fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.reader
import
create_reader
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.check
import
check_gpu
,
check_version
,
check_config
from
ppdet.utils.cli
import
ArgsParser
from
ppdet.utils.checkpoint
import
load_dygraph_ckpt
,
save_dygraph_ckpt
from
paddle.
fluid.dygraph.parallel
import
ParallelEnv
from
paddle.
distributed
import
ParallelEnv
import
logging
FORMAT
=
'%(asctime)s-%(levelname)s: %(message)s'
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
FORMAT
)
...
...
@@ -117,9 +118,10 @@ def run(FLAGS, cfg):
# Parallel Model
if
ParallelEnv
().
nranks
>
1
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
model
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
model
,
strategy
)
strategy
=
paddle
.
distributed
.
init_parallel_env
()
model
=
paddle
.
DataParallel
(
model
,
strategy
)
logger
.
info
(
"success!"
)
# Data Reader
start_iter
=
0
if
cfg
.
use_gpu
:
...
...
@@ -157,8 +159,10 @@ def run(FLAGS, cfg):
else
:
loss
.
backward
()
optimizer
.
minimize
(
loss
)
model
.
clear_gradients
()
curr_lr
=
optimizer
.
current_step_lr
()
optimizer
.
step
()
curr_lr
=
optimizer
.
get_lr
()
lr
.
step
()
optimizer
.
clear_grad
()
if
ParallelEnv
().
nranks
<
2
or
ParallelEnv
().
local_rank
==
0
:
# Log state
...
...
@@ -190,11 +194,11 @@ def main():
check_gpu
(
cfg
.
use_gpu
)
check_version
()
place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
cfg
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
paddle
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
cfg
.
use_gpu
else
paddle
.
CPUPlace
()
paddle
.
disable_static
(
place
)
with
fluid
.
dygraph
.
guard
(
place
):
run
(
FLAGS
,
cfg
)
run
(
FLAGS
,
cfg
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录