提交 b28289ac 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

implement of IMPALA with the newest parallel design (#60)

* add IMPALA algorithm and some common utils

* update README.md

* refactor files structure of impala algorithm; seperate numpy utils from utils

* add hyper parameter scheduler module; add entropy and lr scheduler in impala

* clip reward in atari wrapper instead of learner side; fix codestyle

* add benchmark result of impala; refine code of impala example; add obs_format in atari_wrappers

* Update README.md
上级 7346a23d
......@@ -14,7 +14,6 @@ __pycache__/
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
......
......@@ -4,3 +4,4 @@ details
termcolor
pyarrow
zmq
parameterized
......@@ -79,6 +79,7 @@ pip install parl
- [DQN](examples/DQN/)
- [DDPG](examples/DDPG/)
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
<img src=".github/NeurlIPS2018.gif" width = "300" height ="200" alt="NeurlIPS2018"/> <img src=".github/Half-Cheetah.gif" width = "300" height ="200" alt="Half-Cheetah"/> <img src=".github/Breakout.gif" width = "200" height ="200" alt="Breakout"/>
......
## Reproduce IMPALA with PARL
Based on PARL, the IMPALA algorithm of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game.
+ IMPALA in
[Impala: Scalable distributed deep-rl with importance weighted actor-learner architectures](https://arxiv.org/abs/1802.01561)
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
### Benchmark result
Result with one learner (in P40 GPU) and 32 actors (in 32 CPUs).
+ PongNoFrameskip-v4: mean_episode_rewards can reach 18-19 score in about 7~10 minutes.
<img src=".benchmark/IMPALA_Pong.jpg" width = "400" height ="300" alt="IMPALA_Pong" />
+ Results of other games in an hour.
<img src=".benchmark/IMPALA_Breakout.jpg" width = "400" height ="300" alt="IMPALA_Breakout" /> <img src=".benchmark/IMPALA_BeamRider.jpg" width = "400" height ="300" alt="IMPALA_BeamRider"/>
<br>
<img src=".benchmark/IMPALA_Qbert.jpg" width = "400" height ="300" alt="IMPALA_Qbert" /> <img src=".benchmark/IMPALA_SpaceInvaders.jpg" width = "400" height ="300" alt="IMPALA_SpaceInvaders"/>
## How to use
### Dependencies
+ python2.7 or python3.5+
+ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ opencv-python
+ atari_py
### Distributed Training:
#### Learner
```sh
python train.py
```
#### Actors (Suggest: 32+ actors in 32+ CPUs)
```sh
for index in {1..32}; do
python actor.py &
done;
wait
```
You can change training settings (e.g. `env_name`, `server_ip`) in `impala_config.py`.
Training result will be saved in `log_dir/train/result.csv`.
### Reference
+ [deepmind/scalable_agent](https://github.com/deepmind/scalable_agent)
+ [Ray](https://github.com/ray-project/ray)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import parl
import six
from atari_model import AtariModel
from collections import defaultdict
from atari_agent import AtariAgent
from parl.algorithms import IMPALA
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
@parl.remote_class
class Actor(object):
def __init__(self, config):
self.config = config
self.envs = []
for _ in six.moves.range(config['env_num']):
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env)
self.vector_env = VectorEnv(self.envs)
self.obs_batch = self.vector_env.reset()
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = IMPALA(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
def sample(self):
env_sample_data = {}
for env_id in six.moves.range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list)
for i in six.moves.range(self.config['sample_batch_steps']):
actions, behaviour_logits = self.agent.sample(
np.stack(self.obs_batch))
next_obs_batch, reward_batch, done_batch, info_batch = \
self.vector_env.step(actions)
for env_id in six.moves.range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(self.obs_batch[env_id])
env_sample_data[env_id]['actions'].append(actions[env_id])
env_sample_data[env_id]['behaviour_logits'].append(
behaviour_logits[env_id])
env_sample_data[env_id]['rewards'].append(reward_batch[env_id])
env_sample_data[env_id]['dones'].append(done_batch[env_id])
self.obs_batch = next_obs_batch
# Merge data of envs
sample_data = defaultdict(list)
for env_id in six.moves.range(self.config['env_num']):
for data_name in [
'obs', 'actions', 'behaviour_logits', 'rewards', 'dones'
]:
sample_data[data_name].extend(
env_sample_data[env_id][data_name])
# size of sample_data: env_num * sample_batch_steps
for key in sample_data:
sample_data[key] = np.stack(sample_data[key])
return sample_data
def get_metrics(self):
metrics = defaultdict(list)
for env in self.envs:
monitor = get_wrapper_by_cls(env, MonitorEnv)
if monitor is not None:
for episode_rewards, episode_steps in monitor.next_episode_results(
):
metrics['episode_rewards'].append(episode_rewards)
metrics['episode_steps'].append(episode_steps)
return metrics
def set_params(self, params):
self.agent.set_params(params)
if __name__ == '__main__':
from impala_config import config
actor = Actor(config)
actor.as_remote(config['server_ip'], config['server_port'])
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
class AtariAgent(Agent):
def __init__(self, algorithm, config, learn_data_provider=None):
self.config = config
super(AtariAgent, self).__init__(algorithm)
use_cuda = True if self.gpu_id >= 0 else False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 4
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
# Use ParallelExecutor to make learn program running faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=self.learn_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
if learn_data_provider:
self.learn_reader.decorate_tensor_provider(learn_data_provider)
self.learn_reader.start()
def build_program(self):
self.sample_program = fluid.Program()
self.predict_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
self.sample_actions, self.behaviour_logits = self.alg.sample(obs)
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
self.predict_actions = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
behaviour_logits = layers.data(
name='behaviour_logits',
shape=[self.config['act_dim']],
dtype='float32')
rewards = layers.data(name='rewards', shape=[], dtype='float32')
dones = layers.data(name='dones', shape=[], dtype='float32')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
entropy_coeff = layers.data(
name='entropy_coeff', shape=[], dtype='float32')
self.learn_reader = fluid.layers.create_py_reader_by_data(
capacity=self.config['train_batch_size'],
feed_list=[
obs, actions, behaviour_logits, rewards, dones, lr,
entropy_coeff
])
obs, actions, behaviour_logits, rewards, dones, lr, entropy_coeff = fluid.layers.read_file(
self.learn_reader)
vtrace_loss, kl = self.alg.learn(obs, actions, behaviour_logits,
rewards, dones, lr, entropy_coeff)
self.learn_outputs = [
vtrace_loss.total_loss.name, vtrace_loss.pi_loss.name,
vtrace_loss.vf_loss.name, vtrace_loss.entropy.name, kl.name
]
def sample(self, obs_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space).
Format of image input should be NCHW format.
Returns:
sample_ids: a numpy int64 array of shape [B]
"""
obs_np = obs_np.astype('float32')
sample_actions, behaviour_logits = self.fluid_executor.run(
self.sample_program,
feed={'obs': obs_np},
fetch_list=[self.sample_actions, self.behaviour_logits])
return sample_actions, behaviour_logits
def predict(self, obs_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space)
Format of image input should be NCHW format.
Returns:
sample_ids: a numpy int64 array of shape [B]
"""
obs_np = obs_np.astype('float32')
predict_actions = self.fluid_executor.run(
self.predict_program,
feed={'obs': obs_np},
fetch_list=[self.predict_actions])[0]
return predict_actions
def learn(self):
total_loss, pi_loss, vf_loss, entropy, kl = self.learn_exe.run(
fetch_list=self.learn_outputs)
return total_loss, pi_loss, vf_loss, entropy, kl
def get_params(self):
return self.alg.get_params()
def set_params(self, params):
self.alg.set_params(params, gpu_id=self.gpu_id)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from paddle.fluid.param_attr import ParamAttr
class AtariModel(Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
num_filters=16, filter_size=4, stride=2, padding=1, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=4, stride=2, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=256, filter_size=11, stride=1, padding=0, act='relu')
self.policy_conv = layers.conv2d(
num_filters=act_dim,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
self.value_fc = layers.fc(
size=1,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
def policy(self, obs):
"""
Args:
obs: An float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
return policy_logits
def value(self, obs):
"""
Args:
obs: An float32 tensor of shape [B, C, H, W]
Returns:
value: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
value = self.value_fc(flatten)
value = layers.squeeze(value, axes=[1])
return value
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
config = {
'experiment_name': 'Pong',
#========== remote config ==========
'server_ip': 'localhost',
'server_port': 8037,
#========== env config ==========
'env_name': 'PongNoFrameskip-v4',
'env_dim': 42,
'obs_format': 'NHWC',
#========== actor config ==========
'env_num': 5,
'sample_batch_steps': 50,
#========== learner config ==========
'train_batch_size': 1000,
'learner_queue_max_size': 16,
'sample_queue_max_size': 8,
'gamma': 0.99,
# learning rate adjustment schedule: (train_step, learning_rate)
'lr_scheduler': [(0, 0.001), (20000, 0.0005), (40000, 0.0001)],
# coefficient of policy entropy adjustment schedule: (train_step, coefficient)
'entropy_coeff_scheduler': [(0, -0.01)],
'vf_loss_coeff': 0.5,
'clip_rho_threshold': 1.0,
'clip_pg_rho_threshold': 1.0,
'get_remote_metrics_interval': 10,
'log_metrics_interval_s': 10,
'params_broadcast_interval': 5,
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import os
import queue
import time
import threading
from atari_model import AtariModel
from atari_agent import AtariAgent
from parl import RemoteManager
from parl.algorithms import IMPALA
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat
class Learner(object):
def __init__(self, config):
self.config = config
self.learner_queue = queue.Queue(
maxsize=config['learner_queue_max_size'])
#=========== Create Agent ==========
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = IMPALA(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config, self.learn_data_provider)
self.cache_params = self.agent.get_params()
self.params_lock = threading.Lock()
self.params_updated = False
self.cache_params_sent_cnt = 0
self.total_params_sync = 0
#========== Learner ==========
self.lr, self.entropy_coeff = None, None
self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
self.total_loss_stat = WindowStat(100)
self.pi_loss_stat = WindowStat(100)
self.vf_loss_stat = WindowStat(100)
self.entropy_stat = WindowStat(100)
self.kl_stat = WindowStat(100)
self.learn_time_stat = TimeStat(100)
self.start_time = None
self.learn_thread = threading.Thread(target=self.run_learn)
self.learn_thread.setDaemon(True)
self.learn_thread.start()
#========== Remote Actor ===========
self.remote_count = 0
self.sample_data_queue = queue.Queue(
maxsize=config['sample_queue_max_size'])
self.batch_buffer = []
self.remote_metrics_queue = queue.Queue()
self.sample_total_steps = 0
self.remote_manager_thread = threading.Thread(
target=self.run_remote_manager)
self.remote_manager_thread.setDaemon(True)
self.remote_manager_thread.start()
self.csv_logger = CSVLogger(
os.path.join(logger.get_dir(), 'result.csv'))
def learn_data_provider(self):
""" Data generator for fluid.layers.py_reader
"""
while True:
batch = self.learner_queue.get()
obs_np = batch['obs'].astype('float32')
actions_np = batch['actions'].astype('int64')
behaviour_logits_np = batch['behaviour_logits'].astype('float32')
rewards_np = batch['rewards'].astype('float32')
dones_np = batch['dones'].astype('float32')
self.lr = self.lr_scheduler.step()
self.entropy_coeff = self.entropy_coeff_scheduler.step()
yield [
obs_np, actions_np, behaviour_logits_np, rewards_np, dones_np,
self.lr, self.entropy_coeff
]
def run_learn(self):
""" Learn loop
"""
while True:
with self.learn_time_stat:
total_loss, pi_loss, vf_loss, entropy, kl = self.agent.learn()
self.params_updated = True
self.total_loss_stat.add(total_loss)
self.pi_loss_stat.add(pi_loss)
self.vf_loss_stat.add(vf_loss)
self.entropy_stat.add(entropy)
self.kl_stat.add(kl)
def run_remote_manager(self):
""" Accept connection of new remote actor and start sampling of the remote actor.
"""
remote_manager = RemoteManager(port=self.config['server_port'])
logger.info('Waiting for remote actors connecting.')
while True:
remote_actor = remote_manager.get_remote()
self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count))
if self.start_time is None:
self.start_time = time.time()
remote_thread = threading.Thread(
target=self.run_remote_sample, args=(remote_actor, ))
remote_thread.setDaemon(True)
remote_thread.start()
def run_remote_sample(self, remote_actor):
""" Sample data from remote actor and update parameters of remote actor.
"""
cnt = 0
remote_actor.set_params(self.cache_params)
while True:
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
cnt += 1
if cnt % self.config['get_remote_metrics_interval'] == 0:
metrics = remote_actor.get_metrics()
if metrics:
self.remote_metrics_queue.put(metrics)
self.params_lock.acquire()
if self.params_updated and self.cache_params_sent_cnt >= self.config[
'params_broadcast_interval']:
self.params_updated = False
self.cache_params = self.agent.get_params()
self.cache_params_sent_cnt = 0
self.cache_params_sent_cnt += 1
self.total_params_sync += 1
self.params_lock.release()
remote_actor.set_params(self.cache_params)
def step(self):
""" Merge and generate batch learn data from sample_data_queue,
and put it in learner_queue.
"""
assert self.learn_thread.is_alive()
while True:
try:
sample_data = self.sample_data_queue.get_nowait()
self.sample_total_steps += sample_data['obs'].shape[0]
self.batch_buffer.append(sample_data)
buffer_size = sum(
[data['obs'].shape[0] for data in self.batch_buffer])
if buffer_size >= self.config['train_batch_size']:
train_batch = {}
for key in self.batch_buffer[0].keys():
train_batch[key] = np.concatenate(
[data[key] for data in self.batch_buffer])
self.learner_queue.put(train_batch)
self.batch_buffer = []
except queue.Empty:
break
def log_metrics(self):
""" Log metrics of learner and actors
"""
if self.start_time is None:
return
metrics = []
while True:
try:
metric = self.remote_metrics_queue.get_nowait()
metrics.append(metric)
except queue.Empty:
break
episode_rewards, episode_steps = [], []
for x in metrics:
episode_rewards.extend(x['episode_rewards'])
episode_steps.extend(x['episode_steps'])
max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
max_episode_steps, mean_episode_steps, min_episode_steps =\
None, None, None, None, None, None
if episode_rewards:
mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
max_episode_rewards = np.max(np.array(episode_rewards).flatten())
min_episode_rewards = np.min(np.array(episode_rewards).flatten())
mean_episode_steps = np.mean(np.array(episode_steps).flatten())
max_episode_steps = np.max(np.array(episode_steps).flatten())
min_episode_steps = np.min(np.array(episode_steps).flatten())
metric = {
'Sample steps': self.sample_total_steps,
'max_episode_rewards': max_episode_rewards,
'mean_episode_rewards': mean_episode_rewards,
'min_episode_rewards': min_episode_rewards,
'max_episode_steps': max_episode_steps,
'mean_episode_steps': mean_episode_steps,
'min_episode_steps': min_episode_steps,
'learner_queue_size': self.learner_queue.qsize(),
'sample_queue_size': self.sample_data_queue.qsize(),
'total_params_sync': self.total_params_sync,
'cache_params_sent_cnt': self.cache_params_sent_cnt,
'total_loss': self.total_loss_stat.mean,
'pi_loss': self.pi_loss_stat.mean,
'vf_loss': self.vf_loss_stat.mean,
'entropy': self.entropy_stat.mean,
'kl': self.kl_stat.mean,
'learn_time_s': self.learn_time_stat.mean,
'elapsed_time_s': int(time.time() - self.start_time),
'lr': self.lr,
'entropy_coeff': self.entropy_coeff,
}
logger.info(metric)
self.csv_logger.log_dict(metric)
def close(self):
self.csv_logger.close()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""
for index in {1..32}; do
python actor.py &
done;
wait
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from learner import Learner
def main(config):
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
try:
while True:
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
except KeyboardInterrupt:
learner.close()
if __name__ == '__main__':
from impala_config import config
main(config)
......@@ -16,3 +16,4 @@ from parl.algorithms.ddpg import *
from parl.algorithms.dqn import *
from parl.algorithms.policy_gradient import *
from parl.algorithms.ppo import *
from parl.algorithms.impala.impala import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.impala.impala import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.algorithms.impala import vtrace
from parl.framework.algorithm_base import Algorithm
from parl.framework.policy_distribution import CategoricalDistribution
from parl.plutils import inverse
__all__ = ['IMPALA']
class VTraceLoss(object):
def __init__(self,
behaviour_actions_log_probs,
target_actions_log_probs,
policy_entropy,
dones,
discount,
rewards,
values,
bootstrap_value,
entropy_coeff=-0.01,
vf_loss_coeff=0.5,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0):
"""Policy gradient loss with vtrace importance weighting.
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
batch_size. The reason we need to know `B` is for V-trace to properly
handle episode cut boundaries.
Args:
behaviour_actions_log_probs: A float32 tensor of shape [T, B].
target_actions_log_probs: A float32 tensor of shape [T, B].
policy_entropy: A float32 tensor of shape [T, B].
dones: A float32 tensor of shape [T, B].
discount: A float32 scalar.
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
"""
self.vtrace_returns = vtrace.from_importance_weights(
behaviour_actions_log_probs=behaviour_actions_log_probs,
target_actions_log_probs=target_actions_log_probs,
discounts=inverse(dones) * discount,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)
# The policy gradients loss
self.pi_loss = -1.0 * layers.reduce_sum(
target_actions_log_probs * self.vtrace_returns.pg_advantages)
# The baseline loss
delta = values - self.vtrace_returns.vs
self.vf_loss = 0.5 * layers.reduce_sum(layers.square(delta))
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
self.entropy = layers.reduce_sum(policy_entropy)
# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
self.entropy * entropy_coeff)
class IMPALA(Algorithm):
def __init__(self, model, hyperparas):
super(IMPALA, self).__init__(model, hyperparas)
def learn(self, obs, actions, behaviour_logits, rewards, dones,
learning_rate, entropy_coeff):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
actions: An int64 tensor of shape [B].
behaviour_logits: A float32 tensor of shape [B, NUM_ACTIONS].
rewards: A float32 tensor of shape [B].
dones: A float32 tensor of shape [B].
learning_rate: float scalar of learning rate.
entropy_coeff: float scalar of entropy coefficient.
"""
values = self.model.value(obs)
target_logits = self.model.policy(obs)
target_policy_distribution = CategoricalDistribution(target_logits)
behaviour_policy_distribution = CategoricalDistribution(
behaviour_logits)
policy_entropy = target_policy_distribution.entropy()
target_actions_log_probs = target_policy_distribution.logp(actions)
behaviour_actions_log_probs = behaviour_policy_distribution.logp(
actions)
# Calculating kl for debug
kl = target_policy_distribution.kl(behaviour_policy_distribution)
kl = layers.reduce_mean(kl)
"""
Split the tensor into batches at known episode cut boundaries.
[B * T] -> [T, B]
"""
T = self.hp["sample_batch_steps"]
def split_batches(tensor):
B = tensor.shape[0] // T
splited_tensor = layers.reshape(tensor,
[B, T] + list(tensor.shape[1:]))
# transpose B and T
return layers.transpose(
splited_tensor, [1, 0] + list(range(2, 1 + len(tensor.shape))))
behaviour_actions_log_probs = split_batches(
behaviour_actions_log_probs)
target_actions_log_probs = split_batches(target_actions_log_probs)
policy_entropy = split_batches(policy_entropy)
dones = split_batches(dones)
rewards = split_batches(rewards)
values = split_batches(values)
# [T, B] -> [T - 1, B] for V-trace calc.
behaviour_actions_log_probs = layers.slice(
behaviour_actions_log_probs, axes=[0], starts=[0], ends=[-1])
target_actions_log_probs = layers.slice(
target_actions_log_probs, axes=[0], starts=[0], ends=[-1])
policy_entropy = layers.slice(
policy_entropy, axes=[0], starts=[0], ends=[-1])
dones = layers.slice(dones, axes=[0], starts=[0], ends=[-1])
rewards = layers.slice(rewards, axes=[0], starts=[0], ends=[-1])
bootstrap_value = layers.slice(
values, axes=[0], starts=[T - 1], ends=[T])
values = layers.slice(values, axes=[0], starts=[0], ends=[-1])
bootstrap_value = layers.squeeze(bootstrap_value, axes=[0])
vtrace_loss = VTraceLoss(
behaviour_actions_log_probs=behaviour_actions_log_probs,
target_actions_log_probs=target_actions_log_probs,
policy_entropy=policy_entropy,
dones=dones,
discount=self.hp['gamma'],
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
entropy_coeff=entropy_coeff,
vf_loss_coeff=self.hp['vf_loss_coeff'],
clip_rho_threshold=self.hp['clip_rho_threshold'],
clip_pg_rho_threshold=self.hp['clip_pg_rho_threshold'])
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=40.0))
optimizer = fluid.optimizer.AdamOptimizer(learning_rate)
optimizer.minimize(vtrace_loss.total_loss)
return vtrace_loss, kl
def sample(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
policy_dist = CategoricalDistribution(logits)
sample_actions = policy_dist.sample()
return sample_actions, logits
def predict(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
probs = layers.softmax(logits)
predict_actions = layers.argmax(probs, 1)
return predict_actions
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for V-trace.
The following code is mainly referenced and copied from:
https://github.com/deepmind/scalable_agent/blob/master/vtrace_test.py
"""
import copy
import numpy as np
import unittest
import parl.layers as layers
from paddle import fluid
from parameterized import parameterized
from parl.algorithms.impala import vtrace
from parl.utils import get_gpu_count
def _shaped_arange(*shape):
"""Runs np.arange, converts to float and reshapes."""
return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
def _ground_truth_calculation(behaviour_actions_log_probs,
target_actions_log_probs, discounts, rewards,
values, bootstrap_value, clip_rho_threshold,
clip_pg_rho_threshold):
"""Calculates the ground truth for V-trace in Python/Numpy."""
log_rhos = target_actions_log_probs - behaviour_actions_log_probs
vs = []
seq_len = len(discounts)
rhos = np.exp(log_rhos)
cs = np.minimum(rhos, 1.0)
clipped_rhos = rhos
if clip_rho_threshold:
clipped_rhos = np.minimum(rhos, clip_rho_threshold)
clipped_pg_rhos = rhos
if clip_pg_rho_threshold:
clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold)
# This is a very inefficient way to calculate the V-trace ground truth.
# We calculate it this way because it is close to the mathematical notation of
# V-trace.
# v_s = V(x_s)
# + \sum^{T-1}_{t=s} \gamma^{t-s}
# * \prod_{i=s}^{t-1} c_i
# * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))
# Note that when we take the product over c_i, we write `s:t` as the notation
# of the paper is inclusive of the `t-1`, but Python is exclusive.
# Also note that np.prod([]) == 1.
values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]],
axis=0)
for s in range(seq_len):
v_s = np.copy(values[s]) # Very important copy.
for t in range(s, seq_len):
v_s += (np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0)
* clipped_rhos[t] * (rewards[t] + discounts[t] *
values_t_plus_1[t + 1] - values[t]))
vs.append(v_s)
vs = np.stack(vs, axis=0)
pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate(
[vs[1:], bootstrap_value[None, :]], axis=0) - values))
return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages)
class VtraceTest(unittest.TestCase):
def setUp(self):
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
else:
place = fluid.CPUPlace()
self.gpu_id = -1
self.executor = fluid.Executor(place)
@parameterized.expand([('Batch1', 1), ('Batch4', 4)])
def test_from_importance_weights(self, name, batch_size):
"""Tests V-trace against ground truth data calculated in python."""
seq_len = 5
# Create log_rhos such that rho will span from near-zero to above the
# clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5),
# so that rho is in approx [0.08, 12.2).
log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5).
# Fake behaviour_actions_log_probs, target_actions_log_probs
target_actions_log_probs = log_rhos + 1.0
behaviour_actions_log_probs = np.ones(
shape=log_rhos.shape, dtype='float32')
values = {
'behaviour_actions_log_probs':
behaviour_actions_log_probs,
'target_actions_log_probs':
target_actions_log_probs,
# T, B where B_i: [0.9 / (i+1)] * T
'discounts':
np.array([[0.9 / (b + 1) for b in range(batch_size)]
for _ in range(seq_len)],
dtype=np.float32),
'rewards':
_shaped_arange(seq_len, batch_size),
'values':
_shaped_arange(seq_len, batch_size) / batch_size,
'bootstrap_value':
_shaped_arange(batch_size) + 1.0,
'clip_rho_threshold':
3.7,
'clip_pg_rho_threshold':
2.2,
}
# Calculated by numpy/python
ground_truth_v = _ground_truth_calculation(**values)
# Calculated by Fluid
test_program = fluid.Program()
with fluid.program_guard(test_program):
behaviour_actions_log_probs_input = layers.data(
name='behaviour_actions_log_probs',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
target_actions_log_probs_input = layers.data(
name='target_actions_log_probs',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
discounts_input = layers.data(
name='discounts',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
rewards_input = layers.data(
name='rewards',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
values_input = layers.data(
name='values',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
bootstrap_value_input = layers.data(
name='bootstrap_value',
shape=[batch_size],
dtype='float32',
append_batch_size=False)
fluid_inputs = {
'behaviour_actions_log_probs':
behaviour_actions_log_probs_input,
'target_actions_log_probs': target_actions_log_probs_input,
'discounts': discounts_input,
'rewards': rewards_input,
'values': values_input,
'bootstrap_value': bootstrap_value_input,
'clip_rho_threshold': 3.7,
'clip_pg_rho_threshold': 2.2,
}
output = vtrace.from_importance_weights(**fluid_inputs)
self.executor.run(fluid.default_startup_program())
feed = copy.copy(values)
del feed['clip_rho_threshold']
del feed['clip_pg_rho_threshold']
[output_vs, output_pg_advantage] = self.executor.run(
test_program,
feed=feed,
fetch_list=[output.vs, output.pg_advantages])
np.testing.assert_almost_equal(ground_truth_v.vs, output_vs, 5)
np.testing.assert_almost_equal(ground_truth_v.pg_advantages,
output_pg_advantage, 5)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to compute V-trace off-policy actor critic targets,
which used in IMAPLA algorithm.
The following code is mainly referenced and copied from:
https://github.com/deepmind/scalable_agent/blob/master/vtrace.py
For details and theory see:
"Espeholt L, Soyer H, Munos R, et al. Impala: Scalable distributed
deep-rl with importance weighted actor-learner
architectures[J]. arXiv preprint arXiv:1802.01561, 2018."
"""
import collections
import paddle.fluid as fluid
import parl.layers as layers
from parl.utils import MAX_INT32
VTraceReturns = collections.namedtuple('VTraceReturns',
['vs', 'pg_advantages'])
def from_importance_weights(behaviour_actions_log_probs,
target_actions_log_probs,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
r"""V-trace for softmax policies.
Calculates V-trace actor critic targets for softmax polices as described in
"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.
Target policy refers to the policy we are interested in improving and
behaviour policy refers to the policy that generated the given
rewards and actions.
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions.
Args:
behaviour_actions_log_probs: A float32 tensor of shape [T, B] of
log-probabilities of actions in behaviour policy.
target_policy_logits: A float32 tensor of shape [T, B] of
log-probabilities of actions in target policy.
discounts: A float32 tensor of shape [T, B] with the discount encountered
when following the behaviour policy.
rewards: A float32 tensor of shape [T, B] with the rewards generated by
following the behaviour policy.
values: A float32 tensor of shape [T, B] with the value function estimates
wrt. the target policy.
bootstrap_value: A float32 of shape [B] with the value function estimate at
time T.
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
importance weights (rho) when calculating the baseline targets (vs).
rho^bar in the paper.
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
name: The name scope that all V-trace operations will be created in.
Returns:
A VTraceReturns namedtuple (vs, pg_advantages) where:
vs: A float32 tensor of shape [T, B]. Can be used as target to
train a baseline (V(x_t) - vs_t)^2.
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
advantage in the calculation of policy gradients.
"""
rank = len(behaviour_actions_log_probs.shape) # Usually 2.
assert len(target_actions_log_probs.shape) == rank
assert len(values.shape) == rank
assert len(bootstrap_value.shape) == (rank - 1)
assert len(discounts.shape) == rank
assert len(rewards.shape) == rank
# log importance sampling weights.
# V-trace performs operations on rhos in log-space for numerical stability.
log_rhos = target_actions_log_probs - behaviour_actions_log_probs
if clip_rho_threshold is not None:
clip_rho_threshold = layers.fill_constant([1], 'float32',
clip_rho_threshold)
if clip_pg_rho_threshold is not None:
clip_pg_rho_threshold = layers.fill_constant([1], 'float32',
clip_pg_rho_threshold)
rhos = layers.exp(log_rhos)
if clip_rho_threshold is not None:
clipped_rhos = layers.elementwise_min(rhos, clip_rho_threshold)
else:
clipped_rhos = rhos
constant_one = layers.fill_constant([1], 'float32', 1.0)
cs = layers.elementwise_min(rhos, constant_one)
# Append bootstrapped value to get [v1, ..., v_t+1]
values_1_t = layers.slice(values, axes=[0], starts=[1], ends=[MAX_INT32])
values_t_plus_1 = layers.concat(
[values_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0)
# \delta_s * V
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
vs_minus_v_xs = recursively_scan(discounts, cs, deltas)
# Add V(x_s) to get v_s.
vs = layers.elementwise_add(vs_minus_v_xs, values)
# Advantage for policy gradient.
vs_1_t = layers.slice(vs, axes=[0], starts=[1], ends=[MAX_INT32])
vs_t_plus_1 = layers.concat(
[vs_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0)
if clip_pg_rho_threshold is not None:
clipped_pg_rhos = layers.elementwise_min(rhos, clip_pg_rho_threshold)
else:
clipped_pg_rhos = rhos
pg_advantages = (
clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
# Make sure no gradients backpropagated through the returned values.
vs.stop_gradient = True
pg_advantages.stop_gradient = True
return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
def recursively_scan(discounts, cs, deltas):
""" Recursively calculate vs_minus_v_xs according to following equation:
vs_minus_v_xs(t) = deltas(t) + discounts(t) * cs(t) * vs_minus_v_xs(t + 1)
Args:
discounts: A float32 tensor of shape [T, B] with discounts encountered when
following the behaviour policy.
cs: A float32 tensor of shape [T, B], which corresponding to $c_s$ in the
origin paper.
deltas: A float32 tensor of shape [T, B], which corresponding to
$\delta_s * V$ in the origin paper.
Returns:
vs_minus_v_xs: A float32 tensor of shape [T, B], which corresponding to
$v_s - V(x_s)$ in the origin paper.
"""
# All sequences are reversed, computation starts from the back.
reverse_discounts = layers.reverse(x=discounts, axis=[0])
reverse_cs = layers.reverse(x=cs, axis=[0])
reverse_deltas = layers.reverse(x=deltas, axis=[0])
static_while = layers.StaticRNN()
# init: shape [B]
init = layers.fill_constant_batch_size_like(
discounts, shape=[1], dtype='float32', value=0.0, input_dim_idx=1)
with static_while.step():
discount_t = static_while.step_input(reverse_discounts)
c_t = static_while.step_input(reverse_cs)
delta_t = static_while.step_input(reverse_deltas)
vs_minus_v_xs_t_plus_1 = static_while.memory(init=init)
vs_minus_v_xs_t = delta_t + discount_t * c_t * vs_minus_v_xs_t_plus_1
static_while.update_memory(vs_minus_v_xs_t_plus_1, vs_minus_v_xs_t)
static_while.step_output(vs_minus_v_xs_t)
vs_minus_v_xs = static_while()
# Reverse the results back to original order.
vs_minus_v_xs = layers.reverse(vs_minus_v_xs, [0])
return vs_minus_v_xs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.env.vector_env import *
# Third party code
#
# The following code are copied or modified from:
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/atari_wrappers.py
import numpy as np
from collections import deque
import gym
from gym import spaces
import cv2
cv2.ocl.setUseOpenCL(False)
def get_wrapper_by_cls(env, cls):
"""Returns the gym env wrapper of the given class, or None."""
currentenv = env
while True:
if isinstance(currentenv, cls):
return currentenv
elif isinstance(currentenv, gym.Wrapper):
currentenv = currentenv.env
else:
return None
class MonitorEnv(gym.Wrapper):
def __init__(self, env=None):
"""Record episodes stats prior to EpisodicLifeEnv, etc."""
gym.Wrapper.__init__(self, env)
self._current_reward = None
self._num_steps = None
self._total_steps = None
self._episode_rewards = []
self._episode_lengths = []
self._num_episodes = 0
self._num_returned = 0
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
if self._total_steps is None:
self._total_steps = sum(self._episode_lengths)
if self._current_reward is not None:
self._episode_rewards.append(self._current_reward)
self._episode_lengths.append(self._num_steps)
self._num_episodes += 1
self._current_reward = 0
self._num_steps = 0
return obs
def step(self, action):
obs, rew, done, info = self.env.step(action)
self._current_reward += rew
self._num_steps += 1
self._total_steps += 1
return (obs, rew, done, info)
def get_episode_rewards(self):
return self._episode_rewards
def get_episode_lengths(self):
return self._episode_lengths
def get_total_steps(self):
return self._total_steps
def next_episode_results(self):
for i in range(self._num_returned, len(self._episode_rewards)):
yield (self._episode_rewards[i], self._episode_lengths[i])
self._num_returned = len(self._episode_rewards)
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self, **kwargs):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:
obs = self.env.reset(**kwargs)
return obs
def step(self, ac):
return self.env.step(ac)
class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset.
For environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs
def step(self, ac):
return self.env.step(ac)
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condtion for a few fr
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros(
(2, ) + env.observation_space.shape, dtype=np.uint8)
self._skip = skip
def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env, dim):
"""Warp frames to the specified size (dim x dim)."""
gym.ObservationWrapper.__init__(self, env)
self.width = dim
self.height = dim
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width), dtype=np.uint8)
def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(
frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
return frame
class FrameStack(gym.Wrapper):
def __init__(self, env, k, obs_format='NHWC'):
"""Stack k last frames."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
assert obs_format == 'NHWC' or obs_format == 'NCHW'
self.obs_format = obs_format
if obs_format == 'NHWC':
obs_shape = (shp[0], shp[1], k)
else:
obs_shape = (k, shp[0], shp[1])
self.observation_space = spaces.Box(
low=0,
high=255,
shape=obs_shape,
dtype=env.observation_space.dtype)
def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
def _get_ob(self):
assert len(self.frames) == self.k
if self.obs_format == 'NHWC':
return np.stack(self.frames, axis=2)
else:
return np.array(self.frames)
def wrap_deepmind(env, dim=84, framestack=True, obs_format='NHWC'):
"""Configure environment for DeepMind-style Atari.
Args:
dim (int): Dimension to resize observations to (dim x dim).
framestack (bool): Whether to framestack observations.
"""
env = MonitorEnv(env)
env = NoopResetEnv(env, noop_max=30)
if 'NoFrameskip' in env.spec.id:
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env, dim)
env = ClipRewardEnv(env)
if framestack:
env = FrameStack(env, 4, obs_format)
return env
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from collections import defaultdict
__all__ = ['VectorEnv']
class VectorEnv(object):
""" vector of envs to support vector reset and vector step.
`vector_step` api will automatically reset envs which are done.
"""
def __init__(self, envs):
"""
Args:
envs: List of env
"""
self.envs = envs
self.envs_num = len(envs)
def reset(self):
"""
Returns:
List of obs
"""
return [env.reset() for env in self.envs]
def step(self, actions):
"""
Args:
actions: List or array of action
Returns:
obs_batch: List of next obs of envs
reward_batch: List of return reward of envs
done_batch: List of done of envs
info_batch: List of info of envs
"""
obs_batch, reward_batch, done_batch, info_batch = [], [], [], []
for env_id in six.moves.range(self.envs_num):
obs, reward, done, info = self.envs[env_id].step(actions[env_id])
if done:
obs = self.envs[env_id].reset()
obs_batch.append(obs)
reward_batch.append(reward)
done_batch.append(done)
info_batch.append(info)
return obs_batch, reward_batch, done_batch, info_batch
......@@ -54,3 +54,20 @@ class Algorithm(object):
3. optimize model defined in Model
"""
raise NotImplementedError()
def get_params(self):
""" Get parameters of self.model
Returns:
List of numpy array.
"""
return self.model.get_params()
def set_params(self, params, gpu_id):
""" Set parameters of self.model
Args:
params: List of numpy array.
gpu_id: gpu id where self.model in. (if gpu_id < 0, means in cpu.)
"""
self.model.set_params(params, gpu_id=gpu_id)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
__all__ = ['PolicyDistribution', 'CategoricalDistribution']
class PolicyDistribution(object):
def sample(self):
"""Sampling from the policy distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the policy distribution."""
raise NotImplementedError
def kl(self, other):
"""The KL-divergence between self policy distributions and other."""
raise NotImplementedError
def logp(self, actions):
"""The log-probabilities of the actions in this policy distribution."""
raise NotImplementedError
class CategoricalDistribution(PolicyDistribution):
"""Categorical distribution for discrete action spaces."""
def __init__(self, logits):
"""
Args:
logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
"""
assert len(logits.shape) == 2
self.logits = logits
def sample(self):
"""
Returns:
sample_action: An int64 tensor with shape [BATCH_SIZE] of multinomial sampling ids.
Each value in sample_action is in [0, NUM_ACTIOINS - 1]
"""
probs = layers.softmax(self.logits)
sample_actions = layers.sampling_id(probs)
return sample_actions
def entropy(self):
"""
Returns:
entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
"""
logits = self.logits - layers.reduce_max(self.logits, dim=1)
e_logits = layers.exp(logits)
z = layers.reduce_sum(e_logits, dim=1)
prob = e_logits / z
entropy = -1.0 * layers.reduce_sum(
prob * (logits - layers.log(z)), dim=1)
return entropy
def logp(self, actions):
"""
Args:
actions: An int64 tensor with shape [BATCH_SIZE]
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
assert len(actions.shape) == 1
actions = layers.unsqueeze(actions, axes=[1])
cross_entropy = layers.softmax_with_cross_entropy(
logits=self.logits, label=actions)
actions_log_prob = -1.0 * layers.squeeze(cross_entropy, axes=[-1])
return actions_log_prob
def kl(self, other):
"""
Args:
other: object of CategoricalDistribution
Returns:
kl: A float32 tensor with shape [BATCH_SIZE]
"""
assert isinstance(other, CategoricalDistribution)
logits = self.logits - layers.reduce_max(self.logits, dim=1)
other_logits = other.logits - layers.reduce_max(other.logits, dim=1)
e_logits = layers.exp(logits)
other_e_logits = layers.exp(other_logits)
z = layers.reduce_sum(e_logits, dim=1)
other_z = layers.reduce_sum(other_e_logits, dim=1)
prob = e_logits / z
kl = layers.reduce_sum(
prob *
(logits - layers.log(z) - other_logits + layers.log(other_z)),
dim=1)
return kl
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
import unittest
from paddle import fluid
from parameterized import parameterized
from parl.framework.policy_distribution import *
from parl.utils import get_gpu_count, np_softmax, np_cross_entropy
class PolicyDistributionTest(unittest.TestCase):
def setUp(self):
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
else:
place = fluid.CPUPlace()
self.gpu_id = -1
self.executor = fluid.Executor(place)
@parameterized.expand([('Batch1', 1), ('Batch5', 5)])
def test_categorical_distribution(self, name, batch_size):
ACTIONS_NUM = 4
test_program = fluid.Program()
with fluid.program_guard(test_program):
logits = layers.data(
name='logits', shape=[ACTIONS_NUM], dtype='float32')
other_logits = layers.data(
name='other_logits', shape=[ACTIONS_NUM], dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
categorical_distribution = CategoricalDistribution(logits)
other_categorical_distribution = CategoricalDistribution(
other_logits)
sample_actions = categorical_distribution.sample()
entropy = categorical_distribution.entropy()
actions_logp = categorical_distribution.logp(actions)
kl = categorical_distribution.kl(other_categorical_distribution)
self.executor.run(fluid.default_startup_program())
logits_np = np.random.randn(batch_size, ACTIONS_NUM).astype('float32')
other_logits_np = np.random.randn(batch_size,
ACTIONS_NUM).astype('float32')
actions_np = np.random.randint(
0, high=ACTIONS_NUM, size=(batch_size, 1), dtype='int64')
# ground truth calculated by numpy/python
gt_probs = np_softmax(logits_np)
gt_other_probs = np_softmax(other_logits_np)
gt_log_probs = np.log(gt_probs)
gt_entropy = -1.0 * np.sum(gt_probs * gt_log_probs, axis=1)
gt_actions_logp = -1.0 * np_cross_entropy(
np_softmax(logits_np), actions_np)
gt_actions_logp = np.squeeze(gt_actions_logp, -1)
gt_kl = np.sum(
np.where(gt_probs != 0,
gt_probs * np.log(gt_probs / gt_other_probs), 0),
axis=-1)
# result calculated by CategoricalDistribution
[
output_sample_actions, output_entropy, output_actions_logp,
output_kl
] = self.executor.run(
program=test_program,
feed={
'logits': logits_np,
'other_logits': other_logits_np,
'actions': np.squeeze(actions_np, axis=1)
},
fetch_list=[sample_actions, entropy, actions_logp, kl])
# test entropy
np.testing.assert_almost_equal(output_entropy, gt_entropy, 5)
# test logp
np.testing.assert_almost_equal(output_actions_logp, gt_actions_logp, 5)
# test sample
action_ids = np.arange(ACTIONS_NUM)
assert np.isin(output_sample_actions, action_ids).all()
# test kl
np.testing.assert_almost_equal(output_kl, gt_kl, 5)
if __name__ == '__main__':
unittest.main()
......@@ -18,7 +18,7 @@ Common functions of PARL framework
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
__all__ = ['fetch_framework_var', 'fetch_value', 'set_value']
__all__ = ['fetch_framework_var', 'fetch_value', 'set_value', 'inverse']
def fetch_framework_var(attr_name):
......@@ -64,3 +64,16 @@ def set_value(attr_name, value, gpu_id):
else fluid.CUDAPlace(gpu_id)
var = _fetch_var(attr_name, return_numpy=False)
var.set(value, place)
def inverse(x):
""" Inverse 0/1 variable
Args:
x: variable with float32 dtype
Returns:
inverse_x: variable with float32 dtype
"""
inverse_x = -1.0 * x + 1.0
return inverse_x
......@@ -14,5 +14,8 @@
from parl.utils.exceptions import *
from parl.utils.utils import *
from parl.utils.csv_logger import *
from parl.utils.machine_info import *
from parl.utils.np_utils import *
from parl.utils.replay_memory import *
from parl.utils.scheduler import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
__all__ = ['CSVLogger']
class CSVLogger(object):
def __init__(self, output_file):
"""CSV Logger which can write dict result to csv file
"""
self.output_file = open(output_file, "w")
self.csv_writer = None
def log_dict(self, result):
if self.csv_writer is None:
self.csv_writer = csv.DictWriter(self.output_file, result.keys())
self.csv_writer.writeheader()
self.csv_writer.writerow({
k: v
for k, v in result.items() if k in self.csv_writer.fieldnames
})
def flush(self):
self.output_file.flush()
def close(self):
self.output_file.close()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
__all__ = ['np_softmax', 'np_cross_entropy']
def np_softmax(logits):
return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)
def np_cross_entropy(probs, labels):
if labels.shape[-1] == 1:
# sparse label
n_classes = probs.shape[-1]
result_shape = list(labels.shape[:-1]) + [n_classes]
labels = np.eye(n_classes)[labels.reshape(-1)]
labels = labels.reshape(result_shape)
return -np.sum(labels * np.log(probs), axis=-1, keepdims=True)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
__all__ = ['PiecewiseScheduler']
class PiecewiseScheduler(object):
def __init__(self, scheduler_list):
""" Piecewise scheduler of hyper parameter.
Args:
scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)]
"""
assert len(scheduler_list) > 0
for i in six.moves.range(len(scheduler_list) - 1):
assert scheduler_list[i][0] < scheduler_list[i + 1][0], \
'step of scheduler_list should be incremental.'
self.scheduler_list = scheduler_list
self.cur_index = 0
self.cur_step = 0
self.cur_value = self.scheduler_list[0][1]
self.scheduler_num = len(self.scheduler_list)
def step(self):
""" Step one and fetch value according to following rule:
Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)],
function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1
"""
if self.cur_index < self.scheduler_num - 1:
if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]:
self.cur_index += 1
self.cur_value = self.scheduler_list[self.cur_index][1]
self.cur_step += 1
return self.cur_value
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.utils.scheduler import PiecewiseScheduler
class TestScheduler(unittest.TestCase):
def test_PiecewiseScheduler_with_multi_values(self):
scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)])
for i in range(10):
value = scheduler.step()
if i < 3:
assert value == 0.1
elif i < 7:
assert value == 0.2
else:
assert value == 0.3
def test_PiecewiseScheduler_with_one_value(self):
scheduler = PiecewiseScheduler([(0, 0.1)])
for i in range(10):
value = scheduler.step()
assert value == 0.1
scheduler = PiecewiseScheduler([(3, 0.1)])
for i in range(10):
value = scheduler.step()
assert value == 0.1
def test_PiecewiseScheduler_with_empty(self):
try:
scheduler = PiecewiseScheduler([])
except AssertionError:
# expected
return
assert False
def test_PiecewiseScheduler_with_incorrect_steps(self):
try:
scheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)])
except AssertionError:
# expected
return
assert False
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from parl.utils.window_stat import WindowStat
__all_ = ['TimeStat']
class TimeStat(object):
"""A time stat for logging the elapsed time of code running
Example:
time_stat = TimeStat()
with time_stat:
// some code
print(time_stat.mean)
"""
def __init__(self, window_size=1):
self.time_samples = WindowStat(window_size)
self._start_time = None
def __enter__(self):
self._start_time = time.time()
def __exit__(self, type, value, tb):
time_delta = time.time() - self._start_time
self.time_samples.add(time_delta)
@property
def mean(self):
return self.time_samples.mean
@property
def min(self):
return self.time_samples.min
@property
def max(self):
return self.time_samples.max
......@@ -15,7 +15,8 @@
import sys
__all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3'
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3',
'MAX_INT32'
]
......@@ -68,3 +69,6 @@ def is_PY2():
def is_PY3():
return sys.version_info[0] == 3
MAX_INT32 = 0x7fffffff
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['WindowStat']
import numpy as np
class WindowStat(object):
""" Tool to maintain statistical data in a window.
"""
def __init__(self, window_size):
self.items = [None] * window_size
self.idx = 0
self.count = 0
def add(self, obj):
self.items[self.idx] = obj
self.idx += 1
self.count += 1
self.idx %= len(self.items)
@property
def mean(self):
if self.count > 0:
return np.mean(self.items[:self.count])
else:
return None
@property
def min(self):
if self.count > 0:
return np.min(self.items[:self.count])
else:
return None
@property
def max(self):
if self.count > 0:
return np.max(self.items[:self.count])
else:
return None
......@@ -35,5 +35,7 @@ setup(
package_data={'': ['*.so']},
install_requires=[
"termcolor>=1.1.0",
"pyzmq>=17.1.2",
"pyarrow>=0.12.0",
],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册