未验证 提交 9dc152f0 编写于 作者: B Bo Zhou 提交者: GitHub

make the quickstart more compact (#88)

* make the quickstart more compact

* remove args in the main function

* yapf

* add gif

* remove render

* Update README.md

* Update README.md

* Update README.md
上级 f8e594c9
## Quick Start Example
Based on PARL, train a agent to play CartPole game with policy gradient algorithm in a few minutes.
## Quick Start
Train an agent with PARL to solve the CartPole problem, a classical benchmark in RL.
## How to use
### Dependencies:
......@@ -23,8 +23,9 @@ pip install .
# Train model
cd examples/QuickStart/
python train.py
# Or visualize when evaluating: python train.py --eval_vis
```
### Result
After training, you will see the agent get the best score (200 points).
### Expected Result
<img src="performance.gif" width = "300" height ="200" alt="result"/>
The agent can get around 200 points in a few minutes.
......@@ -19,19 +19,15 @@ from parl.framework.agent_base import Agent
class CartpoleAgent(Agent):
def __init__(self, algorithm, obs_dim, act_dim, seed=1):
def __init__(self, algorithm, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
self.seed = seed
super(CartpoleAgent, self).__init__(algorithm)
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
fluid.default_startup_program().random_seed = self.seed
self.train_program.random_seed = self.seed
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import numpy as np
from cartpole_agent import CartpoleAgent
......@@ -25,15 +24,17 @@ OBS_DIM = 4
ACT_DIM = 2
GAMMA = 0.99
LEARNING_RATE = 1e-3
SEED = 1
def run_train_episode(env, agent):
def run_episode(env, agent, train_or_test='train'):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
action = agent.sample(obs)
if train_or_test == 'train':
action = agent.sample(obs)
else:
action = agent.predict(obs)
action_list.append(action)
obs, reward, done, info = env.step(action)
......@@ -44,30 +45,14 @@ def run_train_episode(env, agent):
return obs_list, action_list, reward_list
def run_evaluate_episode(env, agent):
obs = env.reset()
all_reward = 0
while True:
if args.eval_vis:
env.render()
action = agent.predict(obs)
obs, reward, done, info = env.step(action)
all_reward += reward
if done:
break
return all_reward
def main():
env = gym.make("CartPole-v0")
env.seed(SEED)
np.random.seed(SEED)
model = CartpoleModel(act_dim=ACT_DIM)
alg = PolicyGradient(model, hyperparas={'lr': LEARNING_RATE})
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM, seed=SEED)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
for i in range(1000):
obs_list, action_list, reward_list = run_train_episode(env, agent)
obs_list, action_list, reward_list = run_episode(env, agent)
logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))
batch_obs = np.array(obs_list)
......@@ -76,16 +61,10 @@ def main():
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
all_reward = run_evaluate_episode(env, agent)
logger.info('Test reward: {}'.format(all_reward))
_, _, reward_list = run_episode(env, agent, train_or_test='test')
total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--eval_vis',
action='store_true',
help='if set, will visualize the game when evaluating')
args = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册