From bb2dcf2c9f20a20f1efe2d59afb4baa0a3eaa24b Mon Sep 17 00:00:00 2001 From: likejiao Date: Tue, 22 Sep 2020 16:37:10 +0800 Subject: [PATCH] change model.ckpt to model_dir --- docs/tutorial/save_param.rst | 12 ++++++------ docs/zh_CN/tutorial/param.md | 6 +++--- examples/MADDPG/train.py | 8 ++++---- examples/QuickStart/train.py | 8 ++++---- examples/offline-Q-learning/parallel_run.py | 6 +++--- parl/core/fluid/agent.py | 20 ++++++++++---------- parl/core/fluid/tests/agent_base_test.py | 8 ++++---- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/docs/tutorial/save_param.rst b/docs/tutorial/save_param.rst index 82e411a..d5b3505 100644 --- a/docs/tutorial/save_param.rst +++ b/docs/tutorial/save_param.rst @@ -16,11 +16,11 @@ Here is a demonstration of usage: .. code-block:: python agent = AtariAgent() - # save the parameters of agent to ./model.ckpt - agent.save('./model.ckpt') - # restore the parameters from ./model.ckpt to agent - agent.restore('./model.ckpt') + # save the parameters of agent to ./model_dir + agent.save('./model_dir') + # restore the parameters from ./model_dir to agent + agent.restore('./model_dir') - # restore the parameters from ./model.ckpt to another_agent + # restore the parameters from ./model_dir to another_agent another_agent = AtariAgent() - another_agent.restore('./model.ckpt') + another_agent.restore('./model_dir') diff --git a/docs/zh_CN/tutorial/param.md b/docs/zh_CN/tutorial/param.md index bf0f749..8d68389 100644 --- a/docs/zh_CN/tutorial/param.md +++ b/docs/zh_CN/tutorial/param.md @@ -4,10 +4,10 @@ 当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。 ```python agent = AtariAgent() -# 保存参数到 ./model.ckpt -agent.save('./model.ckpt') +# 保存参数到 ./model_dir +agent.save('./model_dir') # 恢复参数到这个agent上 -agent.restore('./model.ckpt') +agent.restore('./model_dir') ``` 场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。 diff --git a/examples/MADDPG/train.py b/examples/MADDPG/train.py index 8454a73..e4c1490 100644 --- a/examples/MADDPG/train.py +++ b/examples/MADDPG/train.py @@ -121,10 +121,10 @@ def train_agent(): if args.restore: # restore modle for i in range(len(agents)): - model_file = args.model_dir + '/agent_' + str(i) + '.ckpt' + model_file = args.model_dir + '/agent_' + str(i) if not os.path.exists(model_file): - logger.info('model file {} does not exits'.format(model_file)) - raise Exception + raise Exception( + 'model file {} does not exits'.format(model_file)) agents[i].restore(model_file) t_start = time.time() @@ -166,7 +166,7 @@ def train_agent(): if not args.restore: os.makedirs(os.path.dirname(args.model_dir), exist_ok=True) for i in range(len(agents)): - model_name = '/agent_' + str(i) + '.ckpt' + model_name = '/agent_' + str(i) agents[i].save(args.model_dir + model_name) diff --git a/examples/QuickStart/train.py b/examples/QuickStart/train.py index fb4e66d..3f790e5 100644 --- a/examples/QuickStart/train.py +++ b/examples/QuickStart/train.py @@ -57,8 +57,8 @@ def main(): agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) # if the file already exists, restore parameters from it - if os.path.exists('./model.ckpt'): - agent.restore('./model.ckpt') + if os.path.exists('./model_dir'): + agent.restore('./model_dir') for i in range(1000): obs_list, action_list, reward_list = run_episode(env, agent) @@ -76,8 +76,8 @@ def main(): total_reward = np.sum(reward_list) logger.info('Test reward: {}'.format(total_reward)) - # save the parameters to ./model.ckpt - agent.save('./model.ckpt') + # save the parameters to ./model_dir + agent.save('./model_dir') if __name__ == '__main__': diff --git a/examples/offline-Q-learning/parallel_run.py b/examples/offline-Q-learning/parallel_run.py index d7da430..0440f06 100644 --- a/examples/offline-Q-learning/parallel_run.py +++ b/examples/offline-Q-learning/parallel_run.py @@ -97,15 +97,15 @@ def main(): model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num) agent = AtariAgent( algorithm, act_dim=act_dim, total_step=args.train_total_steps) - if os.path.isfile('./model.ckpt'): + if os.path.isfile('./model_dir'): logger.info("load model from file") - agent.restore('./model.ckpt') + agent.restore('./model_dir') if args.train: logger.info("train with memory data") run_train_step(agent, rpm) logger.info("finish training. Save the model.") - agent.save('./model.ckpt') + agent.save('./model_dir') else: logger.info("collect experience") collect_exp(env, rpm, agent) diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index aef49b9..0fdd1b7 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -21,6 +21,7 @@ from parl.core.fluid import layers from parl.core.agent_base import AgentBase from parl.core.fluid.algorithm import Algorithm from parl.utils import machine_info +from parl.utils import logger __all__ = ['Agent'] @@ -147,13 +148,14 @@ class Agent(AgentBase): .. code-block:: python agent = AtariAgent() - agent.save() - agent.save('./program_model') - agent.save('./program_model', program=agent.learn_program) + agent.save('./model_dir') + agent.save('./model_dir', program=agent.learn_program) """ - if save_path is None: - save_path = './program_model' + assert save_path is not None, 'please specify `save_path` ' + if os.path.isfile(save_path): + raise Exception('can not save to {}, it is a file, not directory'. + format(save_path)) if not os.path.exists(save_path): os.makedirs(save_path) all_programs = [ @@ -189,7 +191,6 @@ class Agent(AgentBase): def restore(self, save_path=None, program=None): """Restore previously saved parameters from save_path. - default save_path is ./program_model Args: save_path(str): path where parameters were previously saved. @@ -203,12 +204,11 @@ class Agent(AgentBase): .. code-block:: python agent = AtariAgent() - agent.save() - agent.restore() + agent.save('./model_dir') + agent.restore('./model_dir') """ - if save_path is None: - save_path = './program_model' + assert save_path is not None, 'please specify `save_path` ' if not os.path.exists(save_path): raise Exception( 'can not restore from {}, directory does not exists'.format( diff --git a/parl/core/fluid/tests/agent_base_test.py b/parl/core/fluid/tests/agent_base_test.py index 82688b9..e08ed0d 100644 --- a/parl/core/fluid/tests/agent_base_test.py +++ b/parl/core/fluid/tests/agent_base_test.py @@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase): agent = TestAgent(self.alg) obs = np.random.random([3, 10]).astype('float32') output_np = agent.predict(obs) - save_path1 = 'model.ckpt' - save_path2 = os.path.join('my_model', 'model-2.ckpt') + save_path1 = 'model_dir' + save_path2 = os.path.join('my_model', 'model-2_dir') agent.save(save_path1) agent.save(save_path2) self.assertTrue(os.path.exists(save_path1)) @@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase): agent = TestAgent(self.alg) obs = np.random.random([3, 10]).astype('float32') output_np = agent.predict(obs) - save_path1 = 'model.ckpt' + save_path1 = 'model_dir' previous_output = agent.predict(obs) agent.save(save_path1) agent.restore(save_path1) @@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase): agent.learn_program = parl.compile(agent.learn_program) obs = np.random.random([3, 10]).astype('float32') previous_output = agent.predict(obs) - save_path1 = 'model.ckpt' + save_path1 = 'model_dir' agent.save(save_path1) agent.restore(save_path1) -- GitLab