提交 6a672c80 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add offline q learning (#193)

* add offline q learning

* Update README.md

* update

* yapf
上级 c070db83
## Parallel Training with PARL
Use parl.compile to train the model parallelly. When applying offline training or dataset is too large to train on a single GPU, we can use parallel computing to accelerate training.
```python
# Set CUDA_VISIBLE_DEVICES to select which GPUs to train
import parl
import paddle.fluid as fluid
learn_program = fluid.Program()
with fluid.program_guard(learn_program):
# Define your learn program and training loss
pass
learn_program = parl.compile(learn_program, loss=training_loss)
# Pass the training loss to parl.compile. Distribute the model and data to GPUs.
```
## Demonstration
We provide a demonstration of offline Q-learning with parallel executing, in which we seperate the procedures of collecting data and training the model. First we collect data by interacting with the environment and save them to a replay memory file, and then fit and evaluate the Q network with the collected data. Repeat these two steps to improve the performance gradually.
### Dependencies:
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ tqdm
+ atari-py
### How to Run:
```shell
# Collect training data
python parallel_run.py --rom rom_files/pong.bin
# Train the model offline with multi-GPU
python parallel_run.py --rom rom_files/pong.bin --train
```
../DQN/atari.py
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
from parl import layers
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim, total_step):
super(AtariAgent, self).__init__(algorithm)
assert isinstance(act_dim, int)
self.act_dim = act_dim
self.exploration = 1.1
self.global_step = 0
self.update_target_steps = 10000 // 4
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
self.supervised_eval_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
# use parl.compile to distribute data and model to GPUs
self.learn_program = parl.compile(self.learn_program, loss=self.cost)
with fluid.program_guard(self.supervised_eval_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
action = layers.data(name='act', shape=[1], dtype='int32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.supervised_cost = self.alg.supervised_eval(
obs, action, reward, next_obs, terminal)
def sample(self, obs):
sample = np.random.random()
if sample < self.exploration:
act = np.random.randint(self.act_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.act_dim)
else:
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
self.exploration = max(0.1, self.exploration - 1e-6)
return act
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
self.pred_program,
feed={'obs': obs.astype('float32')},
fetch_list=[self.value])[0]
pred_Q = np.squeeze(pred_Q, axis=0)
act = np.argmax(pred_Q)
return act
def learn(self, obs, act, reward, next_obs, terminal):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
reward = np.clip(reward, -1, 1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal
}
cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
def supervised_eval(self, obs, act, reward, next_obs, terminal):
act = np.expand_dims(act, -1)
reward = np.clip(reward, -1, 1)
feed = {
'obs': obs.astype('float32'),
'act': act.astype('int32'),
'reward': reward,
'next_obs': next_obs.astype('float32'),
'terminal': terminal
}
cost = self.fluid_executor.run(
self.supervised_eval_program,
feed=feed,
fetch_list=[self.supervised_cost])[0]
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.act_dim = act_dim
self.conv1 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=5, stride=1, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=64, filter_size=4, stride=1, padding=1, act='relu')
self.conv4 = layers.conv2d(
num_filters=64, filter_size=3, stride=1, padding=1, act='relu')
self.fc1 = layers.fc(size=act_dim)
def value(self, obs):
obs = obs / 255.0
out = self.conv1(obs)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv2(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv3(out)
out = layers.pool2d(
input=out, pool_size=2, pool_stride=2, pool_type='max')
out = self.conv4(out)
out = layers.flatten(out, axis=1)
Q = self.fc1(out)
return Q
../DQN/atari_wrapper.py
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import copy
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN']
class DQN(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
gamma=None,
lr=None):
""" DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
"""
self.model = model
self.target_model = copy.deepcopy(model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
def learn(self, obs, action, reward, next_obs, terminal):
""" update value model self.model with DQN algorithm
"""
cost = self.cal_bellman_residual(obs, action, reward, next_obs,
terminal)
optimizer = fluid.optimizer.Adam(learning_rate=self.lr, epsilon=1e-3)
optimizer.minimize(cost)
return cost
def supervised_eval(self, obs, action, reward, next_obs, terminal):
""" Calculate squared Bellman residual with test dataset. The operations are the same as learn method above,
except backpropagation.
"""
cost = self.cal_bellman_residual(obs, action, reward, next_obs,
terminal)
cost.stop_gradient = True
return cost
def cal_bellman_residual(self, obs, action, reward, next_obs, terminal):
""" use self.model to get squared Bellman residual with fed data
"""
pred_value = self.model.value(obs)
next_pred_value = self.target_model.value(next_obs)
best_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
return cost
def sync_target(self, gpu_id=None):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import numpy as np
import os
import time
from tqdm import tqdm
import parl
import paddle.fluid as fluid
from parl.utils import get_gpu_count
from parl.utils import tensorboard, logger
from dqn import DQN # slight changes from parl.algorithms.DQN
from atari_agent import AtariAgent
from atari_model import AtariModel
from replay_memory import ReplayMemory, Experience
from utils import get_player
MEMORY_SIZE = int(1e6)
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 3e-4
gpu_num = get_gpu_count()
def run_train_step(agent, rpm):
for step in range(args.train_total_steps):
# use the first 80% data to train
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size * gpu_num)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
cost = agent.learn(batch_state, batch_action, batch_reward,
batch_next_state, batch_isOver)
if step % 100 == 0:
# use the last 20% data to evaluate
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_test_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
eval_cost = agent.supervised_eval(batch_state, batch_action,
batch_reward, batch_next_state,
batch_isOver)
logger.info(
"train step {}, train costs are {}, eval cost is {}.".format(
step, cost, eval_cost))
def collect_exp(env, rpm, agent):
state = env.reset()
# collect data to fulfill replay memory
for i in tqdm(range(MEMORY_SIZE)):
context = rpm.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action)
rpm.append(Experience(state, action, reward, isOver))
state = next_state
def main():
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
file_path = "memory.npz"
rpm = ReplayMemory(
MEMORY_SIZE,
IMAGE_SIZE,
CONTEXT_LEN,
load_file=True, # load replay memory data from file
file_path=file_path)
act_dim = env.action_space.n
model = AtariModel(act_dim)
algorithm = DQN(
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE * gpu_num)
agent = AtariAgent(
algorithm, act_dim=act_dim, total_step=args.train_total_steps)
if os.path.isfile('./model.ckpt'):
logger.info("load model from file")
agent.restore('./model.ckpt')
if args.train:
logger.info("train with memory data")
run_train_step(agent, rpm)
logger.info("finish training. Save the model.")
agent.save('./model.ckpt')
else:
logger.info("collect experience")
collect_exp(env, rpm, agent)
rpm.save_memory()
logger.info("finish collecting, save successfully")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--rom', help='path of the rom of the atari game', required=True)
parser.add_argument(
'--batch_size', type=int, default=64, help='batch size for each GPU')
parser.add_argument(
'--train',
action="store_true",
help='update the value function (default: False)')
parser.add_argument(
'--train_total_steps',
type=int,
default=int(1e6),
help='maximum environmental steps of games')
args = parser.parse_args()
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import copy
import os
from collections import deque, namedtuple
from parl.utils import logger
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'isOver'])
class ReplayMemory(object):
def __init__(self,
max_size,
state_shape,
context_len,
load_file=False,
file_path=None):
self.max_size = int(max_size)
self.state_shape = state_shape
self.context_len = int(context_len)
self.file_path = file_path
if load_file and os.path.isfile(file_path):
logger.info("load memory from file" + self.file_path)
self.load_memory()
logger.info("memory size is {}".format(self._curr_size))
else:
self.state = np.zeros(
(self.max_size, ) + state_shape, dtype='uint8')
self.action = np.zeros((self.max_size, ), dtype='int32')
self.reward = np.zeros((self.max_size, ), dtype='float32')
self.isOver = np.zeros((self.max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
self._context = deque(maxlen=context_len - 1)
def append(self, exp):
"""append a new experience into replay memory
"""
if self._curr_size < self.max_size:
self._assign(self._curr_pos, exp)
self._curr_size += 1
else:
self._assign(self._curr_pos, exp)
self._curr_pos = (self._curr_pos + 1) % self.max_size
if exp.isOver:
self._context.clear()
else:
self._context.append(exp)
def recent_state(self):
""" maintain recent state for training"""
lst = list(self._context)
states = [np.zeros(self.state_shape, dtype='uint8')] * \
(self._context.maxlen - len(lst))
states.extend([k.state for k in lst])
return states
def sample(self, idx):
""" return state, action, reward, isOver,
note that some frames in state may be generated from last episode,
they should be removed from state
"""
state = np.zeros(
(self.context_len + 1, ) + self.state_shape, dtype=np.uint8)
state_idx = np.arange(idx,
idx + self.context_len + 1) % self._curr_size
# confirm that no frame was generated from last episode
has_last_episode = False
for k in range(self.context_len - 2, -1, -1):
to_check_idx = state_idx[k]
if self.isOver[to_check_idx]:
has_last_episode = True
state_idx = state_idx[k + 1:]
state[k + 1:] = self.state[state_idx]
break
if not has_last_episode:
state = self.state[state_idx]
real_idx = (idx + self.context_len - 1) % self._curr_size
action = self.action[real_idx]
reward = self.reward[real_idx]
isOver = self.isOver[real_idx]
return state, reward, action, isOver
def __len__(self):
return self._curr_size
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
self.action[pos] = exp.action
self.isOver[pos] = exp.isOver
def sample_batch(self, batch_size):
"""sample a batch from replay memory for training
"""
batch_idx = np.random.randint(
int(self.max_size * 0.8) - self.context_len - 1, size=batch_size)
# batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def sample_test_batch(self, batch_size):
batch_idx = np.random.randint(
int(self.max_size * 0.2) - self.context_len - 1,
size=batch_size) + int(self.max_size * 0.8)
# batch_idx = (self._curr_pos + batch_idx) % self._curr_size
batch_exp = [self.sample(i) for i in batch_idx]
return self._process_batch(batch_exp)
def _process_batch(self, batch_exp):
state = np.asarray([e[0] for e in batch_exp], dtype='uint8')
reward = np.asarray([e[1] for e in batch_exp], dtype='float32')
action = np.asarray([e[2] for e in batch_exp], dtype='int8')
isOver = np.asarray([e[3] for e in batch_exp], dtype='bool')
return [state, action, reward, isOver]
def save_memory(self):
save_data = [
self.state, self.reward, self.action, self.isOver, self._curr_size,
self._curr_pos, self._context
]
np.savez(self.file_path, *save_data)
def load_memory(self):
container = np.load(self.file_path, allow_pickle=True)
[
self.state, self.reward, self.action, self.isOver, self._curr_size,
self._curr_pos, self._context
] = [container[key] for key in container]
self._curr_size = self._curr_size.astype(int)
self._curr_pos = self._curr_pos.astype(int)
self._context = deque([Experience(*row) for row in self._context],
maxlen=self.context_len - 1)
../DQN/rom_files/
\ No newline at end of file
../DQN/utils.py
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册