diff --git a/docs/tutorial/save_param.rst b/docs/tutorial/save_param.rst index 82e411ab2010ef3f9b4dcca0fd0c23f319eac7b7..d5b3505587f975c0bb8a8d51c9d7c6e73b259ea9 100644 --- a/docs/tutorial/save_param.rst +++ b/docs/tutorial/save_param.rst @@ -16,11 +16,11 @@ Here is a demonstration of usage: .. code-block:: python agent = AtariAgent() - # save the parameters of agent to ./model.ckpt - agent.save('./model.ckpt') - # restore the parameters from ./model.ckpt to agent - agent.restore('./model.ckpt') + # save the parameters of agent to ./model_dir + agent.save('./model_dir') + # restore the parameters from ./model_dir to agent + agent.restore('./model_dir') - # restore the parameters from ./model.ckpt to another_agent + # restore the parameters from ./model_dir to another_agent another_agent = AtariAgent() - another_agent.restore('./model.ckpt') + another_agent.restore('./model_dir') diff --git a/docs/zh_CN/tutorial/param.md b/docs/zh_CN/tutorial/param.md index bf0f74911412a14fcfb463a08d443ec7c7f52632..8d6838942cc0da8e4c020cf526209893e2e28f7d 100644 --- a/docs/zh_CN/tutorial/param.md +++ b/docs/zh_CN/tutorial/param.md @@ -4,10 +4,10 @@ 当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。 ```python agent = AtariAgent() -# 保存参数到 ./model.ckpt -agent.save('./model.ckpt') +# 保存参数到 ./model_dir +agent.save('./model_dir') # 恢复参数到这个agent上 -agent.restore('./model.ckpt') +agent.restore('./model_dir') ``` 场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。 diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py index 8454a73ee209707c65340897ce9b090d482c6751..e4c14904cea76ba51c31f6faa53cc3db3532050c 100644 --- a/examples/MADDPG/train.py +++ b/examples/MADDPG/train.py @@ -121,10 +121,10 @@ def train_agent(): if args.restore: # restore modle for i in range(len(agents)): - model_file = args.model_dir + '/agent_' + str(i) + '.ckpt' + model_file = args.model_dir + '/agent_' + str(i) if not os.path.exists(model_file): - logger.info('model file {} does not exits'.format(model_file)) - raise Exception + raise Exception( + 'model file {} does not exits'.format(model_file)) agents[i].restore(model_file) t_start = time.time() @@ -166,7 +166,7 @@ def train_agent(): if not args.restore: os.makedirs(os.path.dirname(args.model_dir), exist_ok=True) for i in range(len(agents)): - model_name = '/agent_' + str(i) + '.ckpt' + model_name = '/agent_' + str(i) agents[i].save(args.model_dir + model_name) diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py index fb4e66d3d52b404b07e109798e26d5b077ad8513..3f790e5f7addda1d32f6cafe38d37ac19b142d48 100644 --- a/examples/QuickStart/train.py +++ b/examples/QuickStart/train.py @@ -57,8 +57,8 @@ def main(): agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) # if the file already exists, restore parameters from it - if os.path.exists('./model.ckpt'): - agent.restore('./model.ckpt') + if os.path.exists('./model_dir'): + agent.restore('./model_dir') for i in range(1000): obs_list, action_list, reward_list = run_episode(env, agent) @@ -76,8 +76,8 @@ def main(): total_reward = np.sum(reward_list) logger.info('Test reward: {}'.format(total_reward)) - # save the parameters to ./model.ckpt - agent.save('./model.ckpt') + # save the parameters to ./model_dir + agent.save('./model_dir') if __name__ == '__main__': diff --git a/examples/offline-Q-learning/parallel_run.py b/examples/offline-Q-learning/parallel_run.py index d7da430e83de46be82a935bc01ce35ca6bd83c6e..0440f0609978134ab6557d18b6a980836cee1287 100644 --- a/examples/offline-Q-learning/parallel_run.py +++ b/examples/offline-Q-learning/parallel_run.py @@ -97,15 +97,15 @@ def main(): model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num) agent = AtariAgent( algorithm, act_dim=act_dim, total_step=args.train_total_steps) - if os.path.isfile('./model.ckpt'): + if os.path.isfile('./model_dir'): logger.info("load model from file") - agent.restore('./model.ckpt') + agent.restore('./model_dir') if args.train: logger.info("train with memory data") run_train_step(agent, rpm) logger.info("finish training. Save the model.") - agent.save('./model.ckpt') + agent.save('./model_dir') else: logger.info("collect experience") collect_exp(env, rpm, agent) diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index aef49b95ea8481dc41bf2a604d6ba04e6794d58e..0fdd1b709e61254ecaf8a02ea3ad7b12a3d531cc 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -21,6 +21,7 @@ from parl.core.fluid import layers from parl.core.agent_base import AgentBase from parl.core.fluid.algorithm import Algorithm from parl.utils import machine_info +from parl.utils import logger __all__ = ['Agent'] @@ -147,13 +148,14 @@ class Agent(AgentBase): .. code-block:: python agent = AtariAgent() - agent.save() - agent.save('./program_model') - agent.save('./program_model', program=agent.learn_program) + agent.save('./model_dir') + agent.save('./model_dir', program=agent.learn_program) """ - if save_path is None: - save_path = './program_model' + assert save_path is not None, 'please specify `save_path` ' + if os.path.isfile(save_path): + raise Exception('can not save to {}, it is a file, not directory'. + format(save_path)) if not os.path.exists(save_path): os.makedirs(save_path) all_programs = [ @@ -189,7 +191,6 @@ class Agent(AgentBase): def restore(self, save_path=None, program=None): """Restore previously saved parameters from save_path. - default save_path is ./program_model Args: save_path(str): path where parameters were previously saved. @@ -203,12 +204,11 @@ class Agent(AgentBase): .. code-block:: python agent = AtariAgent() - agent.save() - agent.restore() + agent.save('./model_dir') + agent.restore('./model_dir') """ - if save_path is None: - save_path = './program_model' + assert save_path is not None, 'please specify `save_path` ' if not os.path.exists(save_path): raise Exception( 'can not restore from {}, directory does not exists'.format( diff --git a/parl/core/fluid/tests/agent_base_test.py b/parl/core/fluid/tests/agent_base_test.py index 82688b97997cd8b3abf67cac24c6eab6772ac147..e08ed0daaa6ef13d37150b2a5ba046dc16611584 100644 --- a/parl/core/fluid/tests/agent_base_test.py +++ b/parl/core/fluid/tests/agent_base_test.py @@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase): agent = TestAgent(self.alg) obs = np.random.random([3, 10]).astype('float32') output_np = agent.predict(obs) - save_path1 = 'model.ckpt' - save_path2 = os.path.join('my_model', 'model-2.ckpt') + save_path1 = 'model_dir' + save_path2 = os.path.join('my_model', 'model-2_dir') agent.save(save_path1) agent.save(save_path2) self.assertTrue(os.path.exists(save_path1)) @@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase): agent = TestAgent(self.alg) obs = np.random.random([3, 10]).astype('float32') output_np = agent.predict(obs) - save_path1 = 'model.ckpt' + save_path1 = 'model_dir' previous_output = agent.predict(obs) agent.save(save_path1) agent.restore(save_path1) @@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase): agent.learn_program = parl.compile(agent.learn_program) obs = np.random.random([3, 10]).astype('float32') previous_output = agent.predict(obs) - save_path1 = 'model.ckpt' + save_path1 = 'model_dir' agent.save(save_path1) agent.restore(save_path1)