Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
d310c549
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d310c549
编写于
9月 22, 2020
作者:
L
likejiao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
compatible with `program` argument
上级
ba58d597
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
79 addition
and
84 deletion
+79
-84
parl/core/fluid/agent.py
parl/core/fluid/agent.py
+79
-84
未找到文件。
parl/core/fluid/agent.py
浏览文件 @
d310c549
...
@@ -132,11 +132,15 @@ class Agent(AgentBase):
...
@@ -132,11 +132,15 @@ class Agent(AgentBase):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
save
(
self
,
save_path
=
None
):
def
save
(
self
,
save_path
=
None
,
program
=
None
):
"""Save parameters
for every fluid program
.
"""Save parameters.
Args:
Args:
save_path(str): a directory where to save all the parameters.
save_path(str): a directory where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will all program.
Raises:
Error: if program does not exist
Example:
Example:
...
@@ -144,13 +148,35 @@ class Agent(AgentBase):
...
@@ -144,13 +148,35 @@ class Agent(AgentBase):
agent = AtariAgent()
agent = AtariAgent()
agent.save()
agent.save()
agent.save('./program_model')
agent.save('./program_model', program=agent.learn_program)
"""
"""
if
save_path
is
None
:
if
save_path
is
None
:
save_path
=
'./program_model'
save_path
=
'./program_model'
if
not
os
.
path
.
exists
(
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
os
.
makedirs
(
save_path
)
for
keyval
in
self
.
__dict__
.
items
():
all_programs
=
[
(
kv
[
0
],
kv
[
1
])
for
kv
in
self
.
__dict__
.
items
()
if
(
isinstance
(
kv
[
1
],
fluid
.
framework
.
Program
)
or
isinstance
(
kv
[
1
],
fluid
.
compiler
.
CompiledProgram
))
]
if
program
:
filename
=
None
for
keyval
in
all_programs
:
if
program
==
keyval
[
1
]:
filename
=
keyval
[
0
]
break
if
filename
is
None
:
raise
Exception
(
'can not find program {}.'
.
format
(
program
))
fluid
.
io
.
save_params
(
executor
=
self
.
fluid_executor
,
dirname
=
save_path
,
main_program
=
program
,
filename
=
filename
)
else
:
for
keyval
in
all_programs
:
filename
=
keyval
[
0
]
filename
=
keyval
[
0
]
program
=
keyval
[
1
]
program
=
keyval
[
1
]
if
isinstance
(
program
,
fluid
.
framework
.
Program
)
or
\
if
isinstance
(
program
,
fluid
.
framework
.
Program
)
or
\
...
@@ -161,15 +187,16 @@ class Agent(AgentBase):
...
@@ -161,15 +187,16 @@ class Agent(AgentBase):
main_program
=
program
,
main_program
=
program
,
filename
=
filename
)
filename
=
filename
)
def
restore
(
self
,
save_path
=
None
):
def
restore
(
self
,
save_path
=
None
,
program
=
None
):
"""Restore previously saved parameters from save_path.
"""Restore previously saved parameters from save_path.
default save_path is ./program_model
default save_path is ./program_model
Args:
Args:
save_path(str): path where parameters were previously saved.
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will restore all program.
Raises:
Raises:
ValueError: if save_path does not exist or no
file in save_path.
Error: if save_path does not exist or can not find the specific program
file in save_path.
Example:
Example:
...
@@ -190,82 +217,50 @@ class Agent(AgentBase):
...
@@ -190,82 +217,50 @@ class Agent(AgentBase):
raise
Exception
(
raise
Exception
(
'can not restore from {}, it is a file, not directory'
.
format
(
'can not restore from {}, it is a file, not directory'
.
format
(
save_path
))
save_path
))
all_programs
=
[
for
keyval
in
self
.
__dict__
.
items
():
(
kv
[
0
],
kv
[
1
])
for
kv
in
self
.
__dict__
.
items
()
if
(
isinstance
(
kv
[
1
],
fluid
.
framework
.
Program
)
or
isinstance
(
kv
[
1
],
fluid
.
compiler
.
CompiledProgram
))
]
if
program
:
filename
=
None
for
keyval
in
all_programs
:
if
program
==
keyval
[
1
]:
filename
=
keyval
[
0
]
filename
=
keyval
[
0
]
program
=
keyval
[
1
]
break
if
isinstance
(
program
,
fluid
.
framework
.
Program
)
or
\
if
filename
is
None
:
isinstance
(
program
,
fluid
.
compiler
.
CompiledProgram
):
raise
Exception
(
'can not find the program to restore.'
)
if
not
os
.
path
.
isfile
(
'{}/{}'
.
format
(
save_path
,
filename
)):
if
not
os
.
path
.
isfile
(
'{}/{}'
.
format
(
save_path
,
filename
)):
raise
Exception
(
'{}/{} does not exits'
.
format
(
raise
Exception
(
'{}/{} does not exits'
.
format
(
save_path
,
filename
))
save_path
,
filename
))
if
type
(
program
)
is
fluid
.
compiler
.
CompiledProgram
:
if
type
(
program
)
is
fluid
.
compiler
.
CompiledProgram
:
program
=
program
.
_init_program
program
=
program
.
_init_program
fluid
.
io
.
load_params
(
fluid
.
io
.
load_params
(
executor
=
self
.
fluid_executor
,
executor
=
self
.
fluid_executor
,
dirname
=
save_path
,
dirname
=
save_path
,
main_program
=
program
,
main_program
=
program
,
filename
=
filename
)
filename
=
filename
)
else
:
def
save_program
(
self
,
save_path
,
program
=
None
):
programs_list
=
[
kv
[
0
]
for
kv
in
all_programs
]
"""Save parameters.
exist_files
=
os
.
listdir
(
save_path
)
if
len
(
programs_list
)
!=
len
(
exist_files
):
Args:
raise
Exception
(
save_path(str): where to save the parameters.
'expected to restore {} model file under directory {}: {}, but {} files are found: {}.'
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
.
format
(
len
(
programs_list
),
save_path
,
programs_list
,
Raises:
len
(
exist_files
),
exist_files
))
ValueError: if program is None and self.learn_program does not exist.
for
keyval
in
all_programs
:
filename
=
keyval
[
0
]
Example:
program
=
keyval
[
1
]
if
not
os
.
path
.
isfile
(
'{}/{}'
.
format
(
save_path
,
filename
)):
.. code-block:: python
raise
Exception
(
'{}/{} does not exits'
.
format
(
save_path
,
filename
))
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if
program
is
None
:
program
=
self
.
learn_program
dirname
=
os
.
sep
.
join
(
save_path
.
split
(
os
.
sep
)[:
-
1
])
filename
=
save_path
.
split
(
os
.
sep
)[
-
1
]
fluid
.
io
.
save_params
(
executor
=
self
.
fluid_executor
,
dirname
=
dirname
,
main_program
=
program
,
filename
=
filename
)
def
restore_program
(
self
,
save_path
,
program
=
None
):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""
if
program
is
None
:
program
=
self
.
learn_program
if
type
(
program
)
is
fluid
.
compiler
.
CompiledProgram
:
if
type
(
program
)
is
fluid
.
compiler
.
CompiledProgram
:
program
=
program
.
_init_program
program
=
program
.
_init_program
dirname
=
os
.
sep
.
join
(
save_path
.
split
(
os
.
sep
)[:
-
1
])
filename
=
save_path
.
split
(
os
.
sep
)[
-
1
]
fluid
.
io
.
load_params
(
fluid
.
io
.
load_params
(
executor
=
self
.
fluid_executor
,
executor
=
self
.
fluid_executor
,
dirname
=
dirname
,
dirname
=
save_path
,
main_program
=
program
,
main_program
=
program
,
filename
=
filename
)
filename
=
filename
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录