diff --git a/.gitignore b/.gitignore index 956065769a0734c2a267a87e0fe7bb5fa5bea3fb..0dd122d3d331c8bff7f27fd2703c134f7c177064 100644 --- a/.gitignore +++ b/.gitignore @@ -14,7 +14,6 @@ __pycache__/ # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt index 47cf56b71616c94412b8a30801c8738455ad5045..5e3600128b668d7a1e0fce7480a4cbf1d614e665 100644 --- a/.teamcity/requirements.txt +++ b/.teamcity/requirements.txt @@ -4,3 +4,4 @@ details termcolor pyarrow zmq +parameterized diff --git a/README.md b/README.md index 629634ab9c5fba80d8b716794e6b293600c0619b..fbeff6689561a48012f9ed17ccd70e5b903af888 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ pip install parl - [DQN](examples/DQN/) - [DDPG](examples/DDPG/) - [PPO](examples/PPO/) +- [IMPALA](examples/IMPALA/) - [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/) NeurlIPS2018 Half-Cheetah Breakout diff --git a/examples/IMPALA/.benchmark/IMPALA_BeamRider.jpg b/examples/IMPALA/.benchmark/IMPALA_BeamRider.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d112b5c7596a7d7ef98c711c70e536939f08d0b9 Binary files /dev/null and b/examples/IMPALA/.benchmark/IMPALA_BeamRider.jpg differ diff --git a/examples/IMPALA/.benchmark/IMPALA_Breakout.jpg b/examples/IMPALA/.benchmark/IMPALA_Breakout.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91693816f986891a8ad14aed93b8e7d073363efe Binary files /dev/null and b/examples/IMPALA/.benchmark/IMPALA_Breakout.jpg differ diff --git a/examples/IMPALA/.benchmark/IMPALA_Pong.jpg b/examples/IMPALA/.benchmark/IMPALA_Pong.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1067e7b7326d96b9333b442e40677944c4a1d53d Binary files /dev/null and b/examples/IMPALA/.benchmark/IMPALA_Pong.jpg differ diff --git a/examples/IMPALA/.benchmark/IMPALA_Qbert.jpg b/examples/IMPALA/.benchmark/IMPALA_Qbert.jpg new file mode 100644 index 0000000000000000000000000000000000000000..79580a13115b1ea960c25f588c3d4d96c7f36ac3 Binary files /dev/null and b/examples/IMPALA/.benchmark/IMPALA_Qbert.jpg differ diff --git a/examples/IMPALA/.benchmark/IMPALA_SpaceInvaders.jpg b/examples/IMPALA/.benchmark/IMPALA_SpaceInvaders.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b04475e737a5d68c5c7ebcf4cd79d35db167e241 Binary files /dev/null and b/examples/IMPALA/.benchmark/IMPALA_SpaceInvaders.jpg differ diff --git a/examples/IMPALA/README.md b/examples/IMPALA/README.md new file mode 100644 index 0000000000000000000000000000000000000000..866579f17fb421ac3e471475a047c37a1ee27a3b --- /dev/null +++ b/examples/IMPALA/README.md @@ -0,0 +1,51 @@ +## Reproduce IMPALA with PARL +Based on PARL, the IMPALA algorithm of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game. + ++ IMPALA in +[Impala: Scalable distributed deep-rl with importance weighted actor-learner architectures](https://arxiv.org/abs/1802.01561) + +### Atari games introduction +Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game. + +### Benchmark result +Result with one learner (in P40 GPU) and 32 actors (in 32 CPUs). ++ PongNoFrameskip-v4: mean_episode_rewards can reach 18-19 score in about 7~10 minutes. +IMPALA_Pong + ++ Results of other games in an hour. + +IMPALA_Breakout IMPALA_BeamRider +
+IMPALA_Qbert IMPALA_SpaceInvaders + +## How to use +### Dependencies ++ python2.7 or python3.5+ ++ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym ++ opencv-python ++ atari_py + + +### Distributed Training: + +#### Learner +```sh +python train.py +``` + +#### Actors (Suggest: 32+ actors in 32+ CPUs) +```sh +for index in {1..32}; do + python actor.py & +done; +wait +``` + +You can change training settings (e.g. `env_name`, `server_ip`) in `impala_config.py`. +Training result will be saved in `log_dir/train/result.csv`. + +### Reference ++ [deepmind/scalable_agent](https://github.com/deepmind/scalable_agent) ++ [Ray](https://github.com/ray-project/ray) diff --git a/examples/IMPALA/actor.py b/examples/IMPALA/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..bea5f624ca7f4a7305b48186d329a3f626c5d4f6 --- /dev/null +++ b/examples/IMPALA/actor.py @@ -0,0 +1,106 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +import parl +import six +from atari_model import AtariModel +from collections import defaultdict +from atari_agent import AtariAgent +from parl.algorithms import IMPALA +from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls +from parl.env.vector_env import VectorEnv + + +@parl.remote_class +class Actor(object): + def __init__(self, config): + self.config = config + + self.envs = [] + for _ in six.moves.range(config['env_num']): + env = gym.make(config['env_name']) + env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') + self.envs.append(env) + self.vector_env = VectorEnv(self.envs) + + self.obs_batch = self.vector_env.reset() + + obs_shape = env.observation_space.shape + act_dim = env.action_space.n + + self.config['obs_shape'] = obs_shape + self.config['act_dim'] = act_dim + + model = AtariModel(act_dim) + algorithm = IMPALA(model, hyperparas=config) + self.agent = AtariAgent(algorithm, config) + + def sample(self): + env_sample_data = {} + for env_id in six.moves.range(self.config['env_num']): + env_sample_data[env_id] = defaultdict(list) + + for i in six.moves.range(self.config['sample_batch_steps']): + actions, behaviour_logits = self.agent.sample( + np.stack(self.obs_batch)) + next_obs_batch, reward_batch, done_batch, info_batch = \ + self.vector_env.step(actions) + + for env_id in six.moves.range(self.config['env_num']): + env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) + env_sample_data[env_id]['actions'].append(actions[env_id]) + env_sample_data[env_id]['behaviour_logits'].append( + behaviour_logits[env_id]) + env_sample_data[env_id]['rewards'].append(reward_batch[env_id]) + env_sample_data[env_id]['dones'].append(done_batch[env_id]) + + self.obs_batch = next_obs_batch + + # Merge data of envs + sample_data = defaultdict(list) + for env_id in six.moves.range(self.config['env_num']): + for data_name in [ + 'obs', 'actions', 'behaviour_logits', 'rewards', 'dones' + ]: + sample_data[data_name].extend( + env_sample_data[env_id][data_name]) + + # size of sample_data: env_num * sample_batch_steps + for key in sample_data: + sample_data[key] = np.stack(sample_data[key]) + + return sample_data + + def get_metrics(self): + metrics = defaultdict(list) + for env in self.envs: + monitor = get_wrapper_by_cls(env, MonitorEnv) + if monitor is not None: + for episode_rewards, episode_steps in monitor.next_episode_results( + ): + metrics['episode_rewards'].append(episode_rewards) + metrics['episode_steps'].append(episode_steps) + return metrics + + def set_params(self, params): + self.agent.set_params(params) + + +if __name__ == '__main__': + from impala_config import config + + actor = Actor(config) + actor.as_remote(config['server_ip'], config['server_port']) diff --git a/examples/IMPALA/atari_agent.py b/examples/IMPALA/atari_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..86782744365fbe86645839a746387ec01c77994f --- /dev/null +++ b/examples/IMPALA/atari_agent.py @@ -0,0 +1,135 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.agent_base import Agent + + +class AtariAgent(Agent): + def __init__(self, algorithm, config, learn_data_provider=None): + self.config = config + super(AtariAgent, self).__init__(algorithm) + + use_cuda = True if self.gpu_id >= 0 else False + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = True + exec_strategy.num_threads = 4 + build_strategy = fluid.BuildStrategy() + build_strategy.remove_unnecessary_lock = True + + # Use ParallelExecutor to make learn program running faster + self.learn_exe = fluid.ParallelExecutor( + use_cuda=use_cuda, + main_program=self.learn_program, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + if learn_data_provider: + self.learn_reader.decorate_tensor_provider(learn_data_provider) + self.learn_reader.start() + + def build_program(self): + self.sample_program = fluid.Program() + self.predict_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.sample_program): + obs = layers.data( + name='obs', shape=self.config['obs_shape'], dtype='float32') + self.sample_actions, self.behaviour_logits = self.alg.sample(obs) + + with fluid.program_guard(self.predict_program): + obs = layers.data( + name='obs', shape=self.config['obs_shape'], dtype='float32') + self.predict_actions = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=self.config['obs_shape'], dtype='float32') + actions = layers.data(name='actions', shape=[], dtype='int64') + behaviour_logits = layers.data( + name='behaviour_logits', + shape=[self.config['act_dim']], + dtype='float32') + rewards = layers.data(name='rewards', shape=[], dtype='float32') + dones = layers.data(name='dones', shape=[], dtype='float32') + lr = layers.data( + name='lr', shape=[1], dtype='float32', append_batch_size=False) + entropy_coeff = layers.data( + name='entropy_coeff', shape=[], dtype='float32') + + self.learn_reader = fluid.layers.create_py_reader_by_data( + capacity=self.config['train_batch_size'], + feed_list=[ + obs, actions, behaviour_logits, rewards, dones, lr, + entropy_coeff + ]) + + obs, actions, behaviour_logits, rewards, dones, lr, entropy_coeff = fluid.layers.read_file( + self.learn_reader) + + vtrace_loss, kl = self.alg.learn(obs, actions, behaviour_logits, + rewards, dones, lr, entropy_coeff) + self.learn_outputs = [ + vtrace_loss.total_loss.name, vtrace_loss.pi_loss.name, + vtrace_loss.vf_loss.name, vtrace_loss.entropy.name, kl.name + ] + + def sample(self, obs_np): + """ + Args: + obs_np: a numpy float32 array of shape ([B] + observation_space). + Format of image input should be NCHW format. + + Returns: + sample_ids: a numpy int64 array of shape [B] + """ + obs_np = obs_np.astype('float32') + + sample_actions, behaviour_logits = self.fluid_executor.run( + self.sample_program, + feed={'obs': obs_np}, + fetch_list=[self.sample_actions, self.behaviour_logits]) + return sample_actions, behaviour_logits + + def predict(self, obs_np): + """ + Args: + obs_np: a numpy float32 array of shape ([B] + observation_space) + Format of image input should be NCHW format. + + Returns: + sample_ids: a numpy int64 array of shape [B] + """ + obs_np = obs_np.astype('float32') + + predict_actions = self.fluid_executor.run( + self.predict_program, + feed={'obs': obs_np}, + fetch_list=[self.predict_actions])[0] + return predict_actions + + def learn(self): + total_loss, pi_loss, vf_loss, entropy, kl = self.learn_exe.run( + fetch_list=self.learn_outputs) + return total_loss, pi_loss, vf_loss, entropy, kl + + def get_params(self): + return self.alg.get_params() + + def set_params(self, params): + self.alg.set_params(params, gpu_id=self.gpu_id) diff --git a/examples/IMPALA/atari_model.py b/examples/IMPALA/atari_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b65993767f8c728d460f3d91699ad3706add96 --- /dev/null +++ b/examples/IMPALA/atari_model.py @@ -0,0 +1,74 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.framework.model_base import Model +from paddle.fluid.param_attr import ParamAttr + + +class AtariModel(Model): + def __init__(self, act_dim): + + self.conv1 = layers.conv2d( + num_filters=16, filter_size=4, stride=2, padding=1, act='relu') + self.conv2 = layers.conv2d( + num_filters=32, filter_size=4, stride=2, padding=2, act='relu') + self.conv3 = layers.conv2d( + num_filters=256, filter_size=11, stride=1, padding=0, act='relu') + + self.policy_conv = layers.conv2d( + num_filters=act_dim, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(initializer=fluid.initializer.Normal())) + + self.value_fc = layers.fc( + size=1, + param_attr=ParamAttr(initializer=fluid.initializer.Normal())) + + def policy(self, obs): + """ + Args: + obs: An float32 tensor of shape [B, C, H, W] + Returns: + policy_logits: B * ACT_DIM + """ + obs = obs / 255.0 + conv1 = self.conv1(obs) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + policy_conv = self.policy_conv(conv3) + policy_logits = layers.flatten(policy_conv, axis=1) + return policy_logits + + def value(self, obs): + """ + Args: + obs: An float32 tensor of shape [B, C, H, W] + Returns: + value: B + """ + obs = obs / 255.0 + conv1 = self.conv1(obs) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + flatten = layers.flatten(conv3, axis=1) + value = self.value_fc(flatten) + value = layers.squeeze(value, axes=[1]) + return value diff --git a/examples/IMPALA/impala_config.py b/examples/IMPALA/impala_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e36af5c4add6b2de590c499d777074f21bb6687a --- /dev/null +++ b/examples/IMPALA/impala_config.py @@ -0,0 +1,48 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config = { + 'experiment_name': 'Pong', + + #========== remote config ========== + 'server_ip': 'localhost', + 'server_port': 8037, + + #========== env config ========== + 'env_name': 'PongNoFrameskip-v4', + 'env_dim': 42, + 'obs_format': 'NHWC', + + #========== actor config ========== + 'env_num': 5, + 'sample_batch_steps': 50, + + #========== learner config ========== + 'train_batch_size': 1000, + 'learner_queue_max_size': 16, + 'sample_queue_max_size': 8, + 'gamma': 0.99, + + # learning rate adjustment schedule: (train_step, learning_rate) + 'lr_scheduler': [(0, 0.001), (20000, 0.0005), (40000, 0.0001)], + + # coefficient of policy entropy adjustment schedule: (train_step, coefficient) + 'entropy_coeff_scheduler': [(0, -0.01)], + 'vf_loss_coeff': 0.5, + 'clip_rho_threshold': 1.0, + 'clip_pg_rho_threshold': 1.0, + 'get_remote_metrics_interval': 10, + 'log_metrics_interval_s': 10, + 'params_broadcast_interval': 5, +} diff --git a/examples/IMPALA/learner.py b/examples/IMPALA/learner.py new file mode 100644 index 0000000000000000000000000000000000000000..272413c1b3499aae6149975b8835acfd4ea603dc --- /dev/null +++ b/examples/IMPALA/learner.py @@ -0,0 +1,255 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import numpy as np +import os +import queue +import time +import threading +from atari_model import AtariModel +from atari_agent import AtariAgent +from parl import RemoteManager +from parl.algorithms import IMPALA +from parl.env.atari_wrappers import wrap_deepmind +from parl.utils import logger, CSVLogger +from parl.utils.scheduler import PiecewiseScheduler +from parl.utils.time_stat import TimeStat +from parl.utils.window_stat import WindowStat + + +class Learner(object): + def __init__(self, config): + self.config = config + + self.learner_queue = queue.Queue( + maxsize=config['learner_queue_max_size']) + + #=========== Create Agent ========== + env = gym.make(config['env_name']) + env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') + obs_shape = env.observation_space.shape + + act_dim = env.action_space.n + self.config['obs_shape'] = obs_shape + self.config['act_dim'] = act_dim + + model = AtariModel(act_dim) + algorithm = IMPALA(model, hyperparas=config) + self.agent = AtariAgent(algorithm, config, self.learn_data_provider) + + self.cache_params = self.agent.get_params() + self.params_lock = threading.Lock() + self.params_updated = False + self.cache_params_sent_cnt = 0 + self.total_params_sync = 0 + + #========== Learner ========== + self.lr, self.entropy_coeff = None, None + self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler']) + self.entropy_coeff_scheduler = PiecewiseScheduler( + config['entropy_coeff_scheduler']) + + self.total_loss_stat = WindowStat(100) + self.pi_loss_stat = WindowStat(100) + self.vf_loss_stat = WindowStat(100) + self.entropy_stat = WindowStat(100) + self.kl_stat = WindowStat(100) + self.learn_time_stat = TimeStat(100) + self.start_time = None + + self.learn_thread = threading.Thread(target=self.run_learn) + self.learn_thread.setDaemon(True) + self.learn_thread.start() + + #========== Remote Actor =========== + self.remote_count = 0 + self.sample_data_queue = queue.Queue( + maxsize=config['sample_queue_max_size']) + self.batch_buffer = [] + self.remote_metrics_queue = queue.Queue() + self.sample_total_steps = 0 + + self.remote_manager_thread = threading.Thread( + target=self.run_remote_manager) + self.remote_manager_thread.setDaemon(True) + self.remote_manager_thread.start() + + self.csv_logger = CSVLogger( + os.path.join(logger.get_dir(), 'result.csv')) + + def learn_data_provider(self): + """ Data generator for fluid.layers.py_reader + """ + while True: + batch = self.learner_queue.get() + + obs_np = batch['obs'].astype('float32') + actions_np = batch['actions'].astype('int64') + behaviour_logits_np = batch['behaviour_logits'].astype('float32') + rewards_np = batch['rewards'].astype('float32') + dones_np = batch['dones'].astype('float32') + + self.lr = self.lr_scheduler.step() + self.entropy_coeff = self.entropy_coeff_scheduler.step() + + yield [ + obs_np, actions_np, behaviour_logits_np, rewards_np, dones_np, + self.lr, self.entropy_coeff + ] + + def run_learn(self): + """ Learn loop + """ + while True: + with self.learn_time_stat: + total_loss, pi_loss, vf_loss, entropy, kl = self.agent.learn() + + self.params_updated = True + + self.total_loss_stat.add(total_loss) + self.pi_loss_stat.add(pi_loss) + self.vf_loss_stat.add(vf_loss) + self.entropy_stat.add(entropy) + self.kl_stat.add(kl) + + def run_remote_manager(self): + """ Accept connection of new remote actor and start sampling of the remote actor. + """ + remote_manager = RemoteManager(port=self.config['server_port']) + logger.info('Waiting for remote actors connecting.') + while True: + remote_actor = remote_manager.get_remote() + self.remote_count += 1 + logger.info('Remote actor count: {}'.format(self.remote_count)) + if self.start_time is None: + self.start_time = time.time() + + remote_thread = threading.Thread( + target=self.run_remote_sample, args=(remote_actor, )) + remote_thread.setDaemon(True) + remote_thread.start() + + def run_remote_sample(self, remote_actor): + """ Sample data from remote actor and update parameters of remote actor. + """ + cnt = 0 + remote_actor.set_params(self.cache_params) + while True: + batch = remote_actor.sample() + self.sample_data_queue.put(batch) + + cnt += 1 + if cnt % self.config['get_remote_metrics_interval'] == 0: + metrics = remote_actor.get_metrics() + if metrics: + self.remote_metrics_queue.put(metrics) + + self.params_lock.acquire() + + if self.params_updated and self.cache_params_sent_cnt >= self.config[ + 'params_broadcast_interval']: + self.params_updated = False + self.cache_params = self.agent.get_params() + self.cache_params_sent_cnt = 0 + self.cache_params_sent_cnt += 1 + self.total_params_sync += 1 + + self.params_lock.release() + + remote_actor.set_params(self.cache_params) + + def step(self): + """ Merge and generate batch learn data from sample_data_queue, + and put it in learner_queue. + """ + assert self.learn_thread.is_alive() + + while True: + try: + sample_data = self.sample_data_queue.get_nowait() + self.sample_total_steps += sample_data['obs'].shape[0] + self.batch_buffer.append(sample_data) + + buffer_size = sum( + [data['obs'].shape[0] for data in self.batch_buffer]) + if buffer_size >= self.config['train_batch_size']: + train_batch = {} + for key in self.batch_buffer[0].keys(): + train_batch[key] = np.concatenate( + [data[key] for data in self.batch_buffer]) + self.learner_queue.put(train_batch) + self.batch_buffer = [] + + except queue.Empty: + break + + def log_metrics(self): + """ Log metrics of learner and actors + """ + if self.start_time is None: + return + + metrics = [] + while True: + try: + metric = self.remote_metrics_queue.get_nowait() + metrics.append(metric) + except queue.Empty: + break + + episode_rewards, episode_steps = [], [] + for x in metrics: + episode_rewards.extend(x['episode_rewards']) + episode_steps.extend(x['episode_steps']) + max_episode_rewards, mean_episode_rewards, min_episode_rewards, \ + max_episode_steps, mean_episode_steps, min_episode_steps =\ + None, None, None, None, None, None + if episode_rewards: + mean_episode_rewards = np.mean(np.array(episode_rewards).flatten()) + max_episode_rewards = np.max(np.array(episode_rewards).flatten()) + min_episode_rewards = np.min(np.array(episode_rewards).flatten()) + + mean_episode_steps = np.mean(np.array(episode_steps).flatten()) + max_episode_steps = np.max(np.array(episode_steps).flatten()) + min_episode_steps = np.min(np.array(episode_steps).flatten()) + + metric = { + 'Sample steps': self.sample_total_steps, + 'max_episode_rewards': max_episode_rewards, + 'mean_episode_rewards': mean_episode_rewards, + 'min_episode_rewards': min_episode_rewards, + 'max_episode_steps': max_episode_steps, + 'mean_episode_steps': mean_episode_steps, + 'min_episode_steps': min_episode_steps, + 'learner_queue_size': self.learner_queue.qsize(), + 'sample_queue_size': self.sample_data_queue.qsize(), + 'total_params_sync': self.total_params_sync, + 'cache_params_sent_cnt': self.cache_params_sent_cnt, + 'total_loss': self.total_loss_stat.mean, + 'pi_loss': self.pi_loss_stat.mean, + 'vf_loss': self.vf_loss_stat.mean, + 'entropy': self.entropy_stat.mean, + 'kl': self.kl_stat.mean, + 'learn_time_s': self.learn_time_stat.mean, + 'elapsed_time_s': int(time.time() - self.start_time), + 'lr': self.lr, + 'entropy_coeff': self.entropy_coeff, + } + + logger.info(metric) + self.csv_logger.log_dict(metric) + + def close(self): + self.csv_logger.close() diff --git a/examples/IMPALA/run_actors.sh b/examples/IMPALA/run_actors.sh new file mode 100644 index 0000000000000000000000000000000000000000..0e97bb177bfc547ae7381d5730e34c6ca5c60958 --- /dev/null +++ b/examples/IMPALA/run_actors.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES="" + +for index in {1..32}; do + python actor.py & +done; +wait diff --git a/examples/IMPALA/train.py b/examples/IMPALA/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f15075106487595d549f4aa28af01f117a4480aa --- /dev/null +++ b/examples/IMPALA/train.py @@ -0,0 +1,36 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from learner import Learner + + +def main(config): + learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 + + try: + while True: + start = time.time() + while time.time() - start < config['log_metrics_interval_s']: + learner.step() + learner.log_metrics() + + except KeyboardInterrupt: + learner.close() + + +if __name__ == '__main__': + from impala_config import config + main(config) diff --git a/parl/algorithms/__init__.py b/parl/algorithms/__init__.py index 16f032edfe429685a36d1f7eaef4c1390c62dd3f..7f890e19361649870c832fb31cc9d442fb335e53 100644 --- a/parl/algorithms/__init__.py +++ b/parl/algorithms/__init__.py @@ -16,3 +16,4 @@ from parl.algorithms.ddpg import * from parl.algorithms.dqn import * from parl.algorithms.policy_gradient import * from parl.algorithms.ppo import * +from parl.algorithms.impala.impala import * diff --git a/parl/algorithms/impala/__init__.py b/parl/algorithms/impala/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0f310cf32b7d1b22d06f625cb6df8449507e8a6 --- /dev/null +++ b/parl/algorithms/impala/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.algorithms.impala.impala import * diff --git a/parl/algorithms/impala/impala.py b/parl/algorithms/impala/impala.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3e6195c9fa3da146e72b91fcc37fdffc279118 --- /dev/null +++ b/parl/algorithms/impala/impala.py @@ -0,0 +1,195 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl.layers as layers +from parl.algorithms.impala import vtrace +from parl.framework.algorithm_base import Algorithm +from parl.framework.policy_distribution import CategoricalDistribution +from parl.plutils import inverse + +__all__ = ['IMPALA'] + + +class VTraceLoss(object): + def __init__(self, + behaviour_actions_log_probs, + target_actions_log_probs, + policy_entropy, + dones, + discount, + rewards, + values, + bootstrap_value, + entropy_coeff=-0.01, + vf_loss_coeff=0.5, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0): + """Policy gradient loss with vtrace importance weighting. + + VTraceLoss takes tensors of shape [T, B, ...], where `B` is the + batch_size. The reason we need to know `B` is for V-trace to properly + handle episode cut boundaries. + + Args: + behaviour_actions_log_probs: A float32 tensor of shape [T, B]. + target_actions_log_probs: A float32 tensor of shape [T, B]. + policy_entropy: A float32 tensor of shape [T, B]. + dones: A float32 tensor of shape [T, B]. + discount: A float32 scalar. + rewards: A float32 tensor of shape [T, B]. + values: A float32 tensor of shape [T, B]. + bootstrap_value: A float32 tensor of shape [B]. + """ + + self.vtrace_returns = vtrace.from_importance_weights( + behaviour_actions_log_probs=behaviour_actions_log_probs, + target_actions_log_probs=target_actions_log_probs, + discounts=inverse(dones) * discount, + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + clip_rho_threshold=clip_rho_threshold, + clip_pg_rho_threshold=clip_pg_rho_threshold) + + # The policy gradients loss + self.pi_loss = -1.0 * layers.reduce_sum( + target_actions_log_probs * self.vtrace_returns.pg_advantages) + + # The baseline loss + delta = values - self.vtrace_returns.vs + self.vf_loss = 0.5 * layers.reduce_sum(layers.square(delta)) + + # The entropy loss (We want to maximize entropy, so entropy_ceoff < 0) + self.entropy = layers.reduce_sum(policy_entropy) + + # The summed weighted loss + self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff + + self.entropy * entropy_coeff) + + +class IMPALA(Algorithm): + def __init__(self, model, hyperparas): + super(IMPALA, self).__init__(model, hyperparas) + + def learn(self, obs, actions, behaviour_logits, rewards, dones, + learning_rate, entropy_coeff): + """ + Args: + obs: An float32 tensor of shape ([B] + observation_space). + E.g. [B, C, H, W] in atari. + actions: An int64 tensor of shape [B]. + behaviour_logits: A float32 tensor of shape [B, NUM_ACTIONS]. + rewards: A float32 tensor of shape [B]. + dones: A float32 tensor of shape [B]. + learning_rate: float scalar of learning rate. + entropy_coeff: float scalar of entropy coefficient. + """ + + values = self.model.value(obs) + target_logits = self.model.policy(obs) + + target_policy_distribution = CategoricalDistribution(target_logits) + behaviour_policy_distribution = CategoricalDistribution( + behaviour_logits) + + policy_entropy = target_policy_distribution.entropy() + target_actions_log_probs = target_policy_distribution.logp(actions) + behaviour_actions_log_probs = behaviour_policy_distribution.logp( + actions) + + # Calculating kl for debug + kl = target_policy_distribution.kl(behaviour_policy_distribution) + kl = layers.reduce_mean(kl) + """ + Split the tensor into batches at known episode cut boundaries. + [B * T] -> [T, B] + """ + T = self.hp["sample_batch_steps"] + + def split_batches(tensor): + B = tensor.shape[0] // T + splited_tensor = layers.reshape(tensor, + [B, T] + list(tensor.shape[1:])) + # transpose B and T + return layers.transpose( + splited_tensor, [1, 0] + list(range(2, 1 + len(tensor.shape)))) + + behaviour_actions_log_probs = split_batches( + behaviour_actions_log_probs) + target_actions_log_probs = split_batches(target_actions_log_probs) + policy_entropy = split_batches(policy_entropy) + dones = split_batches(dones) + rewards = split_batches(rewards) + values = split_batches(values) + + # [T, B] -> [T - 1, B] for V-trace calc. + behaviour_actions_log_probs = layers.slice( + behaviour_actions_log_probs, axes=[0], starts=[0], ends=[-1]) + target_actions_log_probs = layers.slice( + target_actions_log_probs, axes=[0], starts=[0], ends=[-1]) + policy_entropy = layers.slice( + policy_entropy, axes=[0], starts=[0], ends=[-1]) + dones = layers.slice(dones, axes=[0], starts=[0], ends=[-1]) + rewards = layers.slice(rewards, axes=[0], starts=[0], ends=[-1]) + bootstrap_value = layers.slice( + values, axes=[0], starts=[T - 1], ends=[T]) + values = layers.slice(values, axes=[0], starts=[0], ends=[-1]) + + bootstrap_value = layers.squeeze(bootstrap_value, axes=[0]) + + vtrace_loss = VTraceLoss( + behaviour_actions_log_probs=behaviour_actions_log_probs, + target_actions_log_probs=target_actions_log_probs, + policy_entropy=policy_entropy, + dones=dones, + discount=self.hp['gamma'], + rewards=rewards, + values=values, + bootstrap_value=bootstrap_value, + entropy_coeff=entropy_coeff, + vf_loss_coeff=self.hp['vf_loss_coeff'], + clip_rho_threshold=self.hp['clip_rho_threshold'], + clip_pg_rho_threshold=self.hp['clip_pg_rho_threshold']) + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=40.0)) + + optimizer = fluid.optimizer.AdamOptimizer(learning_rate) + optimizer.minimize(vtrace_loss.total_loss) + return vtrace_loss, kl + + def sample(self, obs): + """ + Args: + obs: An float32 tensor of shape ([B] + observation_space). + E.g. [B, C, H, W] in atari. + """ + logits = self.model.policy(obs) + policy_dist = CategoricalDistribution(logits) + sample_actions = policy_dist.sample() + return sample_actions, logits + + def predict(self, obs): + """ + Args: + obs: An float32 tensor of shape ([B] + observation_space). + E.g. [B, C, H, W] in atari. + """ + logits = self.model.policy(obs) + probs = layers.softmax(logits) + + predict_actions = layers.argmax(probs, 1) + + return predict_actions diff --git a/parl/algorithms/impala/tests/vtrace_test.py b/parl/algorithms/impala/tests/vtrace_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c8bc8d273d663f4983e0a7b416c55fb9d3d7d3e6 --- /dev/null +++ b/parl/algorithms/impala/tests/vtrace_test.py @@ -0,0 +1,191 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for V-trace. + +The following code is mainly referenced and copied from: +https://github.com/deepmind/scalable_agent/blob/master/vtrace_test.py +""" + +import copy +import numpy as np +import unittest +import parl.layers as layers +from paddle import fluid +from parameterized import parameterized +from parl.algorithms.impala import vtrace +from parl.utils import get_gpu_count + + +def _shaped_arange(*shape): + """Runs np.arange, converts to float and reshapes.""" + return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape) + + +def _ground_truth_calculation(behaviour_actions_log_probs, + target_actions_log_probs, discounts, rewards, + values, bootstrap_value, clip_rho_threshold, + clip_pg_rho_threshold): + """Calculates the ground truth for V-trace in Python/Numpy.""" + log_rhos = target_actions_log_probs - behaviour_actions_log_probs + + vs = [] + seq_len = len(discounts) + rhos = np.exp(log_rhos) + cs = np.minimum(rhos, 1.0) + clipped_rhos = rhos + if clip_rho_threshold: + clipped_rhos = np.minimum(rhos, clip_rho_threshold) + clipped_pg_rhos = rhos + if clip_pg_rho_threshold: + clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold) + + # This is a very inefficient way to calculate the V-trace ground truth. + # We calculate it this way because it is close to the mathematical notation of + # V-trace. + # v_s = V(x_s) + # + \sum^{T-1}_{t=s} \gamma^{t-s} + # * \prod_{i=s}^{t-1} c_i + # * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t)) + # Note that when we take the product over c_i, we write `s:t` as the notation + # of the paper is inclusive of the `t-1`, but Python is exclusive. + # Also note that np.prod([]) == 1. + values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]], + axis=0) + for s in range(seq_len): + v_s = np.copy(values[s]) # Very important copy. + for t in range(s, seq_len): + v_s += (np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0) + * clipped_rhos[t] * (rewards[t] + discounts[t] * + values_t_plus_1[t + 1] - values[t])) + vs.append(v_s) + vs = np.stack(vs, axis=0) + pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate( + [vs[1:], bootstrap_value[None, :]], axis=0) - values)) + + return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages) + + +class VtraceTest(unittest.TestCase): + def setUp(self): + gpu_count = get_gpu_count() + if gpu_count > 0: + place = fluid.CUDAPlace(0) + self.gpu_id = 0 + else: + place = fluid.CPUPlace() + self.gpu_id = -1 + self.executor = fluid.Executor(place) + + @parameterized.expand([('Batch1', 1), ('Batch4', 4)]) + def test_from_importance_weights(self, name, batch_size): + """Tests V-trace against ground truth data calculated in python.""" + seq_len = 5 + + # Create log_rhos such that rho will span from near-zero to above the + # clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5), + # so that rho is in approx [0.08, 12.2). + log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len) + log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5). + # Fake behaviour_actions_log_probs, target_actions_log_probs + target_actions_log_probs = log_rhos + 1.0 + behaviour_actions_log_probs = np.ones( + shape=log_rhos.shape, dtype='float32') + + values = { + 'behaviour_actions_log_probs': + behaviour_actions_log_probs, + 'target_actions_log_probs': + target_actions_log_probs, + # T, B where B_i: [0.9 / (i+1)] * T + 'discounts': + np.array([[0.9 / (b + 1) for b in range(batch_size)] + for _ in range(seq_len)], + dtype=np.float32), + 'rewards': + _shaped_arange(seq_len, batch_size), + 'values': + _shaped_arange(seq_len, batch_size) / batch_size, + 'bootstrap_value': + _shaped_arange(batch_size) + 1.0, + 'clip_rho_threshold': + 3.7, + 'clip_pg_rho_threshold': + 2.2, + } + + # Calculated by numpy/python + ground_truth_v = _ground_truth_calculation(**values) + + # Calculated by Fluid + test_program = fluid.Program() + with fluid.program_guard(test_program): + behaviour_actions_log_probs_input = layers.data( + name='behaviour_actions_log_probs', + shape=[seq_len, batch_size], + dtype='float32', + append_batch_size=False) + target_actions_log_probs_input = layers.data( + name='target_actions_log_probs', + shape=[seq_len, batch_size], + dtype='float32', + append_batch_size=False) + discounts_input = layers.data( + name='discounts', + shape=[seq_len, batch_size], + dtype='float32', + append_batch_size=False) + rewards_input = layers.data( + name='rewards', + shape=[seq_len, batch_size], + dtype='float32', + append_batch_size=False) + values_input = layers.data( + name='values', + shape=[seq_len, batch_size], + dtype='float32', + append_batch_size=False) + bootstrap_value_input = layers.data( + name='bootstrap_value', + shape=[batch_size], + dtype='float32', + append_batch_size=False) + fluid_inputs = { + 'behaviour_actions_log_probs': + behaviour_actions_log_probs_input, + 'target_actions_log_probs': target_actions_log_probs_input, + 'discounts': discounts_input, + 'rewards': rewards_input, + 'values': values_input, + 'bootstrap_value': bootstrap_value_input, + 'clip_rho_threshold': 3.7, + 'clip_pg_rho_threshold': 2.2, + } + output = vtrace.from_importance_weights(**fluid_inputs) + + self.executor.run(fluid.default_startup_program()) + feed = copy.copy(values) + del feed['clip_rho_threshold'] + del feed['clip_pg_rho_threshold'] + [output_vs, output_pg_advantage] = self.executor.run( + test_program, + feed=feed, + fetch_list=[output.vs, output.pg_advantages]) + + np.testing.assert_almost_equal(ground_truth_v.vs, output_vs, 5) + np.testing.assert_almost_equal(ground_truth_v.pg_advantages, + output_pg_advantage, 5) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/algorithms/impala/vtrace.py b/parl/algorithms/impala/vtrace.py new file mode 100644 index 0000000000000000000000000000000000000000..71ed1d2c198b23634711983f87bfa83712fbe983 --- /dev/null +++ b/parl/algorithms/impala/vtrace.py @@ -0,0 +1,192 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions to compute V-trace off-policy actor critic targets, +which used in IMAPLA algorithm. + +The following code is mainly referenced and copied from: +https://github.com/deepmind/scalable_agent/blob/master/vtrace.py + +For details and theory see: + +"Espeholt L, Soyer H, Munos R, et al. Impala: Scalable distributed +deep-rl with importance weighted actor-learner +architectures[J]. arXiv preprint arXiv:1802.01561, 2018." + +""" + +import collections +import paddle.fluid as fluid +import parl.layers as layers +from parl.utils import MAX_INT32 + +VTraceReturns = collections.namedtuple('VTraceReturns', + ['vs', 'pg_advantages']) + + +def from_importance_weights(behaviour_actions_log_probs, + target_actions_log_probs, + discounts, + rewards, + values, + bootstrap_value, + clip_rho_threshold=1.0, + clip_pg_rho_threshold=1.0, + name='vtrace_from_logits'): + r"""V-trace for softmax policies. + + Calculates V-trace actor critic targets for softmax polices as described in + + "IMPALA: Scalable Distributed Deep-RL with + Importance Weighted Actor-Learner Architectures" + by Espeholt, Soyer, Munos et al. + + Target policy refers to the policy we are interested in improving and + behaviour policy refers to the policy that generated the given + rewards and actions. + + In the notation used throughout documentation and comments, T refers to the + time dimension ranging from 0 to T-1. B refers to the batch size and + NUM_ACTIONS refers to the number of actions. + + Args: + behaviour_actions_log_probs: A float32 tensor of shape [T, B] of + log-probabilities of actions in behaviour policy. + target_policy_logits: A float32 tensor of shape [T, B] of + log-probabilities of actions in target policy. + discounts: A float32 tensor of shape [T, B] with the discount encountered + when following the behaviour policy. + rewards: A float32 tensor of shape [T, B] with the rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + clip_rho_threshold: A scalar float32 tensor with the clipping threshold for + importance weights (rho) when calculating the baseline targets (vs). + rho^bar in the paper. + clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold + on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). + name: The name scope that all V-trace operations will be created in. + + Returns: + A VTraceReturns namedtuple (vs, pg_advantages) where: + vs: A float32 tensor of shape [T, B]. Can be used as target to + train a baseline (V(x_t) - vs_t)^2. + pg_advantages: A float32 tensor of shape [T, B]. Can be used as the + advantage in the calculation of policy gradients. + """ + + rank = len(behaviour_actions_log_probs.shape) # Usually 2. + assert len(target_actions_log_probs.shape) == rank + assert len(values.shape) == rank + assert len(bootstrap_value.shape) == (rank - 1) + assert len(discounts.shape) == rank + assert len(rewards.shape) == rank + + # log importance sampling weights. + # V-trace performs operations on rhos in log-space for numerical stability. + log_rhos = target_actions_log_probs - behaviour_actions_log_probs + + if clip_rho_threshold is not None: + clip_rho_threshold = layers.fill_constant([1], 'float32', + clip_rho_threshold) + if clip_pg_rho_threshold is not None: + clip_pg_rho_threshold = layers.fill_constant([1], 'float32', + clip_pg_rho_threshold) + + rhos = layers.exp(log_rhos) + if clip_rho_threshold is not None: + clipped_rhos = layers.elementwise_min(rhos, clip_rho_threshold) + else: + clipped_rhos = rhos + + constant_one = layers.fill_constant([1], 'float32', 1.0) + cs = layers.elementwise_min(rhos, constant_one) + + # Append bootstrapped value to get [v1, ..., v_t+1] + values_1_t = layers.slice(values, axes=[0], starts=[1], ends=[MAX_INT32]) + values_t_plus_1 = layers.concat( + [values_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0) + + # \delta_s * V + deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) + + vs_minus_v_xs = recursively_scan(discounts, cs, deltas) + + # Add V(x_s) to get v_s. + vs = layers.elementwise_add(vs_minus_v_xs, values) + + # Advantage for policy gradient. + vs_1_t = layers.slice(vs, axes=[0], starts=[1], ends=[MAX_INT32]) + vs_t_plus_1 = layers.concat( + [vs_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0) + + if clip_pg_rho_threshold is not None: + clipped_pg_rhos = layers.elementwise_min(rhos, clip_pg_rho_threshold) + else: + clipped_pg_rhos = rhos + pg_advantages = ( + clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)) + + # Make sure no gradients backpropagated through the returned values. + vs.stop_gradient = True + pg_advantages.stop_gradient = True + return VTraceReturns(vs=vs, pg_advantages=pg_advantages) + + +def recursively_scan(discounts, cs, deltas): + """ Recursively calculate vs_minus_v_xs according to following equation: + vs_minus_v_xs(t) = deltas(t) + discounts(t) * cs(t) * vs_minus_v_xs(t + 1) + + Args: + discounts: A float32 tensor of shape [T, B] with discounts encountered when + following the behaviour policy. + cs: A float32 tensor of shape [T, B], which corresponding to $c_s$ in the + origin paper. + deltas: A float32 tensor of shape [T, B], which corresponding to + $\delta_s * V$ in the origin paper. + + Returns: + vs_minus_v_xs: A float32 tensor of shape [T, B], which corresponding to + $v_s - V(x_s)$ in the origin paper. + """ + + # All sequences are reversed, computation starts from the back. + reverse_discounts = layers.reverse(x=discounts, axis=[0]) + reverse_cs = layers.reverse(x=cs, axis=[0]) + reverse_deltas = layers.reverse(x=deltas, axis=[0]) + + static_while = layers.StaticRNN() + # init: shape [B] + init = layers.fill_constant_batch_size_like( + discounts, shape=[1], dtype='float32', value=0.0, input_dim_idx=1) + + with static_while.step(): + discount_t = static_while.step_input(reverse_discounts) + c_t = static_while.step_input(reverse_cs) + delta_t = static_while.step_input(reverse_deltas) + + vs_minus_v_xs_t_plus_1 = static_while.memory(init=init) + vs_minus_v_xs_t = delta_t + discount_t * c_t * vs_minus_v_xs_t_plus_1 + + static_while.update_memory(vs_minus_v_xs_t_plus_1, vs_minus_v_xs_t) + + static_while.step_output(vs_minus_v_xs_t) + + vs_minus_v_xs = static_while() + + # Reverse the results back to original order. + vs_minus_v_xs = layers.reverse(vs_minus_v_xs, [0]) + + return vs_minus_v_xs diff --git a/parl/env/__init__.py b/parl/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..079f2ac4a7260d5318c68ceca4ac1faab2b195d2 --- /dev/null +++ b/parl/env/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.env.vector_env import * diff --git a/parl/env/atari_wrappers.py b/parl/env/atari_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa6d95c8f050a25fedc97dcddd9c8567bbc3e74 --- /dev/null +++ b/parl/env/atari_wrappers.py @@ -0,0 +1,280 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/atari_wrappers.py + +import numpy as np +from collections import deque +import gym +from gym import spaces +import cv2 +cv2.ocl.setUseOpenCL(False) + + +def get_wrapper_by_cls(env, cls): + """Returns the gym env wrapper of the given class, or None.""" + currentenv = env + while True: + if isinstance(currentenv, cls): + return currentenv + elif isinstance(currentenv, gym.Wrapper): + currentenv = currentenv.env + else: + return None + + +class MonitorEnv(gym.Wrapper): + def __init__(self, env=None): + """Record episodes stats prior to EpisodicLifeEnv, etc.""" + gym.Wrapper.__init__(self, env) + self._current_reward = None + self._num_steps = None + self._total_steps = None + self._episode_rewards = [] + self._episode_lengths = [] + self._num_episodes = 0 + self._num_returned = 0 + + def reset(self, **kwargs): + obs = self.env.reset(**kwargs) + + if self._total_steps is None: + self._total_steps = sum(self._episode_lengths) + + if self._current_reward is not None: + self._episode_rewards.append(self._current_reward) + self._episode_lengths.append(self._num_steps) + self._num_episodes += 1 + + self._current_reward = 0 + self._num_steps = 0 + + return obs + + def step(self, action): + obs, rew, done, info = self.env.step(action) + self._current_reward += rew + self._num_steps += 1 + self._total_steps += 1 + return (obs, rew, done, info) + + def get_episode_rewards(self): + return self._episode_rewards + + def get_episode_lengths(self): + return self._episode_lengths + + def get_total_steps(self): + return self._total_steps + + def next_episode_results(self): + for i in range(self._num_returned, len(self._episode_rewards)): + yield (self._episode_rewards[i], self._episode_lengths[i]) + self._num_returned = len(self._episode_rewards) + + +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + self.noop_action = 0 + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + + def reset(self, **kwargs): + """ Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset(**kwargs) + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, _ = self.env.step(self.noop_action) + if done: + obs = self.env.reset(**kwargs) + return obs + + def step(self, ac): + return self.env.step(ac) + + +class ClipRewardEnv(gym.RewardWrapper): + def __init__(self, env): + gym.RewardWrapper.__init__(self, env) + + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return np.sign(reward) + + +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset. + + For environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self, **kwargs): + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset(**kwargs) + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset(**kwargs) + return obs + + def step(self, ac): + return self.env.step(ac) + + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. since it helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + self.was_real_done = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, + # then update lives to handle bonus lives + lives = self.env.unwrapped.ale.lives() + if lives < self.lives and lives > 0: + # for Qbert sometimes we stay in lives == 0 condtion for a few fr + # so its important to keep lives > 0, so that we only reset once + # the environment advertises done. + done = True + self.lives = lives + return obs, reward, done, info + + def reset(self, **kwargs): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset(**kwargs) + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _ = self.env.step(0) + self.lives = self.env.unwrapped.ale.lives() + return obs + + +class MaxAndSkipEnv(gym.Wrapper): + def __init__(self, env, skip=4): + """Return only every `skip`-th frame""" + gym.Wrapper.__init__(self, env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = np.zeros( + (2, ) + env.observation_space.shape, dtype=np.uint8) + self._skip = skip + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + done = None + for i in range(self._skip): + obs, reward, done, info = self.env.step(action) + if i == self._skip - 2: + self._obs_buffer[0] = obs + if i == self._skip - 1: + self._obs_buffer[1] = obs + total_reward += reward + if done: + break + # Note that the observation on the done=True frame + # doesn't matter + max_frame = self._obs_buffer.max(axis=0) + + return max_frame, total_reward, done, info + + def reset(self, **kwargs): + return self.env.reset(**kwargs) + + +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env, dim): + """Warp frames to the specified size (dim x dim).""" + gym.ObservationWrapper.__init__(self, env) + self.width = dim + self.height = dim + self.observation_space = spaces.Box( + low=0, high=255, shape=(self.height, self.width), dtype=np.uint8) + + def observation(self, frame): + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = cv2.resize( + frame, (self.width, self.height), interpolation=cv2.INTER_AREA) + return frame + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k, obs_format='NHWC'): + """Stack k last frames.""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + assert obs_format == 'NHWC' or obs_format == 'NCHW' + self.obs_format = obs_format + if obs_format == 'NHWC': + obs_shape = (shp[0], shp[1], k) + else: + obs_shape = (k, shp[0], shp[1]) + + self.observation_space = spaces.Box( + low=0, + high=255, + shape=obs_shape, + dtype=env.observation_space.dtype) + + def reset(self): + ob = self.env.reset() + for _ in range(self.k): + self.frames.append(ob) + return self._get_ob() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self._get_ob(), reward, done, info + + def _get_ob(self): + assert len(self.frames) == self.k + if self.obs_format == 'NHWC': + return np.stack(self.frames, axis=2) + else: + return np.array(self.frames) + + +def wrap_deepmind(env, dim=84, framestack=True, obs_format='NHWC'): + """Configure environment for DeepMind-style Atari. + + Args: + dim (int): Dimension to resize observations to (dim x dim). + framestack (bool): Whether to framestack observations. + """ + env = MonitorEnv(env) + env = NoopResetEnv(env, noop_max=30) + if 'NoFrameskip' in env.spec.id: + env = MaxAndSkipEnv(env, skip=4) + env = EpisodicLifeEnv(env) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = WarpFrame(env, dim) + env = ClipRewardEnv(env) + if framestack: + env = FrameStack(env, 4, obs_format) + return env diff --git a/parl/env/vector_env.py b/parl/env/vector_env.py new file mode 100644 index 0000000000000000000000000000000000000000..5e605811f39384f3799fdb93af402b91da8bc301 --- /dev/null +++ b/parl/env/vector_env.py @@ -0,0 +1,63 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +from collections import defaultdict + +__all__ = ['VectorEnv'] + + +class VectorEnv(object): + """ vector of envs to support vector reset and vector step. + `vector_step` api will automatically reset envs which are done. + """ + + def __init__(self, envs): + """ + Args: + envs: List of env + """ + self.envs = envs + self.envs_num = len(envs) + + def reset(self): + """ + Returns: + List of obs + """ + return [env.reset() for env in self.envs] + + def step(self, actions): + """ + Args: + actions: List or array of action + + Returns: + obs_batch: List of next obs of envs + reward_batch: List of return reward of envs + done_batch: List of done of envs + info_batch: List of info of envs + """ + obs_batch, reward_batch, done_batch, info_batch = [], [], [], [] + for env_id in six.moves.range(self.envs_num): + obs, reward, done, info = self.envs[env_id].step(actions[env_id]) + + if done: + obs = self.envs[env_id].reset() + + obs_batch.append(obs) + reward_batch.append(reward) + done_batch.append(done) + info_batch.append(info) + return obs_batch, reward_batch, done_batch, info_batch diff --git a/parl/framework/algorithm_base.py b/parl/framework/algorithm_base.py index ca8a3a3b14eaf4402e13b273d1aca85370b33fcc..9672134814458102ae89545dc5eee97988af5976 100644 --- a/parl/framework/algorithm_base.py +++ b/parl/framework/algorithm_base.py @@ -54,3 +54,20 @@ class Algorithm(object): 3. optimize model defined in Model """ raise NotImplementedError() + + def get_params(self): + """ Get parameters of self.model + + Returns: + List of numpy array. + """ + return self.model.get_params() + + def set_params(self, params, gpu_id): + """ Set parameters of self.model + + Args: + params: List of numpy array. + gpu_id: gpu id where self.model in. (if gpu_id < 0, means in cpu.) + """ + self.model.set_params(params, gpu_id=gpu_id) diff --git a/parl/framework/policy_distribution.py b/parl/framework/policy_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..25aab040330ba3766c8266d463fc75a98ebca9e8 --- /dev/null +++ b/parl/framework/policy_distribution.py @@ -0,0 +1,115 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl.layers as layers + +__all__ = ['PolicyDistribution', 'CategoricalDistribution'] + + +class PolicyDistribution(object): + def sample(self): + """Sampling from the policy distribution.""" + raise NotImplementedError + + def entropy(self): + """The entropy of the policy distribution.""" + raise NotImplementedError + + def kl(self, other): + """The KL-divergence between self policy distributions and other.""" + raise NotImplementedError + + def logp(self, actions): + """The log-probabilities of the actions in this policy distribution.""" + raise NotImplementedError + + +class CategoricalDistribution(PolicyDistribution): + """Categorical distribution for discrete action spaces.""" + + def __init__(self, logits): + """ + Args: + logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits + """ + assert len(logits.shape) == 2 + self.logits = logits + + def sample(self): + """ + Returns: + sample_action: An int64 tensor with shape [BATCH_SIZE] of multinomial sampling ids. + Each value in sample_action is in [0, NUM_ACTIOINS - 1] + """ + probs = layers.softmax(self.logits) + sample_actions = layers.sampling_id(probs) + return sample_actions + + def entropy(self): + """ + Returns: + entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution. + """ + logits = self.logits - layers.reduce_max(self.logits, dim=1) + e_logits = layers.exp(logits) + z = layers.reduce_sum(e_logits, dim=1) + prob = e_logits / z + entropy = -1.0 * layers.reduce_sum( + prob * (logits - layers.log(z)), dim=1) + + return entropy + + def logp(self, actions): + """ + Args: + actions: An int64 tensor with shape [BATCH_SIZE] + + Returns: + actions_log_prob: A float32 tensor with shape [BATCH_SIZE] + """ + + assert len(actions.shape) == 1 + actions = layers.unsqueeze(actions, axes=[1]) + + cross_entropy = layers.softmax_with_cross_entropy( + logits=self.logits, label=actions) + + actions_log_prob = -1.0 * layers.squeeze(cross_entropy, axes=[-1]) + return actions_log_prob + + def kl(self, other): + """ + Args: + other: object of CategoricalDistribution + + Returns: + kl: A float32 tensor with shape [BATCH_SIZE] + """ + assert isinstance(other, CategoricalDistribution) + + logits = self.logits - layers.reduce_max(self.logits, dim=1) + other_logits = other.logits - layers.reduce_max(other.logits, dim=1) + + e_logits = layers.exp(logits) + other_e_logits = layers.exp(other_logits) + + z = layers.reduce_sum(e_logits, dim=1) + other_z = layers.reduce_sum(other_e_logits, dim=1) + + prob = e_logits / z + kl = layers.reduce_sum( + prob * + (logits - layers.log(z) - other_logits + layers.log(other_z)), + dim=1) + return kl diff --git a/parl/framework/tests/policy_distribution_test.py b/parl/framework/tests/policy_distribution_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bb729580583bf5b76851765a5e4db12d97600926 --- /dev/null +++ b/parl/framework/tests/policy_distribution_test.py @@ -0,0 +1,106 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl.layers as layers +import unittest +from paddle import fluid +from parameterized import parameterized +from parl.framework.policy_distribution import * +from parl.utils import get_gpu_count, np_softmax, np_cross_entropy + + +class PolicyDistributionTest(unittest.TestCase): + def setUp(self): + + gpu_count = get_gpu_count() + if gpu_count > 0: + place = fluid.CUDAPlace(0) + self.gpu_id = 0 + else: + place = fluid.CPUPlace() + self.gpu_id = -1 + self.executor = fluid.Executor(place) + + @parameterized.expand([('Batch1', 1), ('Batch5', 5)]) + def test_categorical_distribution(self, name, batch_size): + ACTIONS_NUM = 4 + test_program = fluid.Program() + with fluid.program_guard(test_program): + logits = layers.data( + name='logits', shape=[ACTIONS_NUM], dtype='float32') + other_logits = layers.data( + name='other_logits', shape=[ACTIONS_NUM], dtype='float32') + actions = layers.data(name='actions', shape=[], dtype='int64') + + categorical_distribution = CategoricalDistribution(logits) + other_categorical_distribution = CategoricalDistribution( + other_logits) + + sample_actions = categorical_distribution.sample() + entropy = categorical_distribution.entropy() + actions_logp = categorical_distribution.logp(actions) + kl = categorical_distribution.kl(other_categorical_distribution) + + self.executor.run(fluid.default_startup_program()) + + logits_np = np.random.randn(batch_size, ACTIONS_NUM).astype('float32') + other_logits_np = np.random.randn(batch_size, + ACTIONS_NUM).astype('float32') + actions_np = np.random.randint( + 0, high=ACTIONS_NUM, size=(batch_size, 1), dtype='int64') + + # ground truth calculated by numpy/python + gt_probs = np_softmax(logits_np) + gt_other_probs = np_softmax(other_logits_np) + gt_log_probs = np.log(gt_probs) + gt_entropy = -1.0 * np.sum(gt_probs * gt_log_probs, axis=1) + + gt_actions_logp = -1.0 * np_cross_entropy( + np_softmax(logits_np), actions_np) + gt_actions_logp = np.squeeze(gt_actions_logp, -1) + gt_kl = np.sum( + np.where(gt_probs != 0, + gt_probs * np.log(gt_probs / gt_other_probs), 0), + axis=-1) + + # result calculated by CategoricalDistribution + [ + output_sample_actions, output_entropy, output_actions_logp, + output_kl + ] = self.executor.run( + program=test_program, + feed={ + 'logits': logits_np, + 'other_logits': other_logits_np, + 'actions': np.squeeze(actions_np, axis=1) + }, + fetch_list=[sample_actions, entropy, actions_logp, kl]) + + # test entropy + np.testing.assert_almost_equal(output_entropy, gt_entropy, 5) + + # test logp + np.testing.assert_almost_equal(output_actions_logp, gt_actions_logp, 5) + + # test sample + action_ids = np.arange(ACTIONS_NUM) + assert np.isin(output_sample_actions, action_ids).all() + + # test kl + np.testing.assert_almost_equal(output_kl, gt_kl, 5) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/plutils/common.py b/parl/plutils/common.py index 762c463cd5ae505c39f9f806dd5f3cfb88cc725d..5462c795f3999248264566f1518a17654d3cc69e 100644 --- a/parl/plutils/common.py +++ b/parl/plutils/common.py @@ -18,7 +18,7 @@ Common functions of PARL framework import paddle.fluid as fluid from paddle.fluid.executor import _fetch_var -__all__ = ['fetch_framework_var', 'fetch_value', 'set_value'] +__all__ = ['fetch_framework_var', 'fetch_value', 'set_value', 'inverse'] def fetch_framework_var(attr_name): @@ -64,3 +64,16 @@ def set_value(attr_name, value, gpu_id): else fluid.CUDAPlace(gpu_id) var = _fetch_var(attr_name, return_numpy=False) var.set(value, place) + + +def inverse(x): + """ Inverse 0/1 variable + + Args: + x: variable with float32 dtype + + Returns: + inverse_x: variable with float32 dtype + """ + inverse_x = -1.0 * x + 1.0 + return inverse_x diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index 22d164d1fc6628bcab945a181a13498bd1f8c53a..d1b2bac928c2f0582b2587964f780172c1799a5a 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -14,5 +14,8 @@ from parl.utils.exceptions import * from parl.utils.utils import * +from parl.utils.csv_logger import * from parl.utils.machine_info import * +from parl.utils.np_utils import * from parl.utils.replay_memory import * +from parl.utils.scheduler import * diff --git a/parl/utils/csv_logger.py b/parl/utils/csv_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e5152e599b831fab0dbed3a43d3bb40b3412bcb5 --- /dev/null +++ b/parl/utils/csv_logger.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv + +__all__ = ['CSVLogger'] + + +class CSVLogger(object): + def __init__(self, output_file): + """CSV Logger which can write dict result to csv file + """ + self.output_file = open(output_file, "w") + self.csv_writer = None + + def log_dict(self, result): + if self.csv_writer is None: + self.csv_writer = csv.DictWriter(self.output_file, result.keys()) + self.csv_writer.writeheader() + + self.csv_writer.writerow({ + k: v + for k, v in result.items() if k in self.csv_writer.fieldnames + }) + + def flush(self): + self.output_file.flush() + + def close(self): + self.output_file.close() diff --git a/parl/utils/np_utils.py b/parl/utils/np_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f83c8c42f24aa11e68b234ba7c95a86681254a6b --- /dev/null +++ b/parl/utils/np_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +__all__ = ['np_softmax', 'np_cross_entropy'] + + +def np_softmax(logits): + return np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) + + +def np_cross_entropy(probs, labels): + if labels.shape[-1] == 1: + # sparse label + n_classes = probs.shape[-1] + result_shape = list(labels.shape[:-1]) + [n_classes] + labels = np.eye(n_classes)[labels.reshape(-1)] + labels = labels.reshape(result_shape) + + return -np.sum(labels * np.log(probs), axis=-1, keepdims=True) diff --git a/parl/utils/scheduler.py b/parl/utils/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..cdf936e2c1de356ae0aa5745cc26118a1be5d283 --- /dev/null +++ b/parl/utils/scheduler.py @@ -0,0 +1,55 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six + +__all__ = ['PiecewiseScheduler'] + + +class PiecewiseScheduler(object): + def __init__(self, scheduler_list): + """ Piecewise scheduler of hyper parameter. + + Args: + scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)] + """ + assert len(scheduler_list) > 0 + + for i in six.moves.range(len(scheduler_list) - 1): + assert scheduler_list[i][0] < scheduler_list[i + 1][0], \ + 'step of scheduler_list should be incremental.' + + self.scheduler_list = scheduler_list + + self.cur_index = 0 + self.cur_step = 0 + self.cur_value = self.scheduler_list[0][1] + + self.scheduler_num = len(self.scheduler_list) + + def step(self): + """ Step one and fetch value according to following rule: + + Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)], + function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1 + """ + + if self.cur_index < self.scheduler_num - 1: + if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]: + self.cur_index += 1 + self.cur_value = self.scheduler_list[self.cur_index][1] + + self.cur_step += 1 + + return self.cur_value diff --git a/parl/utils/tests/scheduler_test.py b/parl/utils/tests/scheduler_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a2dd6e488d2172a24493af3552b2cde528e27577 --- /dev/null +++ b/parl/utils/tests/scheduler_test.py @@ -0,0 +1,60 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from parl.utils.scheduler import PiecewiseScheduler + + +class TestScheduler(unittest.TestCase): + def test_PiecewiseScheduler_with_multi_values(self): + scheduler = PiecewiseScheduler([(0, 0.1), (3, 0.2), (7, 0.3)]) + for i in range(10): + value = scheduler.step() + if i < 3: + assert value == 0.1 + elif i < 7: + assert value == 0.2 + else: + assert value == 0.3 + + def test_PiecewiseScheduler_with_one_value(self): + scheduler = PiecewiseScheduler([(0, 0.1)]) + for i in range(10): + value = scheduler.step() + assert value == 0.1 + + scheduler = PiecewiseScheduler([(3, 0.1)]) + for i in range(10): + value = scheduler.step() + assert value == 0.1 + + def test_PiecewiseScheduler_with_empty(self): + try: + scheduler = PiecewiseScheduler([]) + except AssertionError: + # expected + return + assert False + + def test_PiecewiseScheduler_with_incorrect_steps(self): + try: + scheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)]) + except AssertionError: + # expected + return + assert False + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/utils/time_stat.py b/parl/utils/time_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..6b2997c0e387ff0acfabad123580244f73f035ad --- /dev/null +++ b/parl/utils/time_stat.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from parl.utils.window_stat import WindowStat + +__all_ = ['TimeStat'] + + +class TimeStat(object): + """A time stat for logging the elapsed time of code running + + Example: + time_stat = TimeStat() + with time_stat: + // some code + print(time_stat.mean) + """ + + def __init__(self, window_size=1): + self.time_samples = WindowStat(window_size) + self._start_time = None + + def __enter__(self): + self._start_time = time.time() + + def __exit__(self, type, value, tb): + time_delta = time.time() - self._start_time + self.time_samples.add(time_delta) + + @property + def mean(self): + return self.time_samples.mean + + @property + def min(self): + return self.time_samples.min + + @property + def max(self): + return self.time_samples.max diff --git a/parl/utils/utils.py b/parl/utils/utils.py index a604d116877059f259f28176881e852ed730c202..206f9bc4975b213d2f98cac327b448eccba2b6a0 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -15,7 +15,8 @@ import sys __all__ = [ - 'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3' + 'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3', + 'MAX_INT32' ] @@ -68,3 +69,6 @@ def is_PY2(): def is_PY3(): return sys.version_info[0] == 3 + + +MAX_INT32 = 0x7fffffff diff --git a/parl/utils/window_stat.py b/parl/utils/window_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ae55c5b1b3f9119f67ca00d05d303c3f9ae324 --- /dev/null +++ b/parl/utils/window_stat.py @@ -0,0 +1,54 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['WindowStat'] + +import numpy as np + + +class WindowStat(object): + """ Tool to maintain statistical data in a window. + """ + + def __init__(self, window_size): + self.items = [None] * window_size + self.idx = 0 + self.count = 0 + + def add(self, obj): + self.items[self.idx] = obj + self.idx += 1 + self.count += 1 + self.idx %= len(self.items) + + @property + def mean(self): + if self.count > 0: + return np.mean(self.items[:self.count]) + else: + return None + + @property + def min(self): + if self.count > 0: + return np.min(self.items[:self.count]) + else: + return None + + @property + def max(self): + if self.count > 0: + return np.max(self.items[:self.count]) + else: + return None diff --git a/setup.py b/setup.py index d6b2559e0c8e9f8b0302a8ecba65dea74938afd5..a6fad59ebb65074be9434c328b616b2547ab4bec 100644 --- a/setup.py +++ b/setup.py @@ -35,5 +35,7 @@ setup( package_data={'': ['*.so']}, install_requires=[ "termcolor>=1.1.0", + "pyzmq>=17.1.2", + "pyarrow>=0.12.0", ], )