Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
f5a5baa5
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f5a5baa5
编写于
9月 21, 2020
作者:
L
likejiao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Save the model without specifying a program
上级
b966fa78
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
81 addition
and
2 deletion
+81
-2
parl/core/fluid/agent.py
parl/core/fluid/agent.py
+81
-2
未找到文件。
parl/core/fluid/agent.py
浏览文件 @
f5a5baa5
...
...
@@ -21,6 +21,7 @@ from parl.core.fluid import layers
from
parl.core.agent_base
import
AgentBase
from
parl.core.fluid.algorithm
import
Algorithm
from
parl.utils
import
machine_info
from
parl.utils
import
logger
__all__
=
[
'Agent'
]
...
...
@@ -132,7 +133,85 @@ class Agent(AgentBase):
"""
raise
NotImplementedError
def
save
(
self
,
save_path
,
program
=
None
):
def
save
(
self
,
save_path
=
None
):
"""Save parameters for every fluid program.
Args:
save_path(str): a directory where to save all the parameters.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save()
"""
if
save_path
is
None
:
save_path
=
'./program_model'
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
for
keyval
in
self
.
__dict__
.
items
():
filename
=
keyval
[
0
]
program
=
keyval
[
1
]
if
isinstance
(
program
,
fluid
.
framework
.
Program
)
or
\
isinstance
(
program
,
fluid
.
compiler
.
CompiledProgram
):
fluid
.
io
.
save_params
(
executor
=
self
.
fluid_executor
,
dirname
=
save_path
,
main_program
=
program
,
filename
=
filename
)
def
restore
(
self
,
save_path
=
None
):
"""Restore previously saved parameters from save_path.
default save_path is ./program_model
Args:
save_path(str): path where parameters were previously saved.
Raises:
ValueError: if save_path does not exist or no file in save_path.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save()
agent.restore()
"""
if
save_path
is
None
:
save_path
=
'./program_model'
if
not
os
.
path
.
exists
(
save_path
):
raise
Exception
(
'can not restore from {}, directory does not exists'
.
format
(
save_path
))
if
os
.
path
.
isfile
(
save_path
):
raise
Exception
(
'can not restore from {}, it is a file, not directory'
.
format
(
save_path
))
for
keyval
in
self
.
__dict__
.
items
():
filename
=
keyval
[
0
]
program
=
keyval
[
1
]
if
isinstance
(
program
,
fluid
.
framework
.
Program
)
or
\
isinstance
(
program
,
fluid
.
compiler
.
CompiledProgram
):
if
not
os
.
path
.
isfile
(
'{}/{}'
.
format
(
save_path
,
filename
)):
raise
Exception
(
'{}/{} does not exits'
.
format
(
save_path
,
filename
))
logger
.
info
(
type
(
program
))
if
type
(
program
)
is
fluid
.
compiler
.
CompiledProgram
:
program
=
program
.
_init_program
logger
.
info
(
type
(
program
))
fluid
.
io
.
load_params
(
executor
=
self
.
fluid_executor
,
dirname
=
save_path
,
main_program
=
program
,
filename
=
filename
)
def
save_program
(
self
,
save_path
,
program
=
None
):
"""Save parameters.
Args:
...
...
@@ -160,7 +239,7 @@ class Agent(AgentBase):
main_program
=
program
,
filename
=
filename
)
def
restore
(
self
,
save_path
,
program
=
None
):
def
restore
_program
(
self
,
save_path
,
program
=
None
):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录