未验证 提交 e8797bd0 编写于 作者: R rical730 提交者: GitHub

update tutorials (#298)

* update tutorials
上级 2deefa8f
...@@ -21,7 +21,7 @@ import parl ...@@ -21,7 +21,7 @@ import parl
from parl.utils import logger # 日志打印工具 from parl.utils import logger # 日志打印工具
from model import Model from model import Model
from algorithm import DQN from algorithm import DQN # from parl.algorithms import DQN # parl >= 1.3.1
from agent import Agent from agent import Agent
from replay_memory import ReplayMemory from replay_memory import ReplayMemory
...@@ -117,7 +117,7 @@ def main(): ...@@ -117,7 +117,7 @@ def main():
# test part # test part
eval_reward = evaluate(env, agent, render=True) # render=True 查看显示效果 eval_reward = evaluate(env, agent, render=True) # render=True 查看显示效果
logger.info('episode:{} e_greed:{} test_reward:{}'.format( logger.info('episode:{} e_greed:{} Test reward:{}'.format(
episode, agent.e_greed, eval_reward)) episode, agent.e_greed, eval_reward))
# 训练结束,保存模型 # 训练结束,保存模型
......
...@@ -28,6 +28,7 @@ from parl.utils import logger ...@@ -28,6 +28,7 @@ from parl.utils import logger
LEARNING_RATE = 1e-3 LEARNING_RATE = 1e-3
# 训练一个episode
def run_episode(env, agent): def run_episode(env, agent):
obs_list, action_list, reward_list = [], [], [] obs_list, action_list, reward_list = [], [], []
obs = env.reset() obs = env.reset()
...@@ -44,19 +45,22 @@ def run_episode(env, agent): ...@@ -44,19 +45,22 @@ def run_episode(env, agent):
return obs_list, action_list, reward_list return obs_list, action_list, reward_list
# 评估 agent, 跑 1 个episode # 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False): def evaluate(env, agent, render=False):
obs = env.reset() eval_reward = []
episode_reward = 0 for i in range(5):
while True: obs = env.reset()
action = agent.predict(obs) episode_reward = 0
obs, reward, isOver, _ = env.step(action) while True:
episode_reward += reward action = agent.predict(obs)
if render: obs, reward, isOver, _ = env.step(action)
env.render() episode_reward += reward
if isOver: if render:
break env.render()
return episode_reward if isOver:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def calc_reward_to_go(reward_list, gamma=1.0): def calc_reward_to_go(reward_list, gamma=1.0):
......
...@@ -37,7 +37,8 @@ NOISE = 0.05 # 动作噪声方差 ...@@ -37,7 +37,8 @@ NOISE = 0.05 # 动作噪声方差
TRAIN_EPISODE = 6e3 # 训练的总episode数 TRAIN_EPISODE = 6e3 # 训练的总episode数
def run_train_episode(agent, env, rpm): # 训练一个episode
def run_episode(agent, env, rpm):
obs = env.reset() obs = env.reset()
total_reward = 0 total_reward = 0
steps = 0 steps = 0
...@@ -68,7 +69,8 @@ def run_train_episode(agent, env, rpm): ...@@ -68,7 +69,8 @@ def run_train_episode(agent, env, rpm):
return total_reward return total_reward
def run_evaluate_episode(env, agent, render=False): # 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
eval_reward = [] eval_reward = []
for i in range(5): for i in range(5):
obs = env.reset() obs = env.reset()
...@@ -109,16 +111,16 @@ def main(): ...@@ -109,16 +111,16 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE) rpm = ReplayMemory(MEMORY_SIZE)
# 往经验池中预存数据 # 往经验池中预存数据
while len(rpm) < MEMORY_WARMUP_SIZE: while len(rpm) < MEMORY_WARMUP_SIZE:
run_train_episode(agent, env, rpm) run_episode(agent, env, rpm)
episode = 0 episode = 0
while episode < TRAIN_EPISODE: while episode < TRAIN_EPISODE:
for i in range(50): for i in range(50):
total_reward = run_train_episode(agent, env, rpm) total_reward = run_episode(agent, env, rpm)
episode += 1 episode += 1
eval_reward = run_evaluate_episode(env, agent, render=False) eval_reward = evaluate(env, agent, render=False)
logger.info('episode:{} test_reward:{}'.format( logger.info('episode:{} Test reward:{}'.format(
episode, eval_reward)) episode, eval_reward))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册