未验证 提交 1cf17a57 编写于 作者: R rical730 提交者: GitHub

add tutorials (#270)

* add tutorials

* yapf

* yapf

* copyright

* yapf

* update tutorial lesson5

* delete drawing code

* yapf

* remove action_mapping

* update dqn and add README

* update

* update

* yapf
上级 3b6e7a92
## 《PARL强化学习入门实践》课程示例
针对强化学习初学者,PARL提供了[入门课程](https://aistudio.baidu.com/aistudio/course/introduce/1335),展示最基础的5个强化学习算法代码示例。
## 课程大纲
+ 一、强化学习(RL)初印象
+ RL概述、入门路线
+ 实践:环境搭建([lesson1](lesson1/gridworld.py) 的代码提供了格子环境世界的渲染封装)
+ 二、基于表格型方法求解RL
+ MDP、状态价值、Q表格
+ 实践: [Sarsa](lesson2/sarsa)[Q-learning](lesson2/q_learning)
+ 三、基于神经网络方法求解RL
+ 函数逼近方法
+ 实践:[DQN](lesson3/dqn)
+ 四、基于策略梯度求解RL
+ 策略近似、策略梯度
+ 实践:[Policy Gradient](lesson4/policy_gradient)
+ 五、连续动作空间上求解RL
+ 实战:[DDPG](lesson5/ddpg)
## 使用说明
### 安装依赖
+ [paddlepaddle==1.6.3](https://github.com/PaddlePaddle/Paddle)
+ [parl==1.3.1](https://github.com/PaddlePaddle/PARL)
+ gym
### 运行示例
进入每个示例对应的代码文件夹中,运行
```
python train.py
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import gym
import turtle
import numpy as np
# turtle tutorial : https://docs.python.org/3.3/library/turtle.html
def GridWorld(gridmap=None, is_slippery=False):
if gridmap is None:
gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
env = FrozenLakeWapper(env)
return env
class FrozenLakeWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.max_y = env.desc.shape[0]
self.max_x = env.desc.shape[1]
self.t = None
self.unit = 50
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for _ in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for i in range(self.desc.shape[0]):
for j in range(self.desc.shape[1]):
x = j
y = self.max_y - 1 - i
if self.desc[i][j] == b'S': # Start
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'F': # Frozen ice
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'G': # Goal
self.draw_box(x, y, 'yellow')
elif self.desc[i][j] == b'H': # Hole
self.draw_box(x, y, 'black')
else:
self.draw_box(x, y, 'white')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
class CliffWalkingWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.t = None
self.unit = 50
self.max_x = 12
self.max_y = 4
def draw_x_line(self, y, x0, x1, color='gray'):
assert x1 > x0
self.t.color(color)
self.t.setheading(0)
self.t.up()
self.t.goto(x0, y)
self.t.down()
self.t.forward(x1 - x0)
def draw_y_line(self, x, y0, y1, color='gray'):
assert y1 > y0
self.t.color(color)
self.t.setheading(90)
self.t.up()
self.t.goto(x, y0)
self.t.down()
self.t.forward(y1 - y0)
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for i in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for _ in range(2):
self.t.forward(self.max_x * self.unit)
self.t.left(90)
self.t.forward(self.max_y * self.unit)
self.t.left(90)
for i in range(1, self.max_y):
self.draw_x_line(
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
for i in range(1, self.max_x):
self.draw_y_line(
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
for i in range(1, self.max_x - 1):
self.draw_box(i, 0, 'black')
self.draw_box(self.max_x - 1, 0, 'yellow')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
if __name__ == '__main__':
# 环境1:FrozenLake, 可以配置冰面是否是滑的
# 0 left, 1 down, 2 right, 3 up
env = gym.make("FrozenLake-v0", is_slippery=False)
env = FrozenLakeWapper(env)
# 环境2:CliffWalking, 悬崖环境
# env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
# env = CliffWalkingWapper(env)
# 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal
# gridmap = [
# 'SFFF',
# 'FHFF',
# 'FFFF',
# 'HFGF' ]
# env = GridWorld(gridmap)
env.reset()
for step in range(10):
action = np.random.randint(0, 4)
obs, reward, done, info = env.step(action)
print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
step, action, obs, reward, done, info))
env.render() # 渲染一帧图像
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import numpy as np
class QLearningAgent(object):
def __init__(self,
obs_n,
act_n,
learning_rate=0.01,
gamma=0.9,
e_greed=0.1):
self.act_n = act_n # 动作维度,有几个动作可选
self.lr = learning_rate # 学习率
self.gamma = gamma # reward的衰减率
self.epsilon = e_greed # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n))
# 根据输入观察值,采样输出的动作值,带探索
def sample(self, obs):
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
action = self.predict(obs)
else:
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
return action
# 根据输入观察值,预测输出的动作值
def predict(self, obs):
Q_list = self.Q[obs, :]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
action = np.random.choice(action_list)
return action
# 学习方法,也就是更新Q-table的方法
def learn(self, obs, action, reward, next_obs, done):
""" off-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
done: episode是否结束
"""
predict_Q = self.Q[obs, action]
if done:
target_Q = reward # 没有下一个状态了
else:
target_Q = reward + self.gamma * np.max(
self.Q[next_obs, :]) # Q-learning
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
# 把 Q表格 的数据保存到文件中
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
# 从文件中读取数据到 Q表格
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import gym
import turtle
import numpy as np
# turtle tutorial : https://docs.python.org/3.3/library/turtle.html
def GridWorld(gridmap=None, is_slippery=False):
if gridmap is None:
gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
env = FrozenLakeWapper(env)
return env
class FrozenLakeWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.max_y = env.desc.shape[0]
self.max_x = env.desc.shape[1]
self.t = None
self.unit = 50
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for _ in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for i in range(self.desc.shape[0]):
for j in range(self.desc.shape[1]):
x = j
y = self.max_y - 1 - i
if self.desc[i][j] == b'S': # Start
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'F': # Frozen ice
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'G': # Goal
self.draw_box(x, y, 'yellow')
elif self.desc[i][j] == b'H': # Hole
self.draw_box(x, y, 'black')
else:
self.draw_box(x, y, 'white')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
class CliffWalkingWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.t = None
self.unit = 50
self.max_x = 12
self.max_y = 4
def draw_x_line(self, y, x0, x1, color='gray'):
assert x1 > x0
self.t.color(color)
self.t.setheading(0)
self.t.up()
self.t.goto(x0, y)
self.t.down()
self.t.forward(x1 - x0)
def draw_y_line(self, x, y0, y1, color='gray'):
assert y1 > y0
self.t.color(color)
self.t.setheading(90)
self.t.up()
self.t.goto(x, y0)
self.t.down()
self.t.forward(y1 - y0)
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for i in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for _ in range(2):
self.t.forward(self.max_x * self.unit)
self.t.left(90)
self.t.forward(self.max_y * self.unit)
self.t.left(90)
for i in range(1, self.max_y):
self.draw_x_line(
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
for i in range(1, self.max_x):
self.draw_y_line(
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
for i in range(1, self.max_x - 1):
self.draw_box(i, 0, 'black')
self.draw_box(self.max_x - 1, 0, 'yellow')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
if __name__ == '__main__':
# 环境1:FrozenLake, 可以配置冰面是否是滑的
# 0 left, 1 down, 2 right, 3 up
env = gym.make("FrozenLake-v0", is_slippery=False)
env = FrozenLakeWapper(env)
# 环境2:CliffWalking, 悬崖环境
# env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
# env = CliffWalkingWapper(env)
# 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal
# gridmap = [
# 'SFFF',
# 'FHFF',
# 'FFFF',
# 'HFGF' ]
# env = GridWorld(gridmap)
env.reset()
for step in range(10):
action = np.random.randint(0, 4)
obs, reward, done, info = env.step(action)
print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
step, action, obs, reward, done, info))
# env.render() # 渲染一帧图像
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import gym
from gridworld import CliffWalkingWapper, FrozenLakeWapper
from agent import QLearningAgent
import time
def run_episode(env, agent, render=False):
total_steps = 0 # 记录每个episode走了多少step
total_reward = 0
obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)
while True:
action = agent.sample(obs) # 根据算法选择一个动作
next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互
# 训练 Q-learning算法
agent.learn(obs, action, reward, next_obs, done)
obs = next_obs # 存储上一个观察值
total_reward += reward
total_steps += 1 # 计算step数
if render:
env.render() #渲染新的一帧图形
if done:
break
return total_reward, total_steps
def test_episode(env, agent):
total_reward = 0
obs = env.reset()
while True:
action = agent.predict(obs) # greedy
next_obs, reward, done, _ = env.step(action)
total_reward += reward
obs = next_obs
time.sleep(0.5)
env.render()
if done:
print('test reward = %.1f' % (total_reward))
break
def main():
# env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up
# env = FrozenLakeWapper(env)
env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
env = CliffWalkingWapper(env)
agent = QLearningAgent(
obs_n=env.observation_space.n,
act_n=env.action_space.n,
learning_rate=0.1,
gamma=0.9,
e_greed=0.1)
is_render = False
for episode in range(500):
ep_reward, ep_steps = run_episode(env, agent, is_render)
print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps,
ep_reward))
# 每隔20个episode渲染一下看看效果
if episode % 20 == 0:
is_render = True
else:
is_render = False
# 训练结束,查看算法效果
test_episode(env, agent)
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import numpy as np
class SarsaAgent(object):
def __init__(self,
obs_n,
act_n,
learning_rate=0.01,
gamma=0.9,
e_greed=0.1):
self.act_n = act_n # 动作维度,有几个动作可选
self.lr = learning_rate # 学习率
self.gamma = gamma # reward的衰减率
self.epsilon = e_greed # 按一定概率随机选动作
self.Q = np.zeros((obs_n, act_n))
# 根据输入观察值,采样输出的动作值,带探索
def sample(self, obs):
if np.random.uniform(0, 1) < (1.0 - self.epsilon): #根据table的Q值选动作
action = self.predict(obs)
else:
action = np.random.choice(self.act_n) #有一定概率随机探索选取一个动作
return action
# 根据输入观察值,预测输出的动作值
def predict(self, obs):
Q_list = self.Q[obs, :]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0] # maxQ可能对应多个action
action = np.random.choice(action_list)
return action
# 学习方法,也就是更新Q-table的方法
def learn(self, obs, action, reward, next_obs, next_action, done):
""" on-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
done: episode是否结束
"""
predict_Q = self.Q[obs, action]
if done:
target_Q = reward # 没有下一个状态了
else:
target_Q = reward + self.gamma * self.Q[next_obs,
next_action] # Sarsa
self.Q[obs, action] += self.lr * (target_Q - predict_Q) # 修正q
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import gym
import turtle
import numpy as np
# turtle tutorial : https://docs.python.org/3.3/library/turtle.html
def GridWorld(gridmap=None, is_slippery=False):
if gridmap is None:
gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']
env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)
env = FrozenLakeWapper(env)
return env
class FrozenLakeWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.max_y = env.desc.shape[0]
self.max_x = env.desc.shape[1]
self.t = None
self.unit = 50
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for _ in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for i in range(self.desc.shape[0]):
for j in range(self.desc.shape[1]):
x = j
y = self.max_y - 1 - i
if self.desc[i][j] == b'S': # Start
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'F': # Frozen ice
self.draw_box(x, y, 'white')
elif self.desc[i][j] == b'G': # Goal
self.draw_box(x, y, 'yellow')
elif self.desc[i][j] == b'H': # Hole
self.draw_box(x, y, 'black')
else:
self.draw_box(x, y, 'white')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
class CliffWalkingWapper(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.t = None
self.unit = 50
self.max_x = 12
self.max_y = 4
def draw_x_line(self, y, x0, x1, color='gray'):
assert x1 > x0
self.t.color(color)
self.t.setheading(0)
self.t.up()
self.t.goto(x0, y)
self.t.down()
self.t.forward(x1 - x0)
def draw_y_line(self, x, y0, y1, color='gray'):
assert y1 > y0
self.t.color(color)
self.t.setheading(90)
self.t.up()
self.t.goto(x, y0)
self.t.down()
self.t.forward(y1 - y0)
def draw_box(self, x, y, fillcolor='', line_color='gray'):
self.t.up()
self.t.goto(x * self.unit, y * self.unit)
self.t.color(line_color)
self.t.fillcolor(fillcolor)
self.t.setheading(90)
self.t.down()
self.t.begin_fill()
for i in range(4):
self.t.forward(self.unit)
self.t.right(90)
self.t.end_fill()
def move_player(self, x, y):
self.t.up()
self.t.setheading(90)
self.t.fillcolor('red')
self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)
def render(self):
if self.t == None:
self.t = turtle.Turtle()
self.wn = turtle.Screen()
self.wn.setup(self.unit * self.max_x + 100,
self.unit * self.max_y + 100)
self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,
self.unit * self.max_y)
self.t.shape('circle')
self.t.width(2)
self.t.speed(0)
self.t.color('gray')
for _ in range(2):
self.t.forward(self.max_x * self.unit)
self.t.left(90)
self.t.forward(self.max_y * self.unit)
self.t.left(90)
for i in range(1, self.max_y):
self.draw_x_line(
y=i * self.unit, x0=0, x1=self.max_x * self.unit)
for i in range(1, self.max_x):
self.draw_y_line(
x=i * self.unit, y0=0, y1=self.max_y * self.unit)
for i in range(1, self.max_x - 1):
self.draw_box(i, 0, 'black')
self.draw_box(self.max_x - 1, 0, 'yellow')
self.t.shape('turtle')
x_pos = self.s % self.max_x
y_pos = self.max_y - 1 - int(self.s / self.max_x)
self.move_player(x_pos, y_pos)
if __name__ == '__main__':
# 环境1:FrozenLake, 可以配置冰面是否是滑的
# 0 left, 1 down, 2 right, 3 up
env = gym.make("FrozenLake-v0", is_slippery=False)
env = FrozenLakeWapper(env)
# 环境2:CliffWalking, 悬崖环境
# env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
# env = CliffWalkingWapper(env)
# 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal
# gridmap = [
# 'SFFF',
# 'FHFF',
# 'FFFF',
# 'HFGF' ]
# env = GridWorld(gridmap)
env.reset()
for step in range(10):
action = np.random.randint(0, 4)
obs, reward, done, info = env.step(action)
print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\
step, action, obs, reward, done, info))
# env.render() # 渲染一帧图像
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
import gym
from gridworld import CliffWalkingWapper, FrozenLakeWapper
from agent import SarsaAgent
import time
def run_episode(env, agent, render=False):
total_steps = 0 # 记录每个episode走了多少step
total_reward = 0
obs = env.reset() # 重置环境, 重新开一局(即开始新的一个episode)
action = agent.sample(obs) # 根据算法选择一个动作
while True:
next_obs, reward, done, _ = env.step(action) # 与环境进行一个交互
next_action = agent.sample(next_obs) # 根据算法选择一个动作
# 训练 Sarsa 算法
agent.learn(obs, action, reward, next_obs, next_action, done)
action = next_action
obs = next_obs # 存储上一个观察值
total_reward += reward
total_steps += 1 # 计算step数
if render:
env.render() #渲染新的一帧图形
if done:
break
return total_reward, total_steps
def test_episode(env, agent):
total_reward = 0
obs = env.reset()
while True:
action = agent.predict(obs) # greedy
next_obs, reward, done, _ = env.step(action)
total_reward += reward
obs = next_obs
time.sleep(0.5)
env.render()
if done:
print('test reward = %.1f' % (total_reward))
break
def main():
# env = gym.make("FrozenLake-v0", is_slippery=False) # 0 left, 1 down, 2 right, 3 up
# env = FrozenLakeWapper(env)
env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
env = CliffWalkingWapper(env)
agent = SarsaAgent(
obs_n=env.observation_space.n,
act_n=env.action_space.n,
learning_rate=0.1,
gamma=0.9,
e_greed=0.1)
is_render = False
for episode in range(500):
ep_reward, ep_steps = run_episode(env, agent, is_render)
print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps,
ep_reward))
# 每隔20个episode渲染一下看看效果
if episode % 20 == 0:
is_render = True
else:
is_render = False
# 训练结束,查看算法效果
test_episode(env, agent)
if __name__ == "__main__":
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
class Agent(parl.Agent):
def __init__(self,
algorithm,
obs_dim,
act_dim,
e_greed=0.1,
e_greed_decrement=0):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(Agent, self).__init__(algorithm)
self.global_step = 0
self.update_target_steps = 200 # 每隔200个training steps再把model的参数复制到target_model中
self.e_greed = e_greed # 有一定概率随机选取动作,探索
self.e_greed_decrement = e_greed_decrement # 随着训练逐步收敛,探索的程度慢慢降低
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program): # 搭建计算图用于 更新Q网络,定义输入输出变量
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
def sample(self, obs):
sample = np.random.rand() # 产生0~1之间的小数
if sample < self.e_greed:
act = np.random.randint(self.act_dim) # 探索:每个动作都有概率被选择
else:
act = self.predict(obs) # 选择最优动作
self.e_greed = max(
0.01, self.e_greed - self.e_greed_decrement) # 随着训练逐步收敛,探索的程度慢慢降低
return act
def predict(self, obs): # 选择最优动作
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q) # 选择Q最大的下标,即对应的动作
return act
def learn(self, obs, act, reward, next_obs, terminal):
# 每隔200个training steps同步一次model和target_model的参数
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0] # 训练一次网络
return cost
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import copy
import paddle.fluid as fluid
import parl
from parl import layers
class DQN(parl.Algorithm):
def __init__(self, model, act_dim=None, gamma=None, lr=None):
""" DQN algorithm
Args:
model (parl.Model): 定义Q函数的前向网络结构
act_dim (int): action空间的维度,即有几个action
gamma (float): reward的衰减因子
lr (float): learning_rate,学习率.
"""
self.model = model
self.target_model = copy.deepcopy(model)
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" 使用self.model的value网络来获取 [Q(s,a1),Q(s,a2),...]
"""
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal):
""" 使用DQN算法更新self.model的value网络
"""
# 从target_model中获取 max Q' 的值,用于计算target_Q
next_pred_value = self.target_model.value(next_obs)
best_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True # 阻止梯度传递
terminal = layers.cast(terminal, dtype='float32')
target = reward + (1.0 - terminal) * self.gamma * best_v
pred_value = self.model.value(obs) # 获取Q预测值
# 将action转onehot向量,比如:3 => [0,0,0,1,0]
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
# 下面一行是逐元素相乘,拿到action对应的 Q(s,a)
# 比如:pred_value = [[2.3, 5.7, 1.2, 3.9, 1.4]], action_onehot = [[0,0,0,1,0]]
# ==> pred_action_value = [[3.9]]
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
# 计算 Q(s,a) 与 target_Q的均方差,得到loss
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(learning_rate=self.lr) # 使用Adam优化器
optimizer.minimize(cost)
return cost
def sync_target(self):
""" 把 self.model 的模型参数值同步到 self.target_model
"""
self.model.sync_weights_to(self.target_model)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import parl
from parl import layers # 封装了 paddle.fluid.layers 的API
class Model(parl.Model):
def __init__(self, act_dim):
hid1_size = 128
hid2_size = 128
# 3层全连接网络
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act=None)
def value(self, obs):
h1 = self.fc1(obs)
h2 = self.fc2(h1)
Q = self.fc3(h2)
return Q
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/seungeunrho/minimalRL/blob/master/dqn.py
import random
import collections
import numpy as np
class ReplayMemory(object):
def __init__(self, max_size):
self.buffer = collections.deque(maxlen=max_size)
def append(self, exp):
self.buffer.append(exp)
def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
for experience in mini_batch:
s, a, r, s_p, done = experience
obs_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_obs_batch.append(s_p)
done_batch.append(done)
return np.array(obs_batch).astype('float32'), \
np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')
def __len__(self):
return len(self.buffer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import os
import gym
import numpy as np
import parl
from parl.utils import logger # 日志打印工具
from model import Model
from algorithm import DQN
from agent import Agent
from replay_memory import ReplayMemory
LEARN_FREQ = 5 # 训练频率,不需要每一个step都learn,攒一些新增经验后再learn,提高效率
MEMORY_SIZE = 20000 # replay memory的大小,越大越占用内存
MEMORY_WARMUP_SIZE = 200 # replay_memory 里需要预存一些经验数据,再从里面sample一个batch的经验让agent去learn
BATCH_SIZE = 32 # 每次给agent learn的数据数量,从replay memory随机里sample一批数据出来
LEARNING_RATE = 0.001 # 学习率
GAMMA = 0.99 # reward 的衰减因子,一般取 0.9 到 0.999 不等
# 训练一个episode
def run_episode(env, agent, rpm):
total_reward = 0
obs = env.reset()
step = 0
while True:
step += 1
action = agent.sample(obs) # 采样动作,所有动作都有概率被尝试到
next_obs, reward, done, _ = env.step(action)
rpm.append((obs, action, reward, next_obs, done))
# train model
if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_done) = rpm.sample(BATCH_SIZE)
train_loss = agent.learn(batch_obs, batch_action, batch_reward,
batch_next_obs,
batch_done) # s,a,r,s',done
total_reward += reward
obs = next_obs
if done:
break
return total_reward
# 评估 agent, 跑 5 个episode,总reward求平均
def evaluate(env, agent, render=False):
eval_reward = []
for i in range(5):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs) # 预测动作,只选最优动作
obs, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if done:
break
eval_reward.append(episode_reward)
return np.mean(eval_reward)
def main():
env = gym.make(
'CartPole-v0'
) # CartPole-v0: expected reward > 180 MountainCar-v0 : expected reward > -120
action_dim = env.action_space.n # CartPole-v0: 2
obs_shape = env.observation_space.shape # CartPole-v0: (4,)
rpm = ReplayMemory(MEMORY_SIZE) # DQN的经验回放池
# 根据parl框架构建agent
model = Model(act_dim=action_dim)
algorithm = DQN(model, act_dim=action_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = Agent(
algorithm,
obs_dim=obs_shape[0],
act_dim=action_dim,
e_greed=0.1, # 有一定概率随机选取动作,探索
e_greed_decrement=1e-6) # 随着训练逐步收敛,探索的程度慢慢降低
# 加载模型
# save_path = './dqn_model.ckpt'
# agent.restore(save_path)
# 先往经验池里存一些数据,避免最开始训练的时候样本丰富度不够
while len(rpm) < MEMORY_WARMUP_SIZE:
run_episode(env, agent, rpm)
max_episode = 2000
# start train
episode = 0
while episode < max_episode: # 训练max_episode个回合,test部分不计算入episode数量
# train part
for i in range(0, 50):
total_reward = run_episode(env, agent, rpm)
episode += 1
# test part
eval_reward = evaluate(env, agent, render=True) # render=True 查看显示效果
logger.info('episode:{} e_greed:{} test_reward:{}'.format(
episode, agent.e_greed, eval_reward))
# 训练结束,保存模型
save_path = './dqn_model.ckpt'
agent.save(save_path)
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
class Agent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
super(Agent, self).__init__(algorithm)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program): # 搭建计算图用于 预测动作,定义输入输出变量
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)
with fluid.program_guard(
self.learn_program): # 搭建计算图用于 更新policy网络,定义输入输出变量
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(name='act', shape=[1], dtype='int64')
reward = layers.data(name='reward', shape=[], dtype='float32')
self.cost = self.alg.learn(obs, act, reward)
def sample(self, obs):
obs = np.expand_dims(obs, axis=0) # 增加一维维度
act_prob = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.act_prob])[0]
act_prob = np.squeeze(act_prob, axis=0) # 减少一维维度
act = np.random.choice(range(self.act_dim), p=act_prob) # 根据动作概率选取动作
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act_prob = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.act_prob])[0]
act_prob = np.squeeze(act_prob, axis=0)
act = np.argmax(act_prob) # 根据动作概率选择概率最高的动作
return act
def learn(self, obs, act, reward):
act = np.expand_dims(act, axis=-1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int64'),
'reward': reward.astype('float32')
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import paddle.fluid as fluid
import parl
from parl import layers
class PolicyGradient(parl.Algorithm):
def __init__(self, model, lr=None):
""" Policy Gradient algorithm
Args:
model (parl.Model): policy的前向网络.
lr (float): 学习率.
"""
self.model = model
assert isinstance(lr, float)
self.lr = lr
def predict(self, obs):
""" 使用policy model预测输出的动作概率
"""
return self.model(obs)
def learn(self, obs, action, reward):
""" 用policy gradient 算法更新policy model
"""
act_prob = self.model(obs) # 获取输出动作概率
# log_prob = layers.cross_entropy(act_prob, action) # 交叉熵
log_prob = layers.reduce_sum(
-1.0 * layers.log(act_prob) * layers.one_hot(
action, act_prob.shape[1]),
dim=1)
cost = log_prob * reward
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr)
optimizer.minimize(cost)
return cost
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import parl
from parl import layers
class Model(parl.Model):
def __init__(self, act_dim):
act_dim = act_dim
hid1_size = act_dim * 10
self.fc1 = layers.fc(size=hid1_size, act='tanh')
self.fc2 = layers.fc(size=act_dim, act='softmax')
def forward(self, obs): # 可直接用 model = Model(5); model(obs)调用
out = self.fc1(obs)
out = self.fc2(out)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import os
import gym
import numpy as np
import parl
from agent import Agent
from model import Model
from algorithm import PolicyGradient # from parl.algorithms import PolicyGradient
from parl.utils import logger
LEARNING_RATE = 1e-3
def run_episode(env, agent):
obs_list, action_list, reward_list = [], [], []
obs = env.reset()
while True:
obs_list.append(obs)
action = agent.sample(obs)
action_list.append(action)
obs, reward, done, info = env.step(action)
reward_list.append(reward)
if done:
break
return obs_list, action_list, reward_list
# 评估 agent, 跑 1 个episode
def evaluate(env, agent, render=False):
obs = env.reset()
episode_reward = 0
while True:
action = agent.predict(obs)
obs, reward, isOver, _ = env.step(action)
episode_reward += reward
if render:
env.render()
if isOver:
break
return episode_reward
def calc_reward_to_go(reward_list, gamma=1.0):
for i in range(len(reward_list) - 2, -1, -1):
# G_t = r_t + γ·r_t+1 + ... = r_t + γ·G_t+1
reward_list[i] += gamma * reward_list[i + 1] # Gt
return np.array(reward_list)
def main():
env = gym.make('CartPole-v0')
# env = env.unwrapped # Cancel the minimum score limit
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
# 根据parl框架构建agent
model = Model(act_dim=act_dim)
alg = PolicyGradient(model, lr=LEARNING_RATE)
agent = Agent(alg, obs_dim=obs_dim, act_dim=act_dim)
# 加载模型
# if os.path.exists('./model.ckpt'):
# agent.restore('./model.ckpt')
# run_episode(env, agent, train_or_test='test', render=True)
# exit()
for i in range(1000):
obs_list, action_list, reward_list = run_episode(env, agent)
if i % 10 == 0:
logger.info("Episode {}, Reward Sum {}.".format(
i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_reward_to_go(reward_list)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
total_reward = evaluate(env, agent, render=True)
logger.info('Test reward: {}'.format(total_reward))
# save the parameters to ./model.ckpt
agent.save('./model.ckpt')
if __name__ == '__main__':
main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import numpy as np
import parl
from parl import layers
from paddle import fluid
class Agent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(Agent, self).__init__(algorithm)
# 注意:最开始先同步self.model和self.target_model的参数.
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(
name='act', shape=[self.act_dim], dtype='float32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
_, self.critic_cost = self.alg.learn(obs, act, reward, next_obs,
terminal)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(
self.pred_program, feed={'obs': obs},
fetch_list=[self.pred_act])[0]
act = np.squeeze(act)
return act
def learn(self, obs, act, reward, next_obs, terminal):
feed = {
'obs': obs,
'act': act,
'reward': reward,
'next_obs': next_obs,
'terminal': terminal
}
critic_cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0]
self.alg.sync_target()
return critic_cost
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import parl
from parl import layers
from copy import deepcopy
from paddle import fluid
class DDPG(parl.Algorithm):
def __init__(self,
model,
gamma=None,
tau=None,
actor_lr=None,
critic_lr=None):
""" DDPG algorithm
Args:
model (parl.Model): actor and critic 的前向网络.
model 必须实现 get_actor_params() 方法.
gamma (float): reward的衰减因子.
tau (float): self.target_model 跟 self.model 同步参数 的 软更新参数
actor_lr (float): actor 的学习率
critic_lr (float): critic 的学习率
"""
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.model = model
self.target_model = deepcopy(model)
def predict(self, obs):
""" 使用 self.model 的 actor model 来预测动作
"""
return self.model.policy(obs)
def learn(self, obs, action, reward, next_obs, terminal):
""" 用DDPG算法更新 actor 和 critic
"""
actor_cost = self._actor_learn(obs)
critic_cost = self._critic_learn(obs, action, reward, next_obs,
terminal)
return actor_cost, critic_cost
def _actor_learn(self, obs):
action = self.model.policy(obs)
Q = self.model.value(obs, action)
cost = layers.reduce_mean(-1.0 * Q)
optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
return cost
def _critic_learn(self, obs, action, reward, next_obs, terminal):
next_action = self.target_model.policy(next_obs)
next_Q = self.target_model.value(next_obs, next_action)
terminal = layers.cast(terminal, dtype='float32')
target_Q = reward + (1.0 - terminal) * self.gamma * next_Q
target_Q.stop_gradient = True
Q = self.model.value(obs, action)
cost = layers.square_error_cost(Q, target_Q)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
optimizer.minimize(cost)
return cost
def sync_target(self, decay=None, share_vars_parallel_executor=None):
""" self.target_model从self.model复制参数过来,若decay不为None,则是软更新
"""
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(
self.target_model,
decay=decay,
share_vars_parallel_executor=share_vars_parallel_executor)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
"""
Classic cart-pole system implemented by Rich Sutton et al.
Copied from http://incompleteideas.net/sutton/book/code/pole.c
permalink: https://perma.cc/C9ZM-652R
Continuous version by Ian Danforth
"""
import math
import gym
from gym import spaces, logger
from gym.utils import seeding
import numpy as np
class ContinuousCartPoleEnv(gym.Env):
metadata = {
'render.modes': ['human', 'rgb_array'],
'video.frames_per_second': 50
}
def __init__(self):
self.gravity = 9.8
self.masscart = 1.0
self.masspole = 0.1
self.total_mass = (self.masspole + self.masscart)
self.length = 0.5 # actually half the pole's length
self.polemass_length = (self.masspole * self.length)
self.force_mag = 30.0
self.tau = 0.02 # seconds between state updates
self.min_action = -1.0
self.max_action = 1.0
# Angle at which to fail the episode
self.theta_threshold_radians = 12 * 2 * math.pi / 360
self.x_threshold = 2.4
# Angle limit set to 2 * theta_threshold_radians so failing observation
# is still within bounds
high = np.array([
self.x_threshold * 2,
np.finfo(np.float32).max, self.theta_threshold_radians * 2,
np.finfo(np.float32).max
])
self.action_space = spaces.Box(
low=self.min_action, high=self.max_action, shape=(1, ))
self.observation_space = spaces.Box(-high, high)
self.seed()
self.viewer = None
self.state = None
self.steps_beyond_done = None
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def stepPhysics(self, force):
x, x_dot, theta, theta_dot = self.state
costheta = math.cos(theta)
sintheta = math.sin(theta)
temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta
) / self.total_mass
thetaacc = (self.gravity * sintheta - costheta * temp) / \
(self.length * (4.0/3.0 - self.masspole * costheta * costheta / self.total_mass))
xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass
x = x + self.tau * x_dot
x_dot = x_dot + self.tau * xacc
theta = theta + self.tau * theta_dot
theta_dot = theta_dot + self.tau * thetaacc
return (x, x_dot, theta, theta_dot)
def step(self, action):
action = np.expand_dims(action, 0)
assert self.action_space.contains(action), \
"%r (%s) invalid" % (action, type(action))
# Cast action to float to strip np trappings
force = self.force_mag * float(action)
self.state = self.stepPhysics(force)
x, x_dot, theta, theta_dot = self.state
done = x < -self.x_threshold \
or x > self.x_threshold \
or theta < -self.theta_threshold_radians \
or theta > self.theta_threshold_radians
done = bool(done)
if not done:
reward = 1.0
elif self.steps_beyond_done is None:
# Pole just fell!
self.steps_beyond_done = 0
reward = 1.0
else:
if self.steps_beyond_done == 0:
logger.warn("""
You are calling 'step()' even though this environment has already returned
done = True. You should always call 'reset()' once you receive 'done = True'
Any further steps are undefined behavior.
""")
self.steps_beyond_done += 1
reward = 0.0
return np.array(self.state), reward, done, {}
def reset(self):
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4, ))
self.steps_beyond_done = None
return np.array(self.state)
def render(self, mode='human'):
screen_width = 600
screen_height = 400
world_width = self.x_threshold * 2
scale = screen_width / world_width
carty = 100 # TOP OF CART
polewidth = 10.0
polelen = scale * 1.0
cartwidth = 50.0
cartheight = 30.0
if self.viewer is None:
from gym.envs.classic_control import rendering
self.viewer = rendering.Viewer(screen_width, screen_height)
l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
axleoffset = cartheight / 4.0
cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
self.carttrans = rendering.Transform()
cart.add_attr(self.carttrans)
self.viewer.add_geom(cart)
l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
pole.set_color(.8, .6, .4)
self.poletrans = rendering.Transform(translation=(0, axleoffset))
pole.add_attr(self.poletrans)
pole.add_attr(self.carttrans)
self.viewer.add_geom(pole)
self.axle = rendering.make_circle(polewidth / 2)
self.axle.add_attr(self.poletrans)
self.axle.add_attr(self.carttrans)
self.axle.set_color(.5, .5, .8)
self.viewer.add_geom(self.axle)
self.track = rendering.Line((0, carty), (screen_width, carty))
self.track.set_color(0, 0, 0)
self.viewer.add_geom(self.track)
if self.state is None:
return None
x = self.state
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART
self.carttrans.set_translation(cartx, carty)
self.poletrans.set_rotation(-x[2])
return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
def close(self):
if self.viewer:
self.viewer.close()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import paddle.fluid as fluid
import parl
from parl import layers
class Model(parl.Model):
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
def policy(self, obs):
return self.actor_model.policy(obs)
def value(self, obs, act):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameters()
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid_size = 100
self.fc1 = layers.fc(size=hid_size, act='relu')
self.fc2 = layers.fc(size=act_dim, act='tanh')
def policy(self, obs):
hid = self.fc1(obs)
means = self.fc2(hid)
return means
class CriticModel(parl.Model):
def __init__(self):
hid_size = 100
self.fc1 = layers.fc(size=hid_size, act='relu')
self.fc2 = layers.fc(size=1, act=None)
def value(self, obs, act):
concat = layers.concat([obs, act], axis=1)
hid = self.fc1(concat)
Q = self.fc2(hid)
Q = layers.squeeze(Q, axes=[1])
return Q
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from https://github.com/seungeunrho/minimalRL/blob/master/dqn.py
import random
import collections
import numpy as np
class ReplayMemory(object):
def __init__(self, max_size):
self.buffer = collections.deque(maxlen=max_size)
def append(self, exp):
self.buffer.append(exp)
def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
for experience in mini_batch:
s, a, r, s_p, done = experience
obs_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_obs_batch.append(s_p)
done_batch.append(done)
return np.array(obs_batch).astype('float32'), \
np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')
def __len__(self):
return len(self.buffer)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#-*- coding: utf-8 -*-
import gym
import numpy as np
import parl
from parl.utils import logger
from agent import Agent
from model import Model
from algorithm import DDPG # from parl.algorithms import DDPG
from env import ContinuousCartPoleEnv
from replay_memory import ReplayMemory
ACTOR_LR = 1e-3 # Actor网络的 learning rate
CRITIC_LR = 1e-3 # Critic网络的 learning rate
GAMMA = 0.99 # reward 的衰减因子
TAU = 0.001 # 软更新的系数
MEMORY_SIZE = int(1e6) # 经验池大小
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20 # 预存一部分经验之后再开始训练
BATCH_SIZE = 128
REWARD_SCALE = 0.1 # reward 缩放系数
NOISE = 0.05 # 动作噪声方差
TRAIN_EPISODE = 6e3 # 训练的总episode数
def run_train_episode(agent, env, rpm):
obs = env.reset()
total_reward = 0
steps = 0
while True:
steps += 1
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
# 增加探索扰动, 输出限制在 [-1.0, 1.0] 范围内
action = np.clip(np.random.normal(action, NOISE), -1.0, 1.0)
next_obs, reward, done, info = env.step(action)
action = [action] # 方便存入replaymemory
rpm.append((obs, action, REWARD_SCALE * reward, next_obs, done))
if len(rpm) > MEMORY_WARMUP_SIZE and (steps % 5) == 0:
(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_done) = rpm.sample(BATCH_SIZE)
agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_done)
obs = next_obs
total_reward += reward
if done or steps >= 200:
break
return total_reward
def run_evaluate_episode(env, agent, render=False):
eval_reward = []
for i in range(5):
obs = env.reset()
total_reward = 0
steps = 0
while True:
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
action = np.clip(action, -1.0, 1.0)
steps += 1
next_obs, reward, done, info = env.step(action)
obs = next_obs
total_reward += reward
if render:
env.render()
if done or steps >= 200:
break
eval_reward.append(total_reward)
return np.mean(eval_reward)
def main():
env = ContinuousCartPoleEnv()
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
# 使用PARL框架创建agent
model = Model(act_dim)
algorithm = DDPG(
model, gamma=GAMMA, tau=TAU, actor_lr=ACTOR_LR, critic_lr=CRITIC_LR)
agent = Agent(algorithm, obs_dim, act_dim)
# 创建经验池
rpm = ReplayMemory(MEMORY_SIZE)
# 往经验池中预存数据
while len(rpm) < MEMORY_WARMUP_SIZE:
run_train_episode(agent, env, rpm)
episode = 0
while episode < TRAIN_EPISODE:
for i in range(50):
total_reward = run_train_episode(agent, env, rpm)
episode += 1
eval_reward = run_evaluate_episode(env, agent, render=False)
logger.info('episode:{} test_reward:{}'.format(
episode, eval_reward))
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册