Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
2e0bc41c
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
289
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2e0bc41c
编写于
9月 21, 2020
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Print config
上级
40ed988d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
64 addition
and
70 deletion
+64
-70
dygraph/paddleseg/utils/config.py
dygraph/paddleseg/utils/config.py
+50
-68
dygraph/train.py
dygraph/train.py
+9
-2
dygraph/val.py
dygraph/val.py
+5
-0
未找到文件。
dygraph/paddleseg/utils/config.py
浏览文件 @
2e0bc41c
...
...
@@ -15,7 +15,6 @@
import
codecs
import
os
from
typing
import
Any
,
Callable
import
pprint
import
yaml
import
paddle
...
...
@@ -37,10 +36,10 @@ class Config(object):
if
not
os
.
path
.
exists
(
path
):
raise
FileNotFoundError
(
'File {} does not exist'
.
format
(
path
))
self
.
_model
=
None
self
.
_losses
=
None
if
path
.
endswith
(
'yml'
)
or
path
.
endswith
(
'yaml'
):
dic
=
self
.
_parse_from_yaml
(
path
)
logger
.
info
(
'
\n
'
+
pprint
.
pformat
(
dic
))
self
.
_build
(
dic
)
self
.
dic
=
self
.
_parse_from_yaml
(
path
)
else
:
raise
RuntimeError
(
'Config file should in yaml format!'
)
...
...
@@ -61,6 +60,7 @@ class Config(object):
'''Parse a yaml file and build config'''
with
codecs
.
open
(
path
,
'r'
,
'utf-8'
)
as
file
:
dic
=
yaml
.
load
(
file
,
Loader
=
yaml
.
FullLoader
)
if
'_base_'
in
dic
:
cfg_dir
=
os
.
path
.
dirname
(
path
)
base_path
=
dic
.
pop
(
'_base_'
)
...
...
@@ -69,111 +69,85 @@ class Config(object):
dic
=
self
.
_update_dic
(
dic
,
base_dic
)
return
dic
def
_build
(
self
,
dic
:
dict
):
'''Build config from dictionary'''
dic
=
dic
.
copy
()
self
.
_batch_size
=
dic
.
get
(
'batch_size'
,
1
)
self
.
_iters
=
dic
.
get
(
'iters'
)
if
'model'
not
in
dic
:
raise
RuntimeError
()
self
.
_model_cfg
=
dic
[
'model'
]
self
.
_model
=
None
self
.
_train_dataset
=
dic
.
get
(
'train_dataset'
)
self
.
_val_dataset
=
dic
.
get
(
'val_dataset'
)
self
.
_learning_rate_cfg
=
dic
.
get
(
'learning_rate'
,
{})
self
.
_learning_rate
=
self
.
_learning_rate_cfg
.
get
(
'value'
)
self
.
_decay
=
self
.
_learning_rate_cfg
.
get
(
'decay'
,
{
'type'
:
'poly'
,
'power'
:
0.9
})
self
.
_loss_cfg
=
dic
.
get
(
'loss'
,
{})
self
.
_losses
=
None
self
.
_optimizer_cfg
=
dic
.
get
(
'optimizer'
,
{})
def
update
(
self
,
learning_rate
:
float
=
None
,
batch_size
:
int
=
None
,
iters
:
int
=
None
):
'''Update config'''
if
learning_rate
:
self
.
_learning_rate
=
learning_rate
self
.
dic
[
'learning_rate'
][
'value'
]
=
learning_rate
if
batch_size
:
self
.
_batch_size
=
batch_size
self
.
dic
[
'batch_size'
]
=
batch_size
if
iters
:
self
.
_iters
=
iters
self
.
dic
[
'iters'
]
=
iters
@
property
def
batch_size
(
self
)
->
int
:
return
self
.
_batch_size
return
self
.
dic
.
get
(
'batch_size'
,
1
)
@
property
def
iters
(
self
)
->
int
:
if
not
self
.
_iters
:
iters
=
self
.
dic
.
get
(
'iters'
)
if
not
iters
:
raise
RuntimeError
(
'No iters specified in the configuration file.'
)
return
self
.
_
iters
return
iters
@
property
def
learning_rate
(
self
)
->
float
:
if
not
self
.
_learning_rate
:
_learning_rate
=
self
.
dic
.
get
(
'learning_rate'
,
{}).
get
(
'value'
)
if
not
_learning_rate
:
raise
RuntimeError
(
'No learning rate specified in the configuration file.'
)
if
self
.
decay_type
==
'poly'
:
lr
=
self
.
_learning_rate
args
=
self
.
decay_args
args
.
setdefault
(
'decay_steps'
,
self
.
iters
)
args
.
setdefault
(
'end_lr'
,
0
)
args
=
self
.
decay_args
decay_type
=
args
.
pop
(
'type'
)
if
decay_type
==
'poly'
:
lr
=
_learning_rate
return
paddle
.
optimizer
.
PolynomialLR
(
lr
,
**
args
)
else
:
raise
RuntimeError
(
'Only poly decay support.'
)
@
property
def
optimizer
(
self
)
->
paddle
.
optimizer
.
Optimizer
:
if
self
.
optimizer_type
==
'sgd'
:
args
=
self
.
optimizer_args
optimizer_type
=
args
.
pop
(
'type'
)
if
optimizer_type
==
'sgd'
:
lr
=
self
.
learning_rate
args
=
self
.
optimizer_args
args
.
setdefault
(
'momentum'
,
0.9
)
return
paddle
.
optimizer
.
Momentum
(
lr
,
parameters
=
self
.
model
.
parameters
(),
**
args
)
else
:
raise
RuntimeError
(
'Only sgd optimizer support.'
)
@
property
def
optimizer_type
(
self
)
->
str
:
otype
=
self
.
_optimizer_cfg
.
get
(
'type'
)
if
not
otype
:
raise
RuntimeError
(
'No optimizer type specified in the configuration file.'
)
return
otype
@
property
def
optimizer_args
(
self
)
->
dict
:
args
=
self
.
_optimizer_cfg
.
copy
()
args
.
pop
(
'type'
)
return
args
args
=
self
.
dic
.
get
(
'optimizer'
,
{})
.
copy
()
if
args
[
'type'
]
==
'sgd'
:
args
.
setdefault
(
'momentum'
,
0.9
)
@
property
def
decay_type
(
self
)
->
str
:
return
self
.
_decay
[
'type'
]
return
args
@
property
def
decay_args
(
self
)
->
dict
:
args
=
self
.
_decay
.
copy
()
args
.
pop
(
'type'
)
args
=
self
.
dic
.
get
(
'learning_rate'
,
{}).
get
(
'decay'
,
{
'type'
:
'poly'
,
'power'
:
0.9
}).
copy
()
if
args
[
'type'
]
==
'poly'
:
args
.
setdefault
(
'decay_steps'
,
self
.
iters
)
args
.
setdefault
(
'end_lr'
,
0
)
return
args
@
property
def
loss
(
self
)
->
list
:
args
=
self
.
dic
.
get
(
'loss'
,
{}).
copy
()
if
not
self
.
_losses
:
args
=
self
.
_loss_cfg
.
copy
()
self
.
_losses
=
dict
()
for
key
,
val
in
args
.
items
():
if
key
==
'types'
:
...
...
@@ -191,21 +165,26 @@ class Config(object):
@
property
def
model
(
self
)
->
Callable
:
model_cfg
=
self
.
dic
.
get
(
'model'
).
copy
()
if
not
model_cfg
:
raise
RuntimeError
(
'No model specified in the configuration file.'
)
if
not
self
.
_model
:
self
.
_model
=
self
.
_load_object
(
self
.
_
model_cfg
)
self
.
_model
=
self
.
_load_object
(
model_cfg
)
return
self
.
_model
@
property
def
train_dataset
(
self
)
->
Any
:
if
not
self
.
_train_dataset
:
_train_dataset
=
self
.
dic
.
get
(
'train_dataset'
).
copy
()
if
not
_train_dataset
:
return
None
return
self
.
_load_object
(
self
.
_train_dataset
)
return
self
.
_load_object
(
_train_dataset
)
@
property
def
val_dataset
(
self
)
->
Any
:
if
not
self
.
_val_dataset
:
_val_dataset
=
self
.
dic
.
get
(
'val_dataset'
).
copy
()
if
not
_val_dataset
:
return
None
return
self
.
_load_object
(
self
.
_val_dataset
)
return
self
.
_load_object
(
_val_dataset
)
def
_load_component
(
self
,
com_name
:
str
)
->
Any
:
com_list
=
[
...
...
@@ -243,3 +222,6 @@ class Config(object):
def
_is_meta_type
(
self
,
item
:
Any
)
->
bool
:
return
isinstance
(
item
,
dict
)
and
'type'
in
item
def
__str__
(
self
)
->
str
:
return
yaml
.
dump
(
self
.
dic
)
dygraph/train.py
浏览文件 @
2e0bc41c
...
...
@@ -100,15 +100,22 @@ def main(args):
raise
RuntimeError
(
'No configuration file specified.'
)
cfg
=
Config
(
args
.
cfg
)
cfg
.
update
(
learning_rate
=
args
.
learning_rate
,
iters
=
args
.
iters
,
batch_size
=
args
.
batch_size
)
train_dataset
=
cfg
.
train_dataset
if
not
train_dataset
:
raise
RuntimeError
(
'The training dataset is not specified in the configuration file.'
)
val_dataset
=
cfg
.
val_dataset
if
args
.
do_eval
else
None
losses
=
cfg
.
loss
print
(
'---------------Config Information---------------'
)
print
(
cfg
)
print
(
'------------------------------------------------'
)
train
(
cfg
.
model
,
train_dataset
,
...
...
dygraph/val.py
浏览文件 @
2e0bc41c
...
...
@@ -55,6 +55,11 @@ def main(args):
raise
RuntimeError
(
'The verification dataset is not specified in the configuration file.'
)
print
(
'---------------Config Information---------------'
)
print
(
cfg
)
print
(
'------------------------------------------------'
)
evaluate
(
cfg
.
model
,
val_dataset
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录