diff --git a/docs/index.rst b/docs/index.rst index 13e57da1143b033a2b216af5e07a4d82672c3198..e7d6c144112fca11f836b6890c68b2e4c2010832 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -60,6 +60,7 @@ Abstractions getting_started.rst new_alg.rst + save_param.rst .. toctree:: :maxdepth: 2 diff --git a/docs/save_param.rst b/docs/save_param.rst new file mode 100644 index 0000000000000000000000000000000000000000..3824eb9d3fe23c47f375877a75c6c88aab06c0b4 --- /dev/null +++ b/docs/save_param.rst @@ -0,0 +1,26 @@ +Save and Restore Parameters +============================= + +Goal of this tutorial: + +- Learn how to save and restore parameters. + +Example +--------------- + +Sometimes we need to save the parameters into a file and reuse them later on. PARL provides operators +to save parameters to a file and restore parameters from a file easily. You only need several lines to implement this. + +Here is a demonstration of usage: + +.. code-block:: python + + agent = AtariAgent() + # save the parameters of agent to ./model.ckpt + agent.save('./model.ckpt') + # restore the parameters from ./model.ckpt to agent + agent.restore('./model.ckpt') + + # restore the parameters from ./model.ckpt to another_agent + another_agent = AtariAgent() + another_agent.restore('./model.ckpt') diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py index b944e05270fc1a63d7ccaba549436f03d9c35a2f..27de75a08f4c0030b1101b1c40e25046cb484739 100644 --- a/examples/QuickStart/train.py +++ b/examples/QuickStart/train.py @@ -15,6 +15,7 @@ import gym import numpy as np import parl +import os.path from cartpole_agent import CartpoleAgent from cartpole_model import CartpoleModel from parl.utils import logger @@ -51,6 +52,10 @@ def main(): alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE) agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) + # if the file already exists, restore parameters from it + if os.path.exists('./model.ckpt'): + agent.restore('./model.ckpt') + for i in range(1000): obs_list, action_list, reward_list = run_episode(env, agent) if i % 10 == 0: @@ -67,6 +72,9 @@ def main(): total_reward = np.sum(reward_list) logger.info('Test reward: {}'.format(total_reward)) + # save the parameters to ./model.ckpt + agent.save('./model.ckpt') + if __name__ == '__main__': main()