未验证 提交 e8797bd0 编写于 作者: R rical730 提交者: GitHub

update tutorials (#298)

* update tutorials
上级 2deefa8f
......@@ -21,7 +21,7 @@ import parl
from parl.utils import logger # 日志打印工具
from model import Model
from algorithm import DQN
from algorithm import DQN # from parl.algorithms import DQN # parl >= 1.3.1
from agent import Agent
from replay_memory import ReplayMemory
......@@ -117,7 +117,7 @@ def main():
# test part
eval_reward = evaluate(env, agent, render=True) # render=True 查看显示效果
logger.info('episode:{} e_greed:{} test_reward:{}'.format(
logger.info('episode:{} e_greed:{} Test reward:{}'.format(
episode, agent.e_greed, eval_reward))
# 训练结束,保存模型
......
......@@ -28,6 +28,7 @@ from parl.utils import logger
LEARNING_RATE = 1e-3
# 训练一个episode
def run_episode(env, agent):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
......@@ -44,19 +45,22 @@ def run_episode(env, agent):
return obs_list, action_list, reward_list
# 评估 agent, 跑 1 个episode
# 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
return episode_reward
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def calc_reward_to_go(reward_list, gamma=1.0):
......
......@@ -37,7 +37,8 @@ NOISE = 0.05 # 动作噪声方差
TRAIN_EPISODE = 6e3 # 训练的总episode数
def run_train_episode(agent, env, rpm):
# 训练一个episode
def run_episode(agent, env, rpm):
obs = env.reset()
total_reward = 0
steps = 0
......@@ -68,7 +69,8 @@ def run_train_episode(agent, env, rpm):
return total_reward
def run_evaluate_episode(env, agent, render=False):
# 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
eval_reward = []
for i in range(5):
obs = env.reset()
......@@ -109,16 +111,16 @@ def main():
rpm = ReplayMemory(MEMORY_SIZE)
# 往经验池中预存数据
while len(rpm) < MEMORY_WARMUP_SIZE:
run_train_episode(agent, env, rpm)
run_episode(agent, env, rpm)
episode = 0
while episode < TRAIN_EPISODE:
for i in range(50):
total_reward = run_train_episode(agent, env, rpm)
total_reward = run_episode(agent, env, rpm)
episode += 1
eval_reward = run_evaluate_episode(env, agent, render=False)
logger.info('episode:{} test_reward:{}'.format(
eval_reward = evaluate(env, agent, render=False)
logger.info('episode:{} Test reward:{}'.format(
episode, eval_reward))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册