提交 c50746b1 编写于 作者: T TomorrowIsAnOtherDay

Update quick_start.md

上级 8ebddf33
# **教程:使用PARL解决Cartpole问题**
本教程会使用 [示例](~/parl/examples/QuickStart)中的代码来解释任何通过PARL构建智能体解决经典的Cartpole问题。
本教程会使用 [示例](https://github.com/PaddlePaddle/PARL/tree/develop/examples/QuickStart)中的代码来解释任何通过PARL构建智能体解决经典的Cartpole问题。
本教程的目标:
- 熟悉PARL构建智能体过程中需要用到的子模块。
......@@ -103,5 +103,67 @@ class CartpoleAgent(parl.Agent):
self.train_program, feed=feed, fetch_list=[self.cost])[0]
return cost
```
一般情况下,用户必须实现以下几个函数:
- 构造函数:
把algorithm传进来,以及相关的环境参数(用户自定义的)。需要注意的是,这里必须得要初始化父类:super(CartpoleAgent, self).__init__(algorithm)。
- build_program: 定义paddle里面的program。通常需要定义两个program:一个用于训练,一个用于预测。
- predict: 根据输入返回预测动作(action)。
- sample:根据输入返回动作(action),带探索的动作。
- learn: 输入训练数据,更新算法。
## 开始训练
首先,我们来定一个智能体。逐步定义model|algorithm|agent,然后得到一个可以和环境进行交互的智能体。
```python
model = CartpoleModel(act_dim=2)
alg = parl.algorithms.PolicyGradient(model, lr=1e-3)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=2)
```
然后我们用这个agent和环境进行交互,训练模型,1000个episode之后,agent就可以很好地解决Cartpole问题,拿到满分(200)。
```python
def run_episode(env, agent, train_or_test='train'):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
if train_or_test == 'train':
action = agent.sample(obs)
else:
action = agent.predict(obs)
action_list.append(action)
obs, reward, done, info = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
env = gym.make("CartPole-v0")
for i in range(1000):
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_discount_norm_reward(reward_list, GAMMA)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
_, _, reward_list = run_episode(env, agent, train_or_test='test')
total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward))
```
## 总结
<img src="../../../examples/QuickStart/performance.gif" width="300"/>
<img src="../../images/quickstart.png" width="300"/>
在这个教程,我们展示了如何一步步地构建强化学习智能体,用于解决经典的Cartpole问题。完整的训练代码可以在这个[文件夹](https://github.com/PaddlePaddle/PARL/tree/develop/examples/QuickStart)中找到。
- 构造函数
```shell
cd examples/QuickStart/
python train.py
```
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册