Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
4dd59a1a
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4dd59a1a
编写于
4月 09, 2020
作者:
W
WuHaobo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Init PaddleClas
上级
9f39da88
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
14 deletion
+20
-14
ppcls/utils/config.py
ppcls/utils/config.py
+6
-4
ppcls/utils/save_load.py
ppcls/utils/save_load.py
+12
-10
tools/program.py
tools/program.py
+2
-0
未找到文件。
ppcls/utils/config.py
浏览文件 @
4dd59a1a
...
@@ -89,8 +89,8 @@ def print_config(config):
...
@@ -89,8 +89,8 @@ def print_config(config):
config: configs
config: configs
"""
"""
copyright
=
"PaddleC
LS
is powered by PaddlePaddle"
copyright
=
"PaddleC
las
is powered by PaddlePaddle"
ad
=
"https://github.com/PaddlePaddle/PaddleC
LS
"
ad
=
"https://github.com/PaddlePaddle/PaddleC
las
"
logger
.
info
(
"
\n
"
*
2
)
logger
.
info
(
"
\n
"
*
2
)
logger
.
info
(
copyright
)
logger
.
info
(
copyright
)
...
@@ -193,9 +193,11 @@ def get_config(fname, overrides=[], show=True):
...
@@ -193,9 +193,11 @@ def get_config(fname, overrides=[], show=True):
assert
os
.
path
.
exists
(
fname
),
\
assert
os
.
path
.
exists
(
fname
),
\
(
'config file({}) is not exist'
.
format
(
fname
))
(
'config file({}) is not exist'
.
format
(
fname
))
config
=
parse_config
(
fname
)
config
=
parse_config
(
fname
)
if
show
:
print_config
(
config
)
if
show
:
print_config
(
config
)
if
len
(
overrides
)
>
0
:
if
len
(
overrides
)
>
0
:
override_config
(
config
,
overrides
)
override_config
(
config
,
overrides
)
print_config
(
config
)
if
show
:
print_config
(
config
)
check_config
(
config
)
check_config
(
config
)
return
config
return
config
ppcls/utils/save_load.py
浏览文件 @
4dd59a1a
...
@@ -30,7 +30,7 @@ __all__ = ['init_model', 'save_model']
...
@@ -30,7 +30,7 @@ __all__ = ['init_model', 'save_model']
def
_mkdir_if_not_exist
(
path
):
def
_mkdir_if_not_exist
(
path
):
"""
"""
mkdir if not exists
mkdir if not exists
"""
"""
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
path
)):
os
.
makedirs
(
os
.
path
.
join
(
path
))
os
.
makedirs
(
os
.
path
.
join
(
path
))
...
@@ -97,25 +97,27 @@ def load_params(exe, prog, path, ignore_params=[]):
...
@@ -97,25 +97,27 @@ def load_params(exe, prog, path, ignore_params=[]):
fluid
.
io
.
set_program_state
(
prog
,
state
)
fluid
.
io
.
set_program_state
(
prog
,
state
)
def
init_model
(
config
,
program
,
exe
):
def
init_model
(
config
,
program
,
exe
,
prefix
=
"ppcls"
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
checkpoints
=
config
.
get
(
'checkpoints'
)
checkpoints
=
config
.
get
(
'checkpoints'
)
if
checkpoints
and
os
.
path
.
exists
(
checkpoints
):
if
checkpoints
:
fluid
.
load
(
program
,
checkpoints
,
exe
)
path
=
os
.
path
.
join
(
checkpoints
,
prefix
)
logger
.
info
(
"Finish initing model from {}"
.
format
(
checkpoints
))
fluid
.
load
(
program
,
path
,
exe
)
logger
.
info
(
"Finish initing model from {}"
.
format
(
path
))
return
return
pretrained_model
=
config
.
get
(
'pretrained_model'
)
pretrained_model
=
config
.
get
(
'pretrained_model'
)
if
pretrained_model
and
os
.
path
.
exists
(
pretrained_model
):
if
pretrained_model
:
load_params
(
exe
,
program
,
pretrained_model
)
path
=
os
.
path
.
join
(
pretrained_model
,
prefix
)
logger
.
info
(
"Finish initing model from {}"
.
format
(
pretrained_model
))
load_params
(
exe
,
program
,
path
)
logger
.
info
(
"Finish initing model from {}"
.
format
(
path
))
def
save_model
(
program
,
model_path
,
epoch_id
,
prefix
=
'ppcls'
):
def
save_model
(
program
,
model_path
,
epoch_id
,
prefix
=
'ppcls'
):
"""
"""
save model to the target path
save model to the target path
"""
"""
model_path
=
os
.
path
.
join
(
model_path
,
str
(
epoch_id
))
model_path
=
os
.
path
.
join
(
model_path
,
str
(
epoch_id
))
_mkdir_if_not_exist
(
model_path
)
_mkdir_if_not_exist
(
model_path
)
...
...
tools/program.py
浏览文件 @
4dd59a1a
...
@@ -357,6 +357,8 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
...
@@ -357,6 +357,8 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train'):
"""
"""
fetch_list
=
[
f
[
0
]
for
f
in
fetchs
.
values
()]
fetch_list
=
[
f
[
0
]
for
f
in
fetchs
.
values
()]
metric_list
=
[
f
[
1
]
for
f
in
fetchs
.
values
()]
metric_list
=
[
f
[
1
]
for
f
in
fetchs
.
values
()]
for
m
in
metric_list
:
m
.
reset
()
batch_time
=
AverageMeter
(
'cost'
,
':6.3f'
)
batch_time
=
AverageMeter
(
'cost'
,
':6.3f'
)
tic
=
time
.
time
()
tic
=
time
.
time
()
for
idx
,
batch
in
enumerate
(
dataloader
()):
for
idx
,
batch
in
enumerate
(
dataloader
()):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录