提交 bb2dcf2c 编写于 作者: L likejiao

change model.ckpt to model_dir

上级 d310c549
...@@ -16,11 +16,11 @@ Here is a demonstration of usage: ...@@ -16,11 +16,11 @@ Here is a demonstration of usage:
.. code-block:: python .. code-block:: python
agent = AtariAgent() agent = AtariAgent()
# save the parameters of agent to ./model.ckpt # save the parameters of agent to ./model_dir
agent.save('./model.ckpt') agent.save('./model_dir')
# restore the parameters from ./model.ckpt to agent # restore the parameters from ./model_dir to agent
agent.restore('./model.ckpt') agent.restore('./model_dir')
# restore the parameters from ./model.ckpt to another_agent # restore the parameters from ./model_dir to another_agent
another_agent = AtariAgent() another_agent = AtariAgent()
another_agent.restore('./model.ckpt') another_agent.restore('./model_dir')
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。 当用户构建好agent之后,可以直接通过agent的相关接口来完成参数的存储。
```python ```python
agent = AtariAgent() agent = AtariAgent()
# 保存参数到 ./model.ckpt # 保存参数到 ./model_dir
agent.save('./model.ckpt') agent.save('./model_dir')
# 恢复参数到这个agent上 # 恢复参数到这个agent上
agent.restore('./model.ckpt') agent.restore('./model_dir')
``` ```
场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。 场景2: 并行训练过程中,经常需要把最新的模型参数同步到另一台服务器上,这时候,需要把模型参数拿到内存中,然后再赋值给另一台机器上的agent(actor)。
......
...@@ -121,10 +121,10 @@ def train_agent(): ...@@ -121,10 +121,10 @@ def train_agent():
if args.restore: if args.restore:
# restore modle # restore modle
for i in range(len(agents)): for i in range(len(agents)):
model_file = args.model_dir + '/agent_' + str(i) + '.ckpt' model_file = args.model_dir + '/agent_' + str(i)
if not os.path.exists(model_file): if not os.path.exists(model_file):
logger.info('model file {} does not exits'.format(model_file)) raise Exception(
raise Exception 'model file {} does not exits'.format(model_file))
agents[i].restore(model_file) agents[i].restore(model_file)
t_start = time.time() t_start = time.time()
...@@ -166,7 +166,7 @@ def train_agent(): ...@@ -166,7 +166,7 @@ def train_agent():
if not args.restore: if not args.restore:
os.makedirs(os.path.dirname(args.model_dir), exist_ok=True) os.makedirs(os.path.dirname(args.model_dir), exist_ok=True)
for i in range(len(agents)): for i in range(len(agents)):
model_name = '/agent_' + str(i) + '.ckpt' model_name = '/agent_' + str(i)
agents[i].save(args.model_dir + model_name) agents[i].save(args.model_dir + model_name)
......
...@@ -57,8 +57,8 @@ def main(): ...@@ -57,8 +57,8 @@ def main():
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM) agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
# if the file already exists, restore parameters from it # if the file already exists, restore parameters from it
if os.path.exists('./model.ckpt'): if os.path.exists('./model_dir'):
agent.restore('./model.ckpt') agent.restore('./model_dir')
for i in range(1000): for i in range(1000):
obs_list, action_list, reward_list = run_episode(env, agent) obs_list, action_list, reward_list = run_episode(env, agent)
...@@ -76,8 +76,8 @@ def main(): ...@@ -76,8 +76,8 @@ def main():
total_reward = np.sum(reward_list) total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward)) logger.info('Test reward: {}'.format(total_reward))
# save the parameters to ./model.ckpt # save the parameters to ./model_dir
agent.save('./model.ckpt') agent.save('./model_dir')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -97,15 +97,15 @@ def main(): ...@@ -97,15 +97,15 @@ def main():
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num) model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num)
agent = AtariAgent( agent = AtariAgent(
algorithm, act_dim=act_dim, total_step=args.train_total_steps) algorithm, act_dim=act_dim, total_step=args.train_total_steps)
if os.path.isfile('./model.ckpt'): if os.path.isfile('./model_dir'):
logger.info("load model from file") logger.info("load model from file")
agent.restore('./model.ckpt') agent.restore('./model_dir')
if args.train: if args.train:
logger.info("train with memory data") logger.info("train with memory data")
run_train_step(agent, rpm) run_train_step(agent, rpm)
logger.info("finish training. Save the model.") logger.info("finish training. Save the model.")
agent.save('./model.ckpt') agent.save('./model_dir')
else: else:
logger.info("collect experience") logger.info("collect experience")
collect_exp(env, rpm, agent) collect_exp(env, rpm, agent)
......
...@@ -21,6 +21,7 @@ from parl.core.fluid import layers ...@@ -21,6 +21,7 @@ from parl.core.fluid import layers
from parl.core.agent_base import AgentBase from parl.core.agent_base import AgentBase
from parl.core.fluid.algorithm import Algorithm from parl.core.fluid.algorithm import Algorithm
from parl.utils import machine_info from parl.utils import machine_info
from parl.utils import logger
__all__ = ['Agent'] __all__ = ['Agent']
...@@ -147,13 +148,14 @@ class Agent(AgentBase): ...@@ -147,13 +148,14 @@ class Agent(AgentBase):
.. code-block:: python .. code-block:: python
agent = AtariAgent() agent = AtariAgent()
agent.save() agent.save('./model_dir')
agent.save('./program_model') agent.save('./model_dir', program=agent.learn_program)
agent.save('./program_model', program=agent.learn_program)
""" """
if save_path is None: assert save_path is not None, 'please specify `save_path` '
save_path = './program_model' if os.path.isfile(save_path):
raise Exception('can not save to {}, it is a file, not directory'.
format(save_path))
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
all_programs = [ all_programs = [
...@@ -189,7 +191,6 @@ class Agent(AgentBase): ...@@ -189,7 +191,6 @@ class Agent(AgentBase):
def restore(self, save_path=None, program=None): def restore(self, save_path=None, program=None):
"""Restore previously saved parameters from save_path. """Restore previously saved parameters from save_path.
default save_path is ./program_model
Args: Args:
save_path(str): path where parameters were previously saved. save_path(str): path where parameters were previously saved.
...@@ -203,12 +204,11 @@ class Agent(AgentBase): ...@@ -203,12 +204,11 @@ class Agent(AgentBase):
.. code-block:: python .. code-block:: python
agent = AtariAgent() agent = AtariAgent()
agent.save() agent.save('./model_dir')
agent.restore() agent.restore('./model_dir')
""" """
if save_path is None: assert save_path is not None, 'please specify `save_path` '
save_path = './program_model'
if not os.path.exists(save_path): if not os.path.exists(save_path):
raise Exception( raise Exception(
'can not restore from {}, directory does not exists'.format( 'can not restore from {}, directory does not exists'.format(
......
...@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase): ...@@ -92,8 +92,8 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.alg) agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = 'model.ckpt' save_path1 = 'model_dir'
save_path2 = os.path.join('my_model', 'model-2.ckpt') save_path2 = os.path.join('my_model', 'model-2_dir')
agent.save(save_path1) agent.save(save_path1)
agent.save(save_path2) agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1)) self.assertTrue(os.path.exists(save_path1))
...@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -103,7 +103,7 @@ class AgentBaseTest(unittest.TestCase):
agent = TestAgent(self.alg) agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = 'model.ckpt' save_path1 = 'model_dir'
previous_output = agent.predict(obs) previous_output = agent.predict(obs)
agent.save(save_path1) agent.save(save_path1)
agent.restore(save_path1) agent.restore(save_path1)
...@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -121,7 +121,7 @@ class AgentBaseTest(unittest.TestCase):
agent.learn_program = parl.compile(agent.learn_program) agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs) previous_output = agent.predict(obs)
save_path1 = 'model.ckpt' save_path1 = 'model_dir'
agent.save(save_path1) agent.save(save_path1)
agent.restore(save_path1) agent.restore(save_path1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册