Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
d6e52889
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6e52889
编写于
9月 19, 2019
作者:
W
wangguanzhong
提交者:
GitHub
9月 19, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reconstruction for cfg printer (#3362)
上级
81fd696e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
84 addition
and
85 deletion
+84
-85
ppdet/utils/cli.py
ppdet/utils/cli.py
+74
-3
tools/configure.py
tools/configure.py
+1
-71
tools/eval.py
tools/eval.py
+1
-1
tools/infer.py
tools/infer.py
+1
-1
tools/train.py
tools/train.py
+7
-9
未找到文件。
ppdet/utils/cli.py
浏览文件 @
d6e52889
...
...
@@ -15,6 +15,8 @@
from
argparse
import
ArgumentParser
,
RawDescriptionHelpFormatter
import
yaml
import
re
from
ppdet.core.workspace
import
get_registered_modules
__all__
=
[
'ColorTTY'
,
'ArgsParser'
]
...
...
@@ -42,13 +44,12 @@ class ColorTTY(object):
class
ArgsParser
(
ArgumentParser
):
def
__init__
(
self
):
super
(
ArgsParser
,
self
).
__init__
(
formatter_class
=
RawDescriptionHelpFormatter
)
self
.
add_argument
(
"-c"
,
"--config"
,
help
=
"configuration file to use"
)
self
.
add_argument
(
"-o"
,
"--opt"
,
nargs
=
'*'
,
help
=
"set configuration options"
)
self
.
add_argument
(
"-o"
,
"--opt"
,
nargs
=
'*'
,
help
=
"set configuration options"
)
def
parse_args
(
self
,
argv
=
None
):
args
=
super
(
ArgsParser
,
self
).
parse_args
(
argv
)
...
...
@@ -78,3 +79,73 @@ class ArgsParser(ArgumentParser):
cur
[
key
]
=
{}
cur
=
cur
[
key
]
return
config
def
print_total_cfg
(
config
):
modules
=
get_registered_modules
()
color_tty
=
ColorTTY
()
green
=
'___{}___'
.
format
(
color_tty
.
colors
.
index
(
'green'
)
+
31
)
styled
=
{}
for
key
in
config
.
keys
():
if
not
config
[
key
]:
# empty schema
continue
if
key
not
in
modules
and
not
hasattr
(
config
[
key
],
'__dict__'
):
styled
[
key
]
=
config
[
key
]
continue
elif
key
in
modules
:
module
=
modules
[
key
]
else
:
type_name
=
type
(
config
[
key
]).
__name__
if
type_name
in
modules
:
module
=
modules
[
type_name
].
copy
()
module
.
update
({
k
:
v
for
k
,
v
in
config
[
key
].
__dict__
.
items
()
if
k
in
module
.
schema
})
key
+=
" ({})"
.
format
(
type_name
)
default
=
module
.
find_default_keys
()
missing
=
module
.
find_missing_keys
()
mismatch
=
module
.
find_mismatch_keys
()
extra
=
module
.
find_extra_keys
()
dep_missing
=
[]
for
dep
in
module
.
inject
:
if
isinstance
(
module
[
dep
],
str
)
and
module
[
dep
]
!=
'<value>'
:
if
module
[
dep
]
not
in
modules
:
# not a valid module
dep_missing
.
append
(
dep
)
else
:
dep_mod
=
modules
[
module
[
dep
]]
# empty dict but mandatory
if
not
dep_mod
and
dep_mod
.
mandatory
():
dep_missing
.
append
(
dep
)
override
=
list
(
set
(
module
.
keys
())
-
set
(
default
)
-
set
(
extra
)
-
set
(
dep_missing
))
replacement
=
{}
for
name
in
set
(
override
+
default
+
extra
+
mismatch
+
missing
):
new_name
=
name
if
name
in
missing
:
value
=
"<missing>"
else
:
value
=
module
[
name
]
if
name
in
extra
:
value
=
dump_value
(
value
)
+
" <extraneous>"
elif
name
in
mismatch
:
value
=
dump_value
(
value
)
+
" <type mismatch>"
elif
name
in
dep_missing
:
value
=
dump_value
(
value
)
+
" <module config missing>"
elif
name
in
override
and
value
!=
'<missing>'
:
mark
=
green
new_name
=
mark
+
name
replacement
[
new_name
]
=
value
styled
[
key
]
=
replacement
buffer
=
yaml
.
dump
(
styled
,
default_flow_style
=
False
,
default_style
=
''
)
buffer
=
(
re
.
sub
(
r
"<missing>"
,
r
"[31m<missing>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<extraneous>"
,
r
"[33m<extraneous>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<type mismatch>"
,
r
"[31m<type mismatch>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<module config missing>"
,
r
"[31m<module config missing>[0m"
,
buffer
))
buffer
=
re
.
sub
(
r
"___(\d+)___(.*?):"
,
r
"[\1m\2[0m:"
,
buffer
)
print
(
buffer
)
tools/configure.py
浏览文件 @
d6e52889
...
...
@@ -14,14 +14,13 @@
from
__future__
import
print_function
import
re
import
sys
from
argparse
import
ArgumentParser
,
RawDescriptionHelpFormatter
import
yaml
from
ppdet.core.workspace
import
get_registered_modules
,
load_config
from
ppdet.utils.cli
import
ColorTTY
from
ppdet.utils.cli
import
ColorTTY
,
print_total_cfg
color_tty
=
ColorTTY
()
...
...
@@ -151,75 +150,6 @@ def generate_config(**kwargs):
print
(
dump_config
(
s
,
minimal
))
def
print_total_cfg
(
config
):
modules
=
get_registered_modules
()
green
=
'___{}___'
.
format
(
color_tty
.
colors
.
index
(
'green'
)
+
31
)
styled
=
{}
for
key
in
config
.
keys
():
if
not
config
[
key
]:
# empty schema
continue
if
key
not
in
modules
and
not
hasattr
(
config
[
key
],
'__dict__'
):
styled
[
key
]
=
config
[
key
]
continue
elif
key
in
modules
:
module
=
modules
[
key
]
else
:
type_name
=
type
(
config
[
key
]).
__name__
if
type_name
in
modules
:
module
=
modules
[
type_name
].
copy
()
module
.
update
({
k
:
v
for
k
,
v
in
config
[
key
].
__dict__
.
items
()
if
k
in
module
.
schema
})
key
+=
" ({})"
.
format
(
type_name
)
default
=
module
.
find_default_keys
()
missing
=
module
.
find_missing_keys
()
mismatch
=
module
.
find_mismatch_keys
()
extra
=
module
.
find_extra_keys
()
dep_missing
=
[]
for
dep
in
module
.
inject
:
if
isinstance
(
module
[
dep
],
str
)
and
module
[
dep
]
!=
'<value>'
:
if
module
[
dep
]
not
in
modules
:
# not a valid module
dep_missing
.
append
(
dep
)
else
:
dep_mod
=
modules
[
module
[
dep
]]
# empty dict but mandatory
if
not
dep_mod
and
dep_mod
.
mandatory
():
dep_missing
.
append
(
dep
)
override
=
list
(
set
(
module
.
keys
())
-
set
(
default
)
-
set
(
extra
)
-
set
(
dep_missing
))
replacement
=
{}
for
name
in
set
(
override
+
default
+
extra
+
mismatch
+
missing
):
new_name
=
name
if
name
in
missing
:
value
=
"<missing>"
else
:
value
=
module
[
name
]
if
name
in
extra
:
value
=
dump_value
(
value
)
+
" <extraneous>"
elif
name
in
mismatch
:
value
=
dump_value
(
value
)
+
" <type mismatch>"
elif
name
in
dep_missing
:
value
=
dump_value
(
value
)
+
" <module config missing>"
elif
name
in
override
and
value
!=
'<missing>'
:
mark
=
green
new_name
=
mark
+
name
replacement
[
new_name
]
=
value
styled
[
key
]
=
replacement
buffer
=
yaml
.
dump
(
styled
,
default_flow_style
=
False
,
default_style
=
''
)
buffer
=
(
re
.
sub
(
r
"<missing>"
,
r
"[31m<missing>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<extraneous>"
,
r
"[33m<extraneous>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<type mismatch>"
,
r
"[31m<type mismatch>[0m"
,
buffer
))
buffer
=
(
re
.
sub
(
r
"<module config missing>"
,
r
"[31m<module config missing>[0m"
,
buffer
))
buffer
=
re
.
sub
(
r
"___(\d+)___(.*?):"
,
r
"[\1m\2[0m:"
,
buffer
)
print
(
buffer
)
# FIXME this is pretty hackish, maybe implement a custom YAML printer?
def
analyze_config
(
**
kwargs
):
config
=
load_config
(
kwargs
[
'file'
])
...
...
tools/eval.py
浏览文件 @
d6e52889
...
...
@@ -33,7 +33,7 @@ set_paddle_flags(
import
paddle.fluid
as
fluid
from
tools.configure
import
print_total_cfg
from
ppdet.utils.cli
import
print_total_cfg
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_run
,
eval_results
,
json_eval_results
import
ppdet.utils.checkpoint
as
checkpoint
from
ppdet.utils.cli
import
ArgsParser
...
...
tools/infer.py
浏览文件 @
d6e52889
...
...
@@ -37,7 +37,7 @@ set_paddle_flags(
from
paddle
import
fluid
from
tools.configure
import
print_total_cfg
from
ppdet.utils.cli
import
print_total_cfg
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.modeling.model_input
import
create_feed
from
ppdet.data.data_feed
import
create_reader
...
...
tools/train.py
浏览文件 @
d6e52889
...
...
@@ -21,7 +21,6 @@ import time
import
numpy
as
np
import
datetime
from
collections
import
deque
from
tools.configure
import
print_total_cfg
def
set_paddle_flags
(
**
kwargs
):
...
...
@@ -40,6 +39,7 @@ from paddle import fluid
from
ppdet.core.workspace
import
load_config
,
merge_config
,
create
from
ppdet.data.data_feed
import
create_reader
from
ppdet.utils.cli
import
print_total_cfg
from
ppdet.utils.eval_utils
import
parse_fetches
,
eval_run
,
eval_results
from
ppdet.utils.stats
import
TrainingStats
from
ppdet.utils.cli
import
ArgsParser
...
...
@@ -142,7 +142,8 @@ def main():
train_compile_program
=
fluid
.
compiler
.
CompiledProgram
(
train_prog
).
with_data_parallel
(
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
loss_name
=
loss
.
name
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
if
FLAGS
.
eval
:
eval_compile_program
=
fluid
.
compiler
.
CompiledProgram
(
eval_prog
)
...
...
@@ -159,13 +160,10 @@ def main():
elif
cfg
.
pretrain_weights
:
checkpoint
.
load_pretrain
(
exe
,
train_prog
,
cfg
.
pretrain_weights
)
train_reader
=
create_reader
(
train_feed
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
FLAGS
.
dataset_dir
)
train_reader
=
create_reader
(
train_feed
,
(
cfg
.
max_iters
-
start_iter
)
*
devices_num
,
FLAGS
.
dataset_dir
)
train_pyreader
.
decorate_sample_list_generator
(
train_reader
,
place
)
# whether output bbox is normalized in model output layer
is_bbox_normalized
=
False
if
hasattr
(
model
,
'is_bbox_normalized'
)
and
\
...
...
@@ -230,12 +228,12 @@ def main():
box_ap_stats
=
eval_results
(
results
,
eval_feed
,
cfg
.
metric
,
cfg
.
num_classes
,
resolution
,
is_bbox_normalized
,
FLAGS
.
output_eval
,
map_type
)
# use tb_paddle to log mAP
if
FLAGS
.
use_tb
:
tb_writer
.
add_scalar
(
"mAP"
,
box_ap_stats
[
0
],
tb_mAP_step
)
tb_mAP_step
+=
1
if
box_ap_stats
[
0
]
>
best_box_ap_list
[
0
]:
best_box_ap_list
[
0
]
=
box_ap_stats
[
0
]
best_box_ap_list
[
1
]
=
it
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录