diff --git a/README.md b/README.md
index 46bf73504cad5a090f94dedf43afa5874e882f6a..29b22db8c196c520a2f4bb10a5f613df949f60f1 100644
--- a/README.md
+++ b/README.md
@@ -98,7 +98,7 @@ Two steps to use outer computation resources:
As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(bule circle) is running at the local gpu with several remote actors(yellow circle with dotted edge).
-For users, they can write code in a simple way, just like writing multi-thread code, but with actors consuming remote resources. We have also provided examples of parallized algorithms like IMPALA, A2C and GA3C. For more details in usage please refer to these examples.
+For users, they can write code in a simple way, just like writing multi-thread code, but with actors consuming remote resources. We have also provided examples of parallized algorithms like [IMPALA](examples/IMPALA), [A2C](examples/A2C) and [GA3C](examples/GA3C). For more details in usage please refer to these examples.
# Install:
@@ -118,6 +118,7 @@ pip install parl
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
+- [GA3C](examples/GA3C/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
diff --git a/examples/A2C/README.md b/examples/A2C/README.md
index 483120bdd07465eb2e4fb2ef59dfd6e7a599b3e7..d76cc683b2b09fc4b85af05bc0c3e3de4f48ca9d 100644
--- a/examples/A2C/README.md
+++ b/examples/A2C/README.md
@@ -4,7 +4,7 @@ Based on PARL, the A2C algorithm of deep reinforcement learning has been reprodu
A2C is a synchronous, deterministic variant of [Asynchronous Advantage Actor Critic (A3C)](https://arxiv.org/abs/1602.01783). Instead of updating asynchronously in A3C or GA3C, A2C uses a synchronous approach that waits for each actor to finish its sampling before performing an update. Since loss definition of these A3C variants are identical, we use a common a3c algotrithm `parl.algorithms.A3C` for A2C and GA3C examples.
### Atari games introduction
-Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
+Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
Results with one learner (in a P40 GPU) and 5 actors in 10 million sample steps.
@@ -16,7 +16,6 @@ Results with one learner (in a P40 GPU) and 5 actors in 10 million sample steps.
+ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
-+ opencv-python
+ atari_py
diff --git a/examples/DDPG/README.md b/examples/DDPG/README.md
index 3c39fb63b487419eb9e485b47d89d9ae188be46f..f962897694be43087c8d6cce28229853779eb3f0 100644
--- a/examples/DDPG/README.md
+++ b/examples/DDPG/README.md
@@ -5,7 +5,7 @@ Based on PARL, the DDPG algorithm of deep reinforcement learning has been reprod
[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
### Mujoco games introduction
-Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco game.
+Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
### Benchmark result
diff --git a/examples/DQN/README.md b/examples/DQN/README.md
index af232ade4201644636f82b632d745831a3013b01..b3fda95a71f0757c566d35bc5e2289098600b504 100644
--- a/examples/DQN/README.md
+++ b/examples/DQN/README.md
@@ -5,7 +5,7 @@ Based on PARL, the DQN algorithm of deep reinforcement learning has been reprodu
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
### Atari games introduction
-Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
+Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
@@ -20,7 +20,6 @@ Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari g
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ tqdm
-+ opencv-python
+ atari_py
+ [ale_python_interface](https://github.com/mgbellemare/Arcade-Learning-Environment)
diff --git a/examples/GA3C/.benchmark/GA3C_Breakout.jpg b/examples/GA3C/.benchmark/GA3C_Breakout.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..553dcf9c6a72712b603c5dff72cfcf072636e167
Binary files /dev/null and b/examples/GA3C/.benchmark/GA3C_Breakout.jpg differ
diff --git a/examples/GA3C/.benchmark/GA3C_Pong.jpg b/examples/GA3C/.benchmark/GA3C_Pong.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dded7d32075598da6b555883452093b9bf273395
Binary files /dev/null and b/examples/GA3C/.benchmark/GA3C_Pong.jpg differ
diff --git a/examples/GA3C/README.md b/examples/GA3C/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f961343a057b7d3d19c8c9d19fa1f8a9964aa308
--- /dev/null
+++ b/examples/GA3C/README.md
@@ -0,0 +1,45 @@
+## Reproduce GA3C with PARL
+Based on PARL, the GA3C algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
+
+Original paper: [GA3C: GPU-based A3C for Deep Reinforcement Learning](https://www.researchgate.net/profile/Iuri_Frosio2/publication/310610848_GA3C_GPU-based_A3C_for_Deep_Reinforcement_Learning/links/583c6c0b08ae502a85e3dbb9/GA3C-GPU-based-A3C-for-Deep-Reinforcement-Learning.pdf)
+
+A hybrid CPU/GPU version of the [Asynchronous Advantage Actor-Critic (A3C)](https://arxiv.org/abs/1602.01783) algorithm.
+
+### Atari games introduction
+Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
+
+### Benchmark result
+Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 million sample steps.
+
+
+## How to use
+### Dependencies
++ python2.7 or python3.5+
++ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle)
++ [parl](https://github.com/PaddlePaddle/PARL)
++ gym
++ atari_py
+
+
+### Distributed Training
+
+#### Learner
+```sh
+python train.py
+```
+
+#### Simulators (Suggest: 24 simulators in 12+ CPUs)
+```sh
+for i in $(seq 1 24); do
+ python simulator.py &
+done;
+wait
+```
+
+You can change training settings (e.g. `env_name`, `server_ip`) in `ga3c_config.py`.
+Training result will be saved in `log_dir/train/result.csv`.
+
+[Tips] The performance can be influenced dramatically in a slower computational environment, especially when training with low-speed CPUs. It may be caused by the policy-lag problem.
+
+### Reference
++ [tensorpack](https://github.com/tensorpack/tensorpack)
diff --git a/examples/GA3C/atari_agent.py b/examples/GA3C/atari_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..c27311fb2030838744e491d89b645b729c07f080
--- /dev/null
+++ b/examples/GA3C/atari_agent.py
@@ -0,0 +1,134 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle.fluid as fluid
+import parl.layers as layers
+from parl.framework.agent_base import Agent
+
+
+class AtariAgent(Agent):
+ def __init__(self, algorithm, config, learn_data_provider=None):
+ self.config = config
+ super(AtariAgent, self).__init__(algorithm)
+
+ use_cuda = True if self.gpu_id >= 0 else False
+
+ exec_strategy = fluid.ExecutionStrategy()
+ exec_strategy.use_experimental_executor = True
+ exec_strategy.num_threads = 4
+ build_strategy = fluid.BuildStrategy()
+ build_strategy.remove_unnecessary_lock = True
+
+ # Use ParallelExecutor to make learn program run faster
+ self.learn_exe = fluid.ParallelExecutor(
+ use_cuda=use_cuda,
+ main_program=self.learn_program,
+ build_strategy=build_strategy,
+ exec_strategy=exec_strategy)
+
+ self.sample_exes = []
+ for _ in range(config['predict_thread_num']):
+ with fluid.scope_guard(fluid.global_scope().new_scope()):
+ pe = fluid.ParallelExecutor(
+ use_cuda=use_cuda,
+ main_program=self.sample_program,
+ build_strategy=build_strategy,
+ exec_strategy=exec_strategy)
+ self.sample_exes.append(pe)
+
+ if learn_data_provider:
+ self.learn_reader.decorate_tensor_provider(learn_data_provider)
+ self.learn_reader.start()
+
+ def build_program(self):
+ self.sample_program = fluid.Program()
+ self.predict_program = fluid.Program()
+ self.learn_program = fluid.Program()
+
+ with fluid.program_guard(self.sample_program):
+ obs = layers.data(
+ name='obs', shape=self.config['obs_shape'], dtype='float32')
+ sample_actions, values = self.alg.sample(obs)
+ self.sample_outputs = [sample_actions.name, values.name]
+
+ with fluid.program_guard(self.predict_program):
+ obs = layers.data(
+ name='obs', shape=self.config['obs_shape'], dtype='float32')
+ self.predict_actions = self.alg.predict(obs)
+
+ with fluid.program_guard(self.learn_program):
+ obs = layers.data(
+ name='obs', shape=self.config['obs_shape'], dtype='float32')
+ actions = layers.data(name='actions', shape=[], dtype='int64')
+ advantages = layers.data(
+ name='advantages', shape=[], dtype='float32')
+ target_values = layers.data(
+ name='target_values', shape=[], dtype='float32')
+ lr = layers.data(
+ name='lr', shape=[1], dtype='float32', append_batch_size=False)
+ entropy_coeff = layers.data(
+ name='entropy_coeff', shape=[], dtype='float32')
+
+ self.learn_reader = fluid.layers.create_py_reader_by_data(
+ capacity=self.config['train_batch_size'],
+ feed_list=[
+ obs, actions, advantages, target_values, lr, entropy_coeff
+ ])
+ obs, actions, advantages, target_values, lr, entropy_coeff = fluid.layers.read_file(
+ self.learn_reader)
+
+ total_loss, pi_loss, vf_loss, entropy = self.alg.learn(
+ obs, actions, advantages, target_values, lr, entropy_coeff)
+ self.learn_outputs = [
+ total_loss.name, pi_loss.name, vf_loss.name, entropy.name
+ ]
+
+ def sample(self, obs_np, thread_id):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape ([B] + observation_space)
+ Format of image input should be NCHW format.
+
+ Returns:
+ sample_ids: a numpy int64 array of shape [B]
+ values: a numpy float32 array of shape [B]
+ """
+ obs_np = obs_np.astype('float32')
+
+ sample_actions, values = self.sample_exes[thread_id].run(
+ feed={'obs': obs_np}, fetch_list=self.sample_outputs)
+ return sample_actions, values
+
+ def predict(self, obs_np):
+ """
+ Args:
+ obs_np: a numpy float32 array of shape ([B] + observation_space)
+ Format of image input should be NCHW format.
+
+ Returns:
+ sample_ids: a numpy int64 array of shape [B]
+ """
+ obs_np = obs_np.astype('float32')
+
+ predict_actions = self.fluid_executor.run(
+ self.predict_program,
+ feed={'obs': obs_np},
+ fetch_list=[self.predict_actions])[0]
+ return predict_actions
+
+ def learn(self):
+ total_loss, pi_loss, vf_loss, entropy = self.learn_exe.run(
+ fetch_list=self.learn_outputs)
+ return total_loss, pi_loss, vf_loss, entropy
diff --git a/examples/GA3C/atari_model.py b/examples/GA3C/atari_model.py
new file mode 120000
index 0000000000000000000000000000000000000000..56a659a3d3bd87171b8393135d8929a831abce16
--- /dev/null
+++ b/examples/GA3C/atari_model.py
@@ -0,0 +1 @@
+../A2C/atari_model.py
\ No newline at end of file
diff --git a/examples/GA3C/ga3c_config.py b/examples/GA3C/ga3c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0cfa759009cbb345d73eb373735b8430191d020
--- /dev/null
+++ b/examples/GA3C/ga3c_config.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+config = {
+ #========== remote config ==========
+ 'server_ip': 'localhost',
+ 'server_port': 8037,
+
+ #========== env config ==========
+ 'env_name': 'PongNoFrameskip-v4',
+ 'env_dim': 42,
+
+ #========== learner config ==========
+ 'train_batch_size': 128,
+ 'max_predict_batch_size': 16,
+ 'predict_thread_num': 2,
+ 't_max': 5,
+ 'gamma': 0.99,
+ 'lambda': 1.0, # GAE
+
+ # learning rate adjustment schedule: (train_step, learning_rate)
+ 'lr_scheduler': [(0, 0.0005), (100000, 0.0003), (200000, 0.0001)],
+
+ # coefficient of policy entropy adjustment schedule: (train_step, coefficient)
+ 'entropy_coeff_scheduler': [(0, -0.01)],
+ 'vf_loss_coeff': 0.5,
+ 'log_metrics_interval_s': 10,
+}
diff --git a/examples/GA3C/learner.py b/examples/GA3C/learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a5fff2fd6fde49794e8b80ffbd3289dd7227d01
--- /dev/null
+++ b/examples/GA3C/learner.py
@@ -0,0 +1,320 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gym
+import numpy as np
+import os
+import queue
+import six
+import time
+import threading
+from atari_model import AtariModel
+from atari_agent import AtariAgent
+from collections import defaultdict
+from parl import RemoteManager
+from parl.algorithms import A3C
+from parl.env.atari_wrappers import wrap_deepmind
+from parl.utils import logger, CSVLogger, get_gpu_count
+from parl.utils.scheduler import PiecewiseScheduler
+from parl.utils.time_stat import TimeStat
+from parl.utils.window_stat import WindowStat
+from parl.utils.rl_utils import calc_gae
+
+
+class Learner(object):
+ def __init__(self, config):
+ self.config = config
+
+ self.sample_data_queue = queue.Queue()
+ self.batch_buffer = defaultdict(list)
+
+ #=========== Create Agent ==========
+ env = gym.make(config['env_name'])
+ env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
+ obs_shape = env.observation_space.shape
+ act_dim = env.action_space.n
+
+ self.config['obs_shape'] = obs_shape
+ self.config['act_dim'] = act_dim
+
+ model = AtariModel(act_dim)
+ algorithm = A3C(model, hyperparas=config)
+ self.agent = AtariAgent(algorithm, config, self.learn_data_provider)
+
+ if self.agent.gpu_id >= 0:
+ assert get_gpu_count() == 1, 'Only support training in single GPU,\
+ Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
+
+ else:
+ cpu_num = os.environ.get('CPU_NUM')
+ assert cpu_num is not None and cpu_num == '1', 'Only support training in single CPU,\
+ Please set environment variable: `export CPU_NUM=1`.'
+
+ #========== Learner ==========
+ self.lr, self.entropy_coeff = None, None
+ self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
+ self.entropy_coeff_scheduler = PiecewiseScheduler(
+ config['entropy_coeff_scheduler'])
+
+ self.total_loss_stat = WindowStat(100)
+ self.pi_loss_stat = WindowStat(100)
+ self.vf_loss_stat = WindowStat(100)
+ self.entropy_stat = WindowStat(100)
+
+ self.learn_time_stat = TimeStat(100)
+ self.start_time = None
+
+ # learn thread
+ self.learn_thread = threading.Thread(target=self.run_learn)
+ self.learn_thread.setDaemon(True)
+ self.learn_thread.start()
+
+ self.predict_input_queue = queue.Queue()
+
+ # predict thread
+ self.predict_threads = []
+ for i in six.moves.range(self.config['predict_thread_num']):
+ predict_thread = threading.Thread(
+ target=self.run_predict, args=(i, ))
+ predict_thread.setDaemon(True)
+ predict_thread.start()
+ self.predict_threads.append(predict_thread)
+
+ #========== Remote Simulator ===========
+ self.remote_count = 0
+
+ self.remote_metrics_queue = queue.Queue()
+ self.sample_total_steps = 0
+
+ self.remote_manager_thread = threading.Thread(
+ target=self.run_remote_manager)
+ self.remote_manager_thread.setDaemon(True)
+ self.remote_manager_thread.start()
+
+ self.csv_logger = CSVLogger(
+ os.path.join(logger.get_dir(), 'result.csv'))
+
+ def learn_data_provider(self):
+ """ Data generator for fluid.layers.py_reader
+ """
+ B = self.config['train_batch_size']
+ while True:
+ sample_data = self.sample_data_queue.get()
+ self.sample_total_steps += len(sample_data['obs'])
+ for key in sample_data:
+ self.batch_buffer[key].extend(sample_data[key])
+
+ if len(self.batch_buffer['obs']) >= B:
+ batch = {}
+ for key in self.batch_buffer:
+ batch[key] = np.array(self.batch_buffer[key][:B])
+
+ obs_np = batch['obs'].astype('float32')
+ actions_np = batch['actions'].astype('int64')
+ advantages_np = batch['advantages'].astype('float32')
+ target_values_np = batch['target_values'].astype('float32')
+
+ self.lr = self.lr_scheduler.step()
+ self.entropy_coeff = self.entropy_coeff_scheduler.step()
+
+ yield [
+ obs_np, actions_np, advantages_np, target_values_np,
+ self.lr, self.entropy_coeff
+ ]
+
+ for key in self.batch_buffer:
+ self.batch_buffer[key] = self.batch_buffer[key][B:]
+
+ def run_predict(self, thread_id):
+ """ predict thread
+ """
+ batch_ident = []
+ batch_obs = []
+ while True:
+ ident, obs = self.predict_input_queue.get()
+
+ batch_ident.append(ident)
+ batch_obs.append(obs)
+ while len(batch_obs) < self.config['max_predict_batch_size']:
+ try:
+ ident, obs = self.predict_input_queue.get_nowait()
+ batch_ident.append(ident)
+ batch_obs.append(obs)
+ except queue.Empty:
+ break
+ if batch_obs:
+ batch_obs = np.array(batch_obs)
+ actions, values = self.agent.sample(batch_obs, thread_id)
+
+ for i, ident in enumerate(batch_ident):
+ self.predict_output_queues[ident].put((actions[i],
+ values[i]))
+ batch_ident = []
+ batch_obs = []
+
+ def run_learn(self):
+ """ Learn loop
+ """
+ while True:
+ with self.learn_time_stat:
+ total_loss, pi_loss, vf_loss, entropy = self.agent.learn()
+
+ self.total_loss_stat.add(total_loss)
+ self.pi_loss_stat.add(pi_loss)
+ self.vf_loss_stat.add(vf_loss)
+ self.entropy_stat.add(entropy)
+
+ def run_remote_manager(self):
+ """ Accept connection of new remote simulator and start simulation.
+ """
+ remote_manager = RemoteManager(port=self.config['server_port'])
+ logger.info("Waiting for the remote simulator's connection.")
+
+ ident = 0
+ self.predict_output_queues = []
+
+ while True:
+ remote_simulator = remote_manager.get_remote()
+
+ self.remote_count += 1
+ logger.info('Remote simulator count: {}'.format(self.remote_count))
+ if self.start_time is None:
+ self.start_time = time.time()
+
+ q = queue.Queue()
+ self.predict_output_queues.append(q)
+
+ remote_thread = threading.Thread(
+ target=self.run_remote_sample,
+ args=(
+ remote_simulator,
+ ident,
+ ))
+ remote_thread.setDaemon(True)
+ remote_thread.start()
+
+ ident += 1
+
+ def run_remote_sample(self, remote_simulator, ident):
+ """ Interacts with remote simulator.
+ """
+ mem = defaultdict(list)
+
+ obs = remote_simulator.reset()
+ while True:
+ self.predict_input_queue.put((ident, obs))
+ action, value = self.predict_output_queues[ident].get()
+
+ next_obs, reward, done = remote_simulator.step(action)
+
+ mem['obs'].append(obs)
+ mem['actions'].append(action)
+ mem['rewards'].append(reward)
+ mem['values'].append(value)
+
+ if done:
+ next_value = 0
+ advantages = calc_gae(mem['rewards'], mem['values'],
+ next_value, self.config['gamma'],
+ self.config['lambda'])
+ target_values = advantages + mem['values']
+
+ self.sample_data_queue.put({
+ 'obs': mem['obs'],
+ 'actions': mem['actions'],
+ 'advantages': advantages,
+ 'target_values': target_values
+ })
+
+ mem = defaultdict(list)
+
+ next_obs = remote_simulator.reset()
+
+ elif len(mem['obs']) == self.config['t_max'] + 1:
+ next_value = mem['values'][-1]
+ advantages = calc_gae(mem['rewards'][:-1], mem['values'][:-1],
+ next_value, self.config['gamma'],
+ self.config['lambda'])
+ target_values = advantages + mem['values'][:-1]
+
+ self.sample_data_queue.put({
+ 'obs': mem['obs'][:-1],
+ 'actions': mem['actions'][:-1],
+ 'advantages': advantages,
+ 'target_values': target_values
+ })
+
+ for key in mem:
+ mem[key] = [mem[key][-1]]
+
+ obs = next_obs
+
+ if done:
+ metrics = remote_simulator.get_metrics()
+ if metrics:
+ self.remote_metrics_queue.put(metrics)
+
+ def log_metrics(self):
+ """ Log metrics of learner and simulators
+ """
+ if self.start_time is None:
+ return
+
+ metrics = []
+ while True:
+ try:
+ metric = self.remote_metrics_queue.get_nowait()
+ metrics.append(metric)
+ except queue.Empty:
+ break
+
+ episode_rewards, episode_steps = [], []
+ for x in metrics:
+ episode_rewards.extend(x['episode_rewards'])
+ episode_steps.extend(x['episode_steps'])
+ max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
+ max_episode_steps, mean_episode_steps, min_episode_steps =\
+ None, None, None, None, None, None
+ if episode_rewards:
+ mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
+ max_episode_rewards = np.max(np.array(episode_rewards).flatten())
+ min_episode_rewards = np.min(np.array(episode_rewards).flatten())
+
+ mean_episode_steps = np.mean(np.array(episode_steps).flatten())
+ max_episode_steps = np.max(np.array(episode_steps).flatten())
+ min_episode_steps = np.min(np.array(episode_steps).flatten())
+
+ metric = {
+ 'Sample steps': self.sample_total_steps,
+ 'max_episode_rewards': max_episode_rewards,
+ 'mean_episode_rewards': mean_episode_rewards,
+ 'min_episode_rewards': min_episode_rewards,
+ 'max_episode_steps': max_episode_steps,
+ 'mean_episode_steps': mean_episode_steps,
+ 'min_episode_steps': min_episode_steps,
+ 'total_loss': self.total_loss_stat.mean,
+ 'pi_loss': self.pi_loss_stat.mean,
+ 'vf_loss': self.vf_loss_stat.mean,
+ 'entropy': self.entropy_stat.mean,
+ 'learn_time_s': self.learn_time_stat.mean,
+ 'elapsed_time_s': int(time.time() - self.start_time),
+ 'lr': self.lr,
+ 'entropy_coeff': self.entropy_coeff,
+ }
+
+ logger.info(metric)
+ self.csv_logger.log_dict(metric)
+
+ def close(self):
+ self.csv_logger.close()
diff --git a/examples/GA3C/run_simulators.sh b/examples/GA3C/run_simulators.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9ad2825f76ee0c2ac0ac3116d2c2b661b1fd93cf
--- /dev/null
+++ b/examples/GA3C/run_simulators.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+export CUDA_VISIBLE_DEVICES=""
+
+for i in $(seq 1 24); do
+ python simulator.py &
+done;
+wait
diff --git a/examples/GA3C/simulator.py b/examples/GA3C/simulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2019460e05272026f2e4a41c4729dbb38cfd2c68
--- /dev/null
+++ b/examples/GA3C/simulator.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gym
+import numpy as np
+import parl
+import six
+from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
+from collections import defaultdict
+
+
+@parl.remote_class
+class Simulator(object):
+ def __init__(self, config):
+ self.config = config
+
+ env = gym.make(config['env_name'])
+ self.env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ return obs, reward, done
+
+ def reset(self):
+ obs = self.env.reset()
+ return obs
+
+ def get_metrics(self):
+ metrics = defaultdict(list)
+ monitor = get_wrapper_by_cls(self.env, MonitorEnv)
+ if monitor is not None:
+ for episode_rewards, episode_steps in monitor.next_episode_results(
+ ):
+ metrics['episode_rewards'].append(episode_rewards)
+ metrics['episode_steps'].append(episode_steps)
+ return metrics
+
+
+if __name__ == '__main__':
+ from ga3c_config import config
+
+ simulator = Simulator(config)
+ simulator.as_remote(config['server_ip'], config['server_port'])
diff --git a/examples/GA3C/train.py b/examples/GA3C/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..ade43899b5caad97ba9bf2faf02f46878c8ffbad
--- /dev/null
+++ b/examples/GA3C/train.py
@@ -0,0 +1,34 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from learner import Learner
+
+
+def main(config):
+ learner = Learner(config)
+
+ try:
+ while True:
+ time.sleep(config['log_metrics_interval_s'])
+
+ learner.log_metrics()
+
+ except KeyboardInterrupt:
+ learner.close()
+
+
+if __name__ == '__main__':
+ from ga3c_config import config
+ main(config)
diff --git a/examples/IMPALA/README.md b/examples/IMPALA/README.md
index fd2d6b2a67d4c3bff0c775989fd01e7da9b3b8d8..35510d6d6f323f77bf4714de0a99d1b3c69fb385 100644
--- a/examples/IMPALA/README.md
+++ b/examples/IMPALA/README.md
@@ -5,7 +5,7 @@ Based on PARL, the IMPALA algorithm of deep reinforcement learning is reproduced
[Impala: Scalable distributed deep-rl with importance weighted actor-learner architectures](https://arxiv.org/abs/1802.01561)
### Atari games introduction
-Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
+Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari games.
### Benchmark result
Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs).
@@ -24,7 +24,6 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs).
+ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
-+ opencv-python
+ atari_py
diff --git a/examples/PPO/README.md b/examples/PPO/README.md
index befdc47139eac2d990a1c5268ad9c072af507bcb..af88782b8e916eac05da89d6a968c251592f8e51 100644
--- a/examples/PPO/README.md
+++ b/examples/PPO/README.md
@@ -9,7 +9,7 @@ Include following approach:
[Proximal Policy Optimization Algorithms](https://arxiv.org/abs/1707.06347)
### Mujoco games introduction
-Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco game.
+Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
### Benchmark result
diff --git a/parl/framework/policy_distribution.py b/parl/framework/policy_distribution.py
index 25aab040330ba3766c8266d463fc75a98ebca9e8..a037d9520f629e6aa433a29b06e21e13e21650dd 100644
--- a/parl/framework/policy_distribution.py
+++ b/parl/framework/policy_distribution.py
@@ -70,22 +70,31 @@ class CategoricalDistribution(PolicyDistribution):
return entropy
- def logp(self, actions):
+ def logp(self, actions, eps=1e-6):
"""
Args:
actions: An int64 tensor with shape [BATCH_SIZE]
+ eps: A small float constant that avoids underflows
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
assert len(actions.shape) == 1
+
+ logits = self.logits - layers.reduce_max(self.logits, dim=1)
+ e_logits = layers.exp(logits)
+ z = layers.reduce_sum(e_logits, dim=1)
+ prob = e_logits / z
+
actions = layers.unsqueeze(actions, axes=[1])
+ actions_onehot = layers.one_hot(actions, prob.shape[1])
+ actions_onehot = layers.cast(actions_onehot, dtype='float32')
+ actions_prob = layers.reduce_sum(prob * actions_onehot, dim=1)
- cross_entropy = layers.softmax_with_cross_entropy(
- logits=self.logits, label=actions)
+ actions_prob = actions_prob + eps
+ actions_log_prob = layers.log(actions_prob)
- actions_log_prob = -1.0 * layers.squeeze(cross_entropy, axes=[-1])
return actions_log_prob
def kl(self, other):
diff --git a/parl/framework/tests/policy_distribution_test.py b/parl/framework/tests/policy_distribution_test.py
index bb729580583bf5b76851765a5e4db12d97600926..ea4a521f3a856f63e52c5e24ef01b14229de18ef 100644
--- a/parl/framework/tests/policy_distribution_test.py
+++ b/parl/framework/tests/policy_distribution_test.py
@@ -67,8 +67,7 @@ class PolicyDistributionTest(unittest.TestCase):
gt_log_probs = np.log(gt_probs)
gt_entropy = -1.0 * np.sum(gt_probs * gt_log_probs, axis=1)
- gt_actions_logp = -1.0 * np_cross_entropy(
- np_softmax(logits_np), actions_np)
+ gt_actions_logp = -1.0 * np_cross_entropy(gt_probs + 1e-6, actions_np)
gt_actions_logp = np.squeeze(gt_actions_logp, -1)
gt_kl = np.sum(
np.where(gt_probs != 0,