提交 bdc13f2f 编写于 作者: T TomorrowIsAnOtherDay 提交者: whs

【Fluid models】implement DQN model (#889)

* [DQN]source code commit

* Update README.md

* Update README.md

* add mountain-car curve

* Update README.md

* Update README.md

* clean code

* fix code style

* [fix code style]/2

* remove some tensorflow package

* a better way to sample from replay memory

* code style
上级 0eb2f4b9
#-*- coding: utf-8 -*-
#File: DQN.py
from agent import Model
import gym
import argparse
from tqdm import tqdm
from expreplay import ReplayMemory, Experience
import numpy as np
import os
UPDATE_FREQ = 4
MEMORY_WARMUP_SIZE = 1000
def run_episode(agent, env, exp, train_or_test):
assert train_or_test in ['train', 'test'], train_or_test
total_reward = 0
state = env.reset()
for step in range(200):
action = agent.act(state, train_or_test)
next_state, reward, isOver, _ = env.step(action)
if train_or_test == 'train':
exp.append(Experience(state, action, reward, isOver))
# train model
# start training
if len(exp) > MEMORY_WARMUP_SIZE:
batch_idx = np.random.randint(
len(exp) - 1, size=(args.batch_size))
if step % UPDATE_FREQ == 0:
batch_state, batch_action, batch_reward, \
batch_next_state, batch_isOver = exp.sample(batch_idx)
agent.train(batch_state, batch_action, batch_reward, \
batch_next_state, batch_isOver)
total_reward += reward
state = next_state
if isOver:
break
return total_reward
def train_agent():
env = gym.make(args.env)
state_shape = env.observation_space.shape
exp = ReplayMemory(args.mem_size, state_shape)
action_dim = env.action_space.n
agent = Model(state_shape[0], action_dim, gamma=0.99)
while len(exp) < MEMORY_WARMUP_SIZE:
run_episode(agent, env, exp, train_or_test='train')
max_episode = 4000
# train
total_episode = 0
pbar = tqdm(total=max_episode)
recent_100_reward = []
for episode in xrange(max_episode):
# start epoch
total_reward = run_episode(agent, env, exp, train_or_test='train')
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
pbar.update()
# recent 100 reward
total_reward = run_episode(agent, env, exp, train_or_test='test')
recent_100_reward.append(total_reward)
if len(recent_100_reward) > 100:
recent_100_reward = recent_100_reward[1:]
pbar.write("episode:{} test_reward:{}".format(\
episode, np.mean(recent_100_reward)))
pbar.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='MountainCar-v0', \
help='enviroment to train DQN model, e.g CartPole-v0')
parser.add_argument('--gamma', type=float, default=0.99, \
help='discount factor for accumulated reward computation')
parser.add_argument('--mem_size', type=int, default=500000, \
help='memory size for experience replay')
parser.add_argument('--batch_size', type=int, default=192, \
help='batch size for training')
args = parser.parse_args()
train_agent()
<img src="mountain_car.gif" width="300" height="200">
# Reproduce DQN model
+ DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
# Mountain-CAR benchmark & performance
[MountainCar-v0](https://gym.openai.com/envs/MountainCar-v0/)
A car is on a one-dimensional track, positioned between two "mountains". The goal is to drive up the mountain on the right; however, the car's engine is not strong enough to scale the mountain in a single pass. Therefore, the only way to succeed is to drive back and forth to build up momentum.
<img src="curve.png" >
# How to use
+ Dependencies:
+ python2.7
+ gym
+ tqdm
+ paddle-fluid
+ Start Training:
```
# use mountain-car enviroment as default
python DQN.py
# use other enviorment
python DQN.py --env CartPole-v0
```
#-*- coding: utf-8 -*-
#File: agent.py
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
from tqdm import tqdm
import math
UPDATE_TARGET_STEPS = 200
class Model(object):
def __init__(self, state_dim, action_dim, gamma):
self.global_step = 0
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.exploration = 1.0
self._build_net()
def _get_inputs(self):
return [fluid.layers.data(\
name='state', shape=[self.state_dim], dtype='float32'),
fluid.layers.data(\
name='action', shape=[1], dtype='int32'),
fluid.layers.data(\
name='reward', shape=[], dtype='float32'),
fluid.layers.data(\
name='next_s', shape=[self.state_dim], dtype='float32'),
fluid.layers.data(\
name='isOver', shape=[], dtype='bool')]
def _build_net(self):
state, action, reward, next_s, isOver = self._get_inputs()
self.pred_value = self.get_DQN_prediction(state)
self.predict_program = fluid.default_main_program().clone()
action_onehot = fluid.layers.one_hot(action, self.action_dim)
action_onehot = fluid.layers.cast(action_onehot, dtype='float32')
pred_action_value = fluid.layers.reduce_sum(\
fluid.layers.elementwise_mul(action_onehot, self.pred_value), dim=1)
targetQ_predict_value = self.get_DQN_prediction(next_s, target=True)
best_v = fluid.layers.reduce_max(targetQ_predict_value, dim=1)
best_v.stop_gradient = True
target = reward + (1.0 - fluid.layers.cast(\
isOver, dtype='float32')) * self.gamma * best_v
cost = fluid.layers.square_error_cost(\
input=pred_action_value, label=target)
cost = fluid.layers.reduce_mean(cost)
self._sync_program = self._build_sync_target_network()
optimizer = fluid.optimizer.Adam(1e-3)
optimizer.minimize(cost)
# define program
self.train_program = fluid.default_main_program()
# fluid exe
place = fluid.CUDAPlace(0)
self.exe = fluid.Executor(place)
self.exe.run(fluid.default_startup_program())
def get_DQN_prediction(self, state, target=False):
variable_field = 'target' if target else 'policy'
# layer fc1
param_attr = ParamAttr(name='{}_fc1'.format(variable_field))
bias_attr = ParamAttr(name='{}_fc1_b'.format(variable_field))
fc1 = fluid.layers.fc(input=state,
size=256,
act='relu',
param_attr=param_attr,
bias_attr=bias_attr)
param_attr = ParamAttr(name='{}_fc2'.format(variable_field))
bias_attr = ParamAttr(name='{}_fc2_b'.format(variable_field))
fc2 = fluid.layers.fc(input=fc1,
size=128,
act='tanh',
param_attr=param_attr,
bias_attr=bias_attr)
param_attr = ParamAttr(name='{}_fc3'.format(variable_field))
bias_attr = ParamAttr(name='{}_fc3_b'.format(variable_field))
value = fluid.layers.fc(input=fc2,
size=self.action_dim,
param_attr=param_attr,
bias_attr=bias_attr)
return value
def _build_sync_target_network(self):
vars = fluid.default_main_program().list_vars()
policy_vars = []
target_vars = []
for var in vars:
if 'GRAD' in var.name: continue
if 'policy' in var.name:
policy_vars.append(var)
elif 'target' in var.name:
target_vars.append(var)
policy_vars.sort(key=lambda x: x.name.split('policy_')[1])
target_vars.sort(key=lambda x: x.name.split('target_')[1])
sync_program = fluid.default_main_program().clone()
with fluid.program_guard(sync_program):
sync_ops = []
for i, var in enumerate(policy_vars):
sync_op = fluid.layers.assign(policy_vars[i], target_vars[i])
sync_ops.append(sync_op)
sync_program = sync_program.prune(sync_ops)
return sync_program
def act(self, state, train_or_test):
sample = np.random.random()
if train_or_test == 'train' and sample < self.exploration:
act = np.random.randint(self.action_dim)
else:
state = np.expand_dims(state, axis=0)
pred_Q = self.exe.run(self.predict_program,
feed={'state': state.astype('float32')},
fetch_list=[self.pred_value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def train(self, state, action, reward, next_state, isOver):
if self.global_step % UPDATE_TARGET_STEPS == 0:
self.sync_target_network()
self.global_step += 1
action = np.expand_dims(action, -1)
self.exe.run(self.train_program, \
feed={'state': state, \
'action': action, \
'reward': reward, \
'next_s': next_state, \
'isOver': isOver})
def sync_target_network(self):
self.exe.run(self._sync_program)
#-*- coding: utf-8 -*-
#File: expreplay.py
from collections import namedtuple
import numpy as np
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self, max_size, state_shape):
self.max_size = int(max_size)
self.state_shape = state_shape
self.state = np.zeros((self.max_size, ) + state_shape, dtype='float32')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
def append(self, exp):
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.action[pos] = exp.action
self.reward[pos] = exp.reward
self.isOver[pos] = exp.isOver
def __len__(self):
return self._curr_size
def sample(self, batch_idx):
# index mapping to avoid sampling lastest state
batch_idx = (self._curr_pos + batch_idx) % self._curr_size
next_idx = (batch_idx + 1) % self._curr_size
state = self.state[batch_idx]
reward = self.reward[batch_idx]
action = self.action[batch_idx]
next_state = self.state[next_idx]
isOver = self.isOver[batch_idx]
return (state, action, reward, next_state, isOver)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册