Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
bb2dcf2c
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
68
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb2dcf2c
编写于
9月 22, 2020
作者:
L
likejiao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change model.ckpt to model_dir
上级
d310c549
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
34 addition
and
34 deletion
+34
-34
docs/tutorial/save_param.rst
docs/tutorial/save_param.rst
+6
-6
docs/zh_CN/tutorial/param.md
docs/zh_CN/tutorial/param.md
+3
-3
examples/MADDPG/train.py
examples/MADDPG/train.py
+4
-4
examples/QuickStart/train.py
examples/QuickStart/train.py
+4
-4
examples/offline-Q-learning/parallel_run.py
examples/offline-Q-learning/parallel_run.py
+3
-3
parl/core/fluid/agent.py
parl/core/fluid/agent.py
+10
-10
parl/core/fluid/tests/agent_base_test.py
parl/core/fluid/tests/agent_base_test.py
+4
-4
未找到文件。
docs/tutorial/save_param.rst
浏览文件 @
bb2dcf2c
...
@@ -16,11 +16,11 @@ Here is a demonstration of usage:
...
@@ -16,11 +16,11 @@ Here is a demonstration of usage:
.. code-block:: python
.. code-block:: python
agent = AtariAgent()
agent = AtariAgent()
# save the parameters of agent to ./model
.ckpt
# save the parameters of agent to ./model
_dir
agent.save('./model
.ckpt
')
agent.save('./model
_dir
')
# restore the parameters from ./model
.ckpt
to agent
# restore the parameters from ./model
_dir
to agent
agent.restore('./model
.ckpt
')
agent.restore('./model
_dir
')
# restore the parameters from ./model
.ckpt
to another_agent
# restore the parameters from ./model
_dir
to another_agent
another_agent = AtariAgent()
another_agent = AtariAgent()
another_agent.restore('./model
.ckpt
')
another_agent.restore('./model
_dir
')
docs/zh_CN/tutorial/param.md
浏览文件 @
bb2dcf2c
...
@@ -4,10 +4,10 @@
...
@@ -4,10 +4,10 @@
当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。
当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。
```
python
```
python
agent
=
AtariAgent
()
agent
=
AtariAgent
()
# 保存参数到 ./model
.ckpt
# 保存参数到 ./model
_dir
agent
.
save
(
'./model
.ckpt
'
)
agent
.
save
(
'./model
_dir
'
)
# 恢复参数到这个agent上
# 恢复参数到这个agent上
agent
.
restore
(
'./model
.ckpt
'
)
agent
.
restore
(
'./model
_dir
'
)
```
```
场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。
场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。
...
...
examples/MADDPG/train.py
浏览文件 @
bb2dcf2c
...
@@ -121,10 +121,10 @@ def train_agent():
...
@@ -121,10 +121,10 @@ def train_agent():
if
args
.
restore
:
if
args
.
restore
:
# restore modle
# restore modle
for
i
in
range
(
len
(
agents
)):
for
i
in
range
(
len
(
agents
)):
model_file
=
args
.
model_dir
+
'/agent_'
+
str
(
i
)
+
'.ckpt'
model_file
=
args
.
model_dir
+
'/agent_'
+
str
(
i
)
if
not
os
.
path
.
exists
(
model_file
):
if
not
os
.
path
.
exists
(
model_file
):
logger
.
info
(
'model file {} does not exits'
.
format
(
model_file
))
raise
Exception
(
raise
Exception
'model file {} does not exits'
.
format
(
model_file
))
agents
[
i
].
restore
(
model_file
)
agents
[
i
].
restore
(
model_file
)
t_start
=
time
.
time
()
t_start
=
time
.
time
()
...
@@ -166,7 +166,7 @@ def train_agent():
...
@@ -166,7 +166,7 @@ def train_agent():
if
not
args
.
restore
:
if
not
args
.
restore
:
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
model_dir
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
model_dir
),
exist_ok
=
True
)
for
i
in
range
(
len
(
agents
)):
for
i
in
range
(
len
(
agents
)):
model_name
=
'/agent_'
+
str
(
i
)
+
'.ckpt'
model_name
=
'/agent_'
+
str
(
i
)
agents
[
i
].
save
(
args
.
model_dir
+
model_name
)
agents
[
i
].
save
(
args
.
model_dir
+
model_name
)
...
...
examples/QuickStart/train.py
浏览文件 @
bb2dcf2c
...
@@ -57,8 +57,8 @@ def main():
...
@@ -57,8 +57,8 @@ def main():
agent
=
CartpoleAgent
(
alg
,
obs_dim
=
OBS_DIM
,
act_dim
=
ACT_DIM
)
agent
=
CartpoleAgent
(
alg
,
obs_dim
=
OBS_DIM
,
act_dim
=
ACT_DIM
)
# if the file already exists, restore parameters from it
# if the file already exists, restore parameters from it
if
os
.
path
.
exists
(
'./model
.ckpt
'
):
if
os
.
path
.
exists
(
'./model
_dir
'
):
agent
.
restore
(
'./model
.ckpt
'
)
agent
.
restore
(
'./model
_dir
'
)
for
i
in
range
(
1000
):
for
i
in
range
(
1000
):
obs_list
,
action_list
,
reward_list
=
run_episode
(
env
,
agent
)
obs_list
,
action_list
,
reward_list
=
run_episode
(
env
,
agent
)
...
@@ -76,8 +76,8 @@ def main():
...
@@ -76,8 +76,8 @@ def main():
total_reward
=
np
.
sum
(
reward_list
)
total_reward
=
np
.
sum
(
reward_list
)
logger
.
info
(
'Test reward: {}'
.
format
(
total_reward
))
logger
.
info
(
'Test reward: {}'
.
format
(
total_reward
))
# save the parameters to ./model
.ckpt
# save the parameters to ./model
_dir
agent
.
save
(
'./model
.ckpt
'
)
agent
.
save
(
'./model
_dir
'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
examples/offline-Q-learning/parallel_run.py
浏览文件 @
bb2dcf2c
...
@@ -97,15 +97,15 @@ def main():
...
@@ -97,15 +97,15 @@ def main():
model
,
act_dim
=
act_dim
,
gamma
=
GAMMA
,
lr
=
LEARNING_RATE
*
gpu_num
)
model
,
act_dim
=
act_dim
,
gamma
=
GAMMA
,
lr
=
LEARNING_RATE
*
gpu_num
)
agent
=
AtariAgent
(
agent
=
AtariAgent
(
algorithm
,
act_dim
=
act_dim
,
total_step
=
args
.
train_total_steps
)
algorithm
,
act_dim
=
act_dim
,
total_step
=
args
.
train_total_steps
)
if
os
.
path
.
isfile
(
'./model
.ckpt
'
):
if
os
.
path
.
isfile
(
'./model
_dir
'
):
logger
.
info
(
"load model from file"
)
logger
.
info
(
"load model from file"
)
agent
.
restore
(
'./model
.ckpt
'
)
agent
.
restore
(
'./model
_dir
'
)
if
args
.
train
:
if
args
.
train
:
logger
.
info
(
"train with memory data"
)
logger
.
info
(
"train with memory data"
)
run_train_step
(
agent
,
rpm
)
run_train_step
(
agent
,
rpm
)
logger
.
info
(
"finish training. Save the model."
)
logger
.
info
(
"finish training. Save the model."
)
agent
.
save
(
'./model
.ckpt
'
)
agent
.
save
(
'./model
_dir
'
)
else
:
else
:
logger
.
info
(
"collect experience"
)
logger
.
info
(
"collect experience"
)
collect_exp
(
env
,
rpm
,
agent
)
collect_exp
(
env
,
rpm
,
agent
)
...
...
parl/core/fluid/agent.py
浏览文件 @
bb2dcf2c
...
@@ -21,6 +21,7 @@ from parl.core.fluid import layers
...
@@ -21,6 +21,7 @@ from parl.core.fluid import layers
from
parl.core.agent_base
import
AgentBase
from
parl.core.agent_base
import
AgentBase
from
parl.core.fluid.algorithm
import
Algorithm
from
parl.core.fluid.algorithm
import
Algorithm
from
parl.utils
import
machine_info
from
parl.utils
import
machine_info
from
parl.utils
import
logger
__all__
=
[
'Agent'
]
__all__
=
[
'Agent'
]
...
@@ -147,13 +148,14 @@ class Agent(AgentBase):
...
@@ -147,13 +148,14 @@ class Agent(AgentBase):
.. code-block:: python
.. code-block:: python
agent = AtariAgent()
agent = AtariAgent()
agent.save()
agent.save('./model_dir')
agent.save('./program_model')
agent.save('./model_dir', program=agent.learn_program)
agent.save('./program_model', program=agent.learn_program)
"""
"""
if
save_path
is
None
:
assert
save_path
is
not
None
,
'please specify `save_path` '
save_path
=
'./program_model'
if
os
.
path
.
isfile
(
save_path
):
raise
Exception
(
'can not save to {}, it is a file, not directory'
.
format
(
save_path
))
if
not
os
.
path
.
exists
(
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
os
.
makedirs
(
save_path
)
os
.
makedirs
(
save_path
)
all_programs
=
[
all_programs
=
[
...
@@ -189,7 +191,6 @@ class Agent(AgentBase):
...
@@ -189,7 +191,6 @@ class Agent(AgentBase):
def
restore
(
self
,
save_path
=
None
,
program
=
None
):
def
restore
(
self
,
save_path
=
None
,
program
=
None
):
"""Restore previously saved parameters from save_path.
"""Restore previously saved parameters from save_path.
default save_path is ./program_model
Args:
Args:
save_path(str): path where parameters were previously saved.
save_path(str): path where parameters were previously saved.
...
@@ -203,12 +204,11 @@ class Agent(AgentBase):
...
@@ -203,12 +204,11 @@ class Agent(AgentBase):
.. code-block:: python
.. code-block:: python
agent = AtariAgent()
agent = AtariAgent()
agent.save()
agent.save(
'./model_dir'
)
agent.restore()
agent.restore(
'./model_dir'
)
"""
"""
if
save_path
is
None
:
assert
save_path
is
not
None
,
'please specify `save_path` '
save_path
=
'./program_model'
if
not
os
.
path
.
exists
(
save_path
):
if
not
os
.
path
.
exists
(
save_path
):
raise
Exception
(
raise
Exception
(
'can not restore from {}, directory does not exists'
.
format
(
'can not restore from {}, directory does not exists'
.
format
(
...
...
parl/core/fluid/tests/agent_base_test.py
浏览文件 @
bb2dcf2c
...
@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase):
...
@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase):
agent
=
TestAgent
(
self
.
alg
)
agent
=
TestAgent
(
self
.
alg
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
output_np
=
agent
.
predict
(
obs
)
output_np
=
agent
.
predict
(
obs
)
save_path1
=
'model
.ckpt
'
save_path1
=
'model
_dir
'
save_path2
=
os
.
path
.
join
(
'my_model'
,
'model-2
.ckpt
'
)
save_path2
=
os
.
path
.
join
(
'my_model'
,
'model-2
_dir
'
)
agent
.
save
(
save_path1
)
agent
.
save
(
save_path1
)
agent
.
save
(
save_path2
)
agent
.
save
(
save_path2
)
self
.
assertTrue
(
os
.
path
.
exists
(
save_path1
))
self
.
assertTrue
(
os
.
path
.
exists
(
save_path1
))
...
@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase):
...
@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase):
agent
=
TestAgent
(
self
.
alg
)
agent
=
TestAgent
(
self
.
alg
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
output_np
=
agent
.
predict
(
obs
)
output_np
=
agent
.
predict
(
obs
)
save_path1
=
'model
.ckpt
'
save_path1
=
'model
_dir
'
previous_output
=
agent
.
predict
(
obs
)
previous_output
=
agent
.
predict
(
obs
)
agent
.
save
(
save_path1
)
agent
.
save
(
save_path1
)
agent
.
restore
(
save_path1
)
agent
.
restore
(
save_path1
)
...
@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase):
...
@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase):
agent
.
learn_program
=
parl
.
compile
(
agent
.
learn_program
)
agent
.
learn_program
=
parl
.
compile
(
agent
.
learn_program
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
obs
=
np
.
random
.
random
([
3
,
10
]).
astype
(
'float32'
)
previous_output
=
agent
.
predict
(
obs
)
previous_output
=
agent
.
predict
(
obs
)
save_path1
=
'model
.ckpt
'
save_path1
=
'model
_dir
'
agent
.
save
(
save_path1
)
agent
.
save
(
save_path1
)
agent
.
restore
(
save_path1
)
agent
.
restore
(
save_path1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录