提交 49b0e706 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add dygraph pg (#155)

* add dygraph pg

* update acc. comments

* update comments
上级 89c3366b
## Dygraph Quick Start
Train an agent with PARL to solve the CartPole problem, a classical benchmark in RL. Dygraph version of [QuickStart][origin]
## How to use
### Dependencies:
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
### Start Training:
```
# Install dependencies
pip install paddlepaddle
# Or use Cuda: pip install paddlepaddle-gpu
pip install gym
git clone https://github.com/PaddlePaddle/PARL.git
cd PARL
pip install .
# Train model
cd examples/EagerMode/QuickStart/
python train.py
```
### Expected Result
<img src="https://github.com/PaddlePaddle/PARL/blob/develop/examples/QuickStart/performance.gif" width = "300" height ="200" alt="result"/>
The agent can get around 200 points in a few minutes.
[origin]: https://github.com/PaddlePaddle/PARL/tree/develop/examples/QuickStart
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
from parl.utils import machine_info
class CartpoleAgent(object):
def __init__(
self,
alg,
obs_dim,
act_dim,
):
self.alg = alg
self.obs_dim = obs_dim
self.act_dim = act_dim
def sample(self, obs):
obs = np.expand_dims(obs, axis=0)
act_prob = self.alg.predict(obs).numpy()
act_prob = np.squeeze(act_prob, axis=0)
act = np.random.choice(self.act_dim, p=act_prob)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act_prob = self.alg.predict(obs).numpy()
act_prob = np.squeeze(act_prob, axis=0)
act = np.argmax(act_prob)
return act
def learn(self, obs, act, reward):
act = np.expand_dims(act, axis=-1)
reward = np.expand_dims(reward, axis=-1)
cost = self.alg.learn(obs, act, reward)
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
class CartpoleModel(fluid.dygraph.Layer):
def __init__(self, name_scope, act_dim):
super(CartpoleModel, self).__init__(name_scope)
hid1_size = act_dim * 10
self.fc1 = fluid.FC('fc1', hid1_size, act='tanh')
self.fc2 = fluid.FC('fc2', act_dim, act='softmax')
def forward(self, obs):
out = self.fc1(obs)
out = self.fc2(out)
return out
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class PolicyGradient(object):
def __init__(self, model, lr):
self.model = model
self.optimizer = fluid.optimizer.Adam(learning_rate=lr)
def predict(self, obs):
obs = fluid.dygraph.to_variable(obs)
obs = layers.cast(obs, dtype='float32')
return self.model(obs)
def learn(self, obs, action, reward):
obs = fluid.dygraph.to_variable(obs)
obs = layers.cast(obs, dtype='float32')
act_prob = self.model(obs)
action = fluid.dygraph.to_variable(action)
reward = fluid.dygraph.to_variable(reward)
log_prob = layers.cross_entropy(act_prob, action)
cost = log_prob * reward
cost = layers.cast(cost, dtype='float32')
cost = layers.reduce_mean(cost)
cost.backward()
self.optimizer.minimize(cost)
self.model.clear_gradients()
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import paddle.fluid as fluid
from parl.utils import logger
from cartpole_model import CartpoleModel
from cartpole_agent import CartpoleAgent
from policy_gradient import PolicyGradient
from utils import calc_discount_norm_reward
OBS_DIM = 4
ACT_DIM = 2
GAMMA = 0.99
LEARNING_RATE = 1e-3
def run_episode(env, agent, train_or_test='train'):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
if train_or_test == 'train':
action = agent.sample(obs)
else:
action = agent.predict(obs)
action_list.append(action)
obs, reward, done, _ = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
def main():
env = gym.make('CartPole-v0')
model = CartpoleModel(name_scope='noIdeaWhyNeedThis', act_dim=ACT_DIM)
alg = PolicyGradient(model, LEARNING_RATE)
agent = CartpoleAgent(alg, OBS_DIM, ACT_DIM)
with fluid.dygraph.guard():
for i in range(1000): # 100 episodes
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_discount_norm_reward(reward_list, GAMMA)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
_, _, reward_list = run_episode(
env, agent, train_or_test='test')
total_reward = np.sum(reward_list)
logger.info('Test reward: {}'.format(total_reward))
if __name__ == '__main__':
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def calc_discount_norm_reward(reward_list, gamma):
'''
Calculate the discounted reward list according to the discount factor gamma, and normalize it.
Args:
reward_list(list): a list containing the rewards along the trajectory.
gamma(float): the discounted factor for accumulation reward computation.
Returns:
a list containing the discounted reward
'''
discount_norm_reward = np.zeros_like(reward_list)
discount_cumulative_reward = 0
for i in reversed(range(0, len(reward_list))):
discount_cumulative_reward = (
gamma * discount_cumulative_reward + reward_list[i])
discount_norm_reward[i] = discount_cumulative_reward
discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward)
discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward)
return discount_norm_reward
......@@ -16,6 +16,14 @@ import numpy as np
def calc_discount_norm_reward(reward_list, gamma):
'''
Calculate the discounted reward list according to the discount factor gamma, and normalize it.
Args:
reward_list(list): a list containing the rewards along the trajectory.
gamma(float): the discounted factor for accumulation reward computation.
Returns:
a list containing the discounted reward
'''
discount_norm_reward = np.zeros_like(reward_list)
discount_cumulative_reward = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册