提交 4c98e3fd 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add save_params in docs and quickStart (#172)

* add save_param in docs and quickstart

* Update train.py
上级 4abc0534
......@@ -60,6 +60,7 @@ Abstractions
getting_started.rst
new_alg.rst
save_param.rst
.. toctree::
:maxdepth: 2
......
Save and Restore Parameters
=============================
Goal of this tutorial:
- Learn how to save and restore parameters.
Example
---------------
Sometimes we need to save the parameters into a file and reuse them later on. PARL provides operators
to save parameters to a file and restore parameters from a file easily. You only need several lines to implement this.
Here is a demonstration of usage:
.. code-block:: python
agent = AtariAgent()
# save the parameters of agent to ./model.ckpt
agent.save('./model.ckpt')
# restore the parameters from ./model.ckpt to agent
agent.restore('./model.ckpt')
# restore the parameters from ./model.ckpt to another_agent
another_agent = AtariAgent()
another_agent.restore('./model.ckpt')
......@@ -15,6 +15,7 @@
import gym
import numpy as np
import parl
import os.path
from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.utils import logger
......@@ -51,6 +52,10 @@ def main():
alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
# if the file already exists, restore parameters from it
if os.path.exists('./model.ckpt'):
agent.restore('./model.ckpt')
for i in range(1000):
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
......@@ -67,6 +72,9 @@ def main():
total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward))
# save the parameters to ./model.ckpt
agent.save('./model.ckpt')
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册