diff --git a/examples/A2C/README.md b/examples/A2C/README.md index 195a2af72708d9699d2d5e0b3baf191d44598b4a..2f5eec60aa9333047f1525a675677e3af5b27f15 100755 --- a/examples/A2C/README.md +++ b/examples/A2C/README.md @@ -18,8 +18,8 @@ Mean episode reward in training process after 10 million sample steps. ### Dependencies + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [parl](https://github.com/PaddlePaddle/PARL) -+ gym -+ atari-py ++ gym==0.12.1 ++ atari-py==0.1.7 ### Distributed Training @@ -34,10 +34,10 @@ Note that if you have started a master before, you don't have to run the above command. For more information about the cluster, please refer to our [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) -Then we can start the distributed training by running `train.py`. +Then we can start the distributed training by running `learner.py`. ```bash -python train.py +python learner.py ``` ### Reference diff --git a/examples/A2C/learner.py b/examples/A2C/learner.py index 6624be44be9365e69dc28be492e38156272cd426..7409df6e53bdaef334729d396bec5c2b2efde5b9 100755 --- a/examples/A2C/learner.py +++ b/examples/A2C/learner.py @@ -216,3 +216,16 @@ class Learner(object): def should_stop(self): return self.sample_total_steps >= self.config['max_sample_steps'] + + +if __name__ == '__main__': + from a2c_config import config + + learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 + + while not learner.should_stop(): + start = time.time() + while time.time() - start < config['log_metrics_interval_s']: + learner.step() + learner.log_metrics() diff --git a/examples/ES/README.md b/examples/ES/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f25a449f1ac17a4d5190d08ff1986708d72cdcb0 --- /dev/null +++ b/examples/ES/README.md @@ -0,0 +1,38 @@ +## Reproduce ES with PARL +Based on PARL, the Evolution Strategies (ES) algorithm has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks. + ++ ES in +[Evolution Strategies as a Scalable Alternative to Reinforcement Learning](https://arxiv.org/abs/1703.03864) + +### Mujoco games introduction +Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games. + +### Benchmark result +TODO + +## How to use +### Dependencies ++ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) ++ [parl](https://github.com/PaddlePaddle/PARL) ++ gym==0.9.4 ++ mujoco-py==0.5.1 + + +### Distributed Training + +#### Learner +```sh +python learner.py +``` + +#### Actors +```sh +sh run_actors.sh +``` + +You can change training settings (e.g. `env_name`, `server_ip`) in `es_config.py`. If you want to use different number of actors, please modify `actor_num` in both `es_config.py` and `run_actors.sh`. +Training result will be saved in `log_dir/train/result.csv`. + +### Reference ++ [Ray](https://github.com/ray-project/ray) ++ [evolution-strategies-starter](https://github.com/openai/evolution-strategies-starter) diff --git a/examples/ES/actor.py b/examples/ES/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..683549c93b46c2ac973170a2343572181577e486 --- /dev/null +++ b/examples/ES/actor.py @@ -0,0 +1,124 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import parl +import time +import numpy as np +from es import ES +from obs_filter import MeanStdFilter +from mujoco_agent import MujocoAgent +from mujoco_model import MujocoModel +from noise import SharedNoiseTable + + +@parl.remote_class +class Actor(object): + def __init__(self, config): + self.config = config + + self.env = gym.make(self.config['env_name']) + self.config['obs_dim'] = self.env.observation_space.shape[0] + self.config['act_dim'] = self.env.action_space.shape[0] + + self.obs_filter = MeanStdFilter(self.config['obs_dim']) + self.noise = SharedNoiseTable(self.config['noise_size']) + + model = MujocoModel(self.config['act_dim']) + algorithm = ES(model) + self.agent = MujocoAgent(algorithm, self.config) + + def _play_one_episode(self, add_noise=False): + episode_reward = 0 + episode_step = 0 + + obs = self.env.reset() + while True: + if np.random.uniform() < self.config['filter_update_prob']: + obs = self.obs_filter(obs[None], update=True) + else: + obs = self.obs_filter(obs[None], update=False) + + action = self.agent.predict(obs) + if add_noise: + action += np.random.randn( + *action.shape) * self.config['action_noise_std'] + + obs, reward, done, _ = self.env.step(action) + episode_reward += reward + episode_step += 1 + if done: + break + return episode_reward, episode_step + + def sample(self, flat_weights): + noise_indices, rewards, lengths = [], [], [] + eval_rewards, eval_lengths = [], [] + + # Perform some rollouts with noise. + task_tstart = time.time() + while (len(noise_indices) == 0 + or time.time() - task_tstart < self.config['min_task_runtime']): + + if np.random.uniform() < self.config["eval_prob"]: + # Do an evaluation run with no perturbation. + self.agent.set_flat_weights(flat_weights) + episode_reward, episode_step = self._play_one_episode( + add_noise=False) + eval_rewards.append(episode_reward) + eval_lengths.append(episode_step) + else: + # Do a regular run with parameter perturbations. + noise_index = self.noise.sample_index( + self.agent.weights_total_size) + + perturbation = self.config["noise_stdev"] * self.noise.get( + noise_index, self.agent.weights_total_size) + + # mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon + self.agent.set_flat_weights(flat_weights + perturbation) + episode_reward_pos, episode_step_pos = self._play_one_episode( + add_noise=True) + + self.agent.set_flat_weights(flat_weights - perturbation) + episode_reward_neg, episode_step_neg = self._play_one_episode( + add_noise=True) + + noise_indices.append(noise_index) + rewards.append([episode_reward_pos, episode_reward_neg]) + lengths.append([episode_step_pos, episode_step_neg]) + + return { + 'noise_indices': noise_indices, + 'noisy_rewards': rewards, + 'noisy_lengths': lengths, + 'eval_rewards': eval_rewards, + 'eval_lengths': eval_lengths + } + + def get_filter(self, flush_after=False): + return_filter = self.obs_filter.as_serializable() + if flush_after: + self.obs_filter.clear_buffer() + return return_filter + + def set_filter(self, new_filter): + self.obs_filter.sync(new_filter) + + +if __name__ == '__main__': + from es_config import config + + actor = Actor(config) + actor.as_remote(config['server_ip'], config['server_port']) diff --git a/examples/ES/es.py b/examples/ES/es.py new file mode 100644 index 0000000000000000000000000000000000000000..71a1cac50f33d21478576a67a50b6e63ae6073ce --- /dev/null +++ b/examples/ES/es.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl + +__all__ = ['ES'] + + +class ES(parl.Algorithm): + def __init__(self, model): + """ES algorithm. + + Since parameters of the model is updated in the numpy level, `learn` function is not needed + in this algorithm. + + Args: + model(`parl.Model`): policy model of ES algorithm. + """ + self.model = model + + def predict(self, obs): + """Use the policy model to predict actions of observations. + + Args: + obs(layers.data): data layer of observations. + + Returns: + tensor of predicted actions. + """ + return self.model(obs) diff --git a/examples/GA3C/train.py b/examples/ES/es_config.py old mode 100755 new mode 100644 similarity index 50% rename from examples/GA3C/train.py rename to examples/ES/es_config.py index 0d351ee55b10cdc884afc8f2b8607deb0ddae222..ad4c9c402ffd379b3b676a509f244110b4e7f347 --- a/examples/GA3C/train.py +++ b/examples/ES/es_config.py @@ -12,20 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -from learner import Learner +config = { + #========== remote config ========== + 'server_ip': 'localhost', + 'server_port': 8037, + #========== env config ========== + 'env_name': 'Humanoid-v1', -def main(config): - learner = Learner(config) - assert config['log_metrics_interval_s'] > 0 + #========== actor config ========== + 'actor_num': 96, + 'action_noise_std': 0.01, + 'min_task_runtime': 0.2, + 'eval_prob': 0.003, + 'filter_update_prob': 0.01, - while True: - time.sleep(config['log_metrics_interval_s']) - - learner.log_metrics() - - -if __name__ == '__main__': - from ga3c_config import config - main(config) + #========== learner config ========== + 'stepsize': 0.01, + 'min_episodes_per_batch': 1000, + 'min_steps_per_batch': 10000, + 'noise_size': 200000000, + 'noise_stdev': 0.02, + 'l2_coeff': 0.005, + 'report_window_size': 10, +} diff --git a/examples/ES/learner.py b/examples/ES/learner.py new file mode 100644 index 0000000000000000000000000000000000000000..478dfaf23570ebac65c59061a4b0a5ffb5fe452f --- /dev/null +++ b/examples/ES/learner.py @@ -0,0 +1,217 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gym +import os +import parl +import numpy as np +import threading +import utils +from es import ES +from obs_filter import MeanStdFilter +from mujoco_agent import MujocoAgent +from mujoco_model import MujocoModel +from noise import SharedNoiseTable +from parl import RemoteManager +from parl.utils import logger, tensorboard +from parl.utils.window_stat import WindowStat +from six.moves import queue + + +class Learner(object): + def __init__(self, config): + self.config = config + + env = gym.make(self.config['env_name']) + self.config['obs_dim'] = env.observation_space.shape[0] + self.config['act_dim'] = env.action_space.shape[0] + + self.obs_filter = MeanStdFilter(self.config['obs_dim']) + self.noise = SharedNoiseTable(self.config['noise_size']) + + model = MujocoModel(self.config['act_dim']) + algorithm = ES(model) + self.agent = MujocoAgent(algorithm, self.config) + + self.latest_flat_weights = self.agent.get_flat_weights() + self.latest_obs_filter = self.obs_filter.as_serializable() + + self.sample_total_episodes = 0 + self.sample_total_steps = 0 + + self.actors_signal_input_queues = [] + self.actors_output_queues = [] + + self.run_remote_manager() + + self.eval_rewards_stat = WindowStat(self.config['report_window_size']) + self.eval_lengths_stat = WindowStat(self.config['report_window_size']) + + def run_remote_manager(self): + """ Accept connection of new remote actor and start sampling of the remote actor. + """ + remote_manager = RemoteManager(port=self.config['server_port']) + logger.info('Waiting for {} remote actors to connect.'.format( + self.config['actor_num'])) + + self.remote_count = 0 + for i in range(self.config['actor_num']): + remote_actor = remote_manager.get_remote() + signal_queue = queue.Queue() + output_queue = queue.Queue() + self.actors_signal_input_queues.append(signal_queue) + self.actors_output_queues.append(output_queue) + + self.remote_count += 1 + logger.info('Remote actor count: {}'.format(self.remote_count)) + + remote_thread = threading.Thread( + target=self.run_remote_sample, + args=(remote_actor, signal_queue, output_queue)) + remote_thread.setDaemon(True) + remote_thread.start() + + logger.info('All remote actors are ready, begin to learn.') + + def run_remote_sample(self, remote_actor, signal_queue, output_queue): + """ Sample data from remote actor or get filters of remote actor. + """ + while True: + info = signal_queue.get() + if info['signal'] == 'sample': + result = remote_actor.sample(self.latest_flat_weights) + output_queue.put(result) + elif info['signal'] == 'get_filter': + actor_filter = remote_actor.get_filter(flush_after=True) + output_queue.put(actor_filter) + elif info['signal'] == 'set_filter': + remote_actor.set_filter(self.latest_obs_filter) + else: + raise NotImplementedError + + def step(self): + """Run a step in ES. + + 1. kick off all actors to synchronize weights and sample data; + 2. update parameters of the model based on sampled data. + 3. update global observation filter based on local filters of all actors, and synchronize global + filter to all actors. + """ + num_episodes, num_timesteps = 0, 0 + results = [] + + while num_episodes < self.config['min_episodes_per_batch'] or \ + num_timesteps < self.config['min_steps_per_batch']: + # Send sample signal to all actors + for q in self.actors_signal_input_queues: + q.put({'signal': 'sample'}) + + # Collect results from all actors + for q in self.actors_output_queues: + result = q.get() + results.append(result) + # result['noisy_lengths'] is a list of lists, where the inner lists have length 2. + num_episodes += sum( + len(pair) for pair in result['noisy_lengths']) + num_timesteps += sum( + sum(pair) for pair in result['noisy_lengths']) + + all_noise_indices = [] + all_training_rewards = [] + all_training_lengths = [] + all_eval_rewards = [] + all_eval_lengths = [] + + for result in results: + all_eval_rewards.extend(result['eval_rewards']) + all_eval_lengths.extend(result['eval_lengths']) + + all_noise_indices.extend(result['noise_indices']) + all_training_rewards.extend(result['noisy_rewards']) + all_training_lengths.extend(result['noisy_lengths']) + + assert len(all_eval_rewards) == len(all_eval_lengths) + assert (len(all_noise_indices) == len(all_training_rewards) == + len(all_training_lengths)) + + self.sample_total_episodes += num_episodes + self.sample_total_steps += num_timesteps + + eval_rewards = np.array(all_eval_rewards) + eval_lengths = np.array(all_eval_lengths) + noise_indices = np.array(all_noise_indices) + noisy_rewards = np.array(all_training_rewards) + noisy_lengths = np.array(all_training_lengths) + + # normalize rewards to (-0.5, 0.5) + proc_noisy_rewards = utils.compute_centered_ranks(noisy_rewards) + noises = [ + self.noise.get(index, self.agent.weights_total_size) + for index in noise_indices + ] + + # Update the parameters of the model. + self.agent.learn(proc_noisy_rewards, noises) + self.latest_flat_weights = self.agent.get_flat_weights() + + # Update obs filter + self._update_filter() + + # Store the evaluate rewards + if len(all_eval_rewards) > 0: + self.eval_rewards_stat.add(np.mean(eval_rewards)) + self.eval_lengths_stat.add(np.mean(eval_lengths)) + + metrics = { + "episodes_this_iter": noisy_lengths.size, + "sample_total_episodes": self.sample_total_episodes, + 'sample_total_steps': self.sample_total_steps, + "evaluate_rewards_mean": self.eval_rewards_stat.mean, + "evaluate_steps_mean": self.eval_lengths_stat.mean, + "timesteps_this_iter": noisy_lengths.sum(), + } + + self.log_metrics(metrics) + return metrics + + def _update_filter(self): + # Send get_filter signal to all actors + for q in self.actors_signal_input_queues: + q.put({'signal': 'get_filter'}) + + filters = [] + # Collect filters from all actors and update global filter + for q in self.actors_output_queues: + actor_filter = q.get() + self.obs_filter.apply_changes(actor_filter) + + # Send set_filter signal to all actors + self.latest_obs_filter = self.obs_filter.as_serializable() + for q in self.actors_signal_input_queues: + q.put({'signal': 'set_filter'}) + + def log_metrics(self, metrics): + logger.info(metrics) + for k, v in metrics.items(): + if v is not None: + tensorboard.add_scalar(k, v, self.sample_total_steps) + + +if __name__ == '__main__': + from es_config import config + + learner = Learner(config) + + while True: + learner.step() diff --git a/examples/ES/mujoco_agent.py b/examples/ES/mujoco_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..58260d7495a23c4f08085895eb6a5158813a6c83 --- /dev/null +++ b/examples/ES/mujoco_agent.py @@ -0,0 +1,88 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +import parl +import utils +from parl import layers +from optimizers import Adam + + +class MujocoAgent(parl.Agent): + def __init__(self, algorithm, config): + self.config = config + super(MujocoAgent, self).__init__(algorithm) + + weights = self.get_weights() + assert len( + weights) == 1, "There should be only one model in the algorithm." + self.weights_name = list(weights.keys())[0] + weights = list(weights.values())[0] + self.weights_shapes = [x.shape for x in weights] + self.weights_total_size = np.sum( + [np.prod(x) for x in self.weights_shapes]) + + self.optimizer = Adam(self.weights_total_size, self.config['stepsize']) + + def build_program(self): + self.predict_program = fluid.Program() + + with fluid.program_guard(self.predict_program): + obs = layers.data( + name='obs', shape=[self.config['obs_dim']], dtype='float32') + self.predict_action = self.alg.predict(obs) + self.predict_program = parl.compile(self.predict_program) + + def learn(self, noisy_rewards, noises): + """ Update weights of the model in the numpy level. + + Compute the grident and take a step. + + Args: + noisy_rewards(np.float32): [batch_size, 2] + noises(np.float32): [batch_size, weights_total_size] + """ + + g, count = utils.batched_weighted_sum( + # mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon + noisy_rewards[:, 0] - noisy_rewards[:, 1], + noises, + batch_size=500) + g /= noisy_rewards.size + + latest_flat_weights = self.get_flat_weights() + # Compute the new weights theta. + theta, update_ratio = self.optimizer.update( + latest_flat_weights, + -g + self.config["l2_coeff"] * latest_flat_weights) + self.set_flat_weights(theta) + + def predict(self, obs): + obs = obs.astype('float32') + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run( + program=self.predict_program, + feed={'obs': obs}, + fetch_list=[self.predict_action])[0] + return act + + def get_flat_weights(self): + weights = list(self.get_weights().values())[0] + flat_weights = np.concatenate([x.flatten() for x in weights]) + return flat_weights + + def set_flat_weights(self, flat_weights): + weights = utils.unflatten(flat_weights, self.weights_shapes) + self.set_weights({self.weights_name: weights}) diff --git a/examples/A2C/train.py b/examples/ES/mujoco_model.py old mode 100755 new mode 100644 similarity index 57% rename from examples/A2C/train.py rename to examples/ES/mujoco_model.py index e100e46677a5b5ca4d0b905587e800754b62d613..77c57aa68c99df0b0d25144f3e8926e0743c3df8 --- a/examples/A2C/train.py +++ b/examples/ES/mujoco_model.py @@ -12,22 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -from learner import Learner +import paddle.fluid as fluid import parl +from parl import layers -def main(config): - learner = Learner(config) - assert config['log_metrics_interval_s'] > 0 +class MujocoModel(parl.Model): + def __init__(self, act_dim): + hid1_size = 256 + hid2_size = 256 - while not learner.should_stop(): - start = time.time() - while time.time() - start < config['log_metrics_interval_s']: - learner.step() - learner.log_metrics() + self.fc1 = layers.fc(size=hid1_size, act='tanh') + self.fc2 = layers.fc(size=hid2_size, act='tanh') + self.fc3 = layers.fc(size=act_dim) - -if __name__ == '__main__': - from a2c_config import config - main(config) + def forward(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.fc3(hid2) + return means diff --git a/examples/ES/noise.py b/examples/ES/noise.py new file mode 100644 index 0000000000000000000000000000000000000000..30dacab910f561d85f7309caed581620c94eedac --- /dev/null +++ b/examples/ES/noise.py @@ -0,0 +1,31 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/filter.py + +import numpy as np + + +class SharedNoiseTable(object): + """Shared noise table used by learner and actor. + + Learner and actor will create a same noise table by passing the same seed. + With the same noise table, learner and actor can communicate the noises by + index of noise table instead of numpy array of noises. + """ + + def __init__(self, noise_size, seed=1024): + self.noise_size = noise_size + self.seed = seed + self.noise = self._create_noise() + + def _create_noise(self): + noise = np.random.RandomState(self.seed).randn(self.noise_size).astype( + np.float32) + return noise + + def get(self, i, dim): + return self.noise[i:i + dim] + + def sample_index(self, dim): + return np.random.randint(0, len(self.noise) - dim + 1) diff --git a/examples/ES/obs_filter.py b/examples/ES/obs_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..b21bbe4a0ebe4d09852890d02bfbddd8faaedbd1 --- /dev/null +++ b/examples/ES/obs_filter.py @@ -0,0 +1,212 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/filter.py + +import numpy as np + + +class Filter(object): + """Processes input, possibly statefully.""" + + def apply_changes(self, other, *args, **kwargs): + """Updates self with "new state" from other filter.""" + raise NotImplementedError + + def copy(self): + """Creates a new object with same state as self. + + Returns: + A copy of self. + """ + raise NotImplementedError + + def sync(self, other): + """Copies all state from other filter to self.""" + raise NotImplementedError + + def clear_buffer(self): + """Creates copy of current state and clears accumulated state""" + raise NotImplementedError + + def as_serializable(self): + raise NotImplementedError + + +# http://www.johndcook.com/blog/standard_deviation/ +class RunningStat(object): + def __init__(self, shape=None): + self._n = 0 + self._M = np.zeros(shape) + self._S = np.zeros(shape) + + def copy(self): + other = RunningStat() + other._n = self._n + other._M = np.copy(self._M) + other._S = np.copy(self._S) + return other + + def push(self, x): + x = np.asarray(x) + # Unvectorized update of the running statistics. + if x.shape != self._M.shape: + raise ValueError( + "Unexpected input shape {}, expected {}, value = {}".format( + x.shape, self._M.shape, x)) + n1 = self._n + self._n += 1 + if self._n == 1: + self._M[...] = x + else: + delta = x - self._M + self._M[...] += delta / self._n + self._S[...] += delta * delta * n1 / self._n + + def update(self, other): + n1 = self._n + n2 = other._n + n = n1 + n2 + if n == 0: + # Avoid divide by zero, which creates nans + return + delta = self._M - other._M + delta2 = delta * delta + M = (n1 * self._M + n2 * other._M) / n + S = self._S + other._S + delta2 * n1 * n2 / n + self._n = n + self._M = M + self._S = S + + def __repr__(self): + return '(n={}, mean_mean={}, mean_std={})'.format( + self.n, np.mean(self.mean), np.mean(self.std)) + + @property + def n(self): + return self._n + + @property + def mean(self): + return self._M + + @property + def var(self): + return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) + + @property + def std(self): + return np.sqrt(self.var) + + @property + def shape(self): + return self._M.shape + + +class MeanStdFilter(Filter): + """Keeps track of a running mean for seen states. + + The filter will be used to normalize observations and will be + online updated according to the seen observations of all actors. + """ + is_concurrent = False + + def __init__(self, shape, demean=True, destd=True, clip=10.0): + self.shape = shape + self.demean = demean + self.destd = destd + self.clip = clip + self.rs = RunningStat(shape) + # In distributed rollouts, each worker sees different states. + # The buffer is used to keep track of deltas amongst all the + # observation filters. + + self.buffer = RunningStat(shape) + + def clear_buffer(self): + self.buffer = RunningStat(self.shape) + + def apply_changes(self, other, with_buffer=False): + """Applies updates from the buffer of another filter. + + Params: + other (MeanStdFilter): Other filter to apply info from + with_buffer (bool): Flag for specifying if the buffer should be + copied from other. + + Examples: + >>> a = MeanStdFilter(()) + >>> a(1) + >>> a(2) + >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + [2, 1.5, 2] + >>> b = MeanStdFilter(()) + >>> b(10) + >>> a.apply_changes(b, with_buffer=False) + >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + [3, 4.333333333333333, 2] + >>> a.apply_changes(b, with_buffer=True) + >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + [4, 5.75, 1] + """ + self.rs.update(other.buffer) + if with_buffer: + self.buffer = other.buffer.copy() + + def copy(self): + """Returns a copy of Filter.""" + other = MeanStdFilter(self.shape) + other.sync(self) + return other + + def as_serializable(self): + return self.copy() + + def sync(self, other): + """Syncs all fields together from other filter. + + Examples: + >>> a = MeanStdFilter(()) + >>> a(1) + >>> a(2) + >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + [2, array(1.5), 2] + >>> b = MeanStdFilter(()) + >>> b(10) + >>> print([b.rs.n, b.rs.mean, b.buffer.n]) + [1, array(10.0), 1] + >>> a.sync(b) + >>> print([a.rs.n, a.rs.mean, a.buffer.n]) + [1, array(10.0), 1] + """ + assert other.shape == self.shape, "Shapes don't match!" + self.demean = other.demean + self.destd = other.destd + self.clip = other.clip + self.rs = other.rs.copy() + self.buffer = other.buffer.copy() + + def __call__(self, x, update=True): + x = np.asarray(x) + if update: + if len(x.shape) == len(self.rs.shape) + 1: + # The vectorized case. + for i in range(x.shape[0]): + self.rs.push(x[i]) + self.buffer.push(x[i]) + else: + # The unvectorized case. + self.rs.push(x) + self.buffer.push(x) + if self.demean: + x = x - self.rs.mean + if self.destd: + x = x / (self.rs.std + 1e-8) + if self.clip: + x = np.clip(x, -self.clip, self.clip) + return x + + def __repr__(self): + return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format( + self.shape, self.demean, self.destd, self.clip, self.rs, + self.buffer) diff --git a/examples/ES/optimizers.py b/examples/ES/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..235d57ebdc40ec8a33fda435d0daa260bd293120 --- /dev/null +++ b/examples/ES/optimizers.py @@ -0,0 +1,57 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/openai/evolution-strategies-starter/blob/master/es_distributed/optimizers.py + +import numpy as np + + +class Optimizer(object): + def __init__(self, parameter_size): + self.dim = parameter_size + self.t = 0 + + def update(self, theta, global_g): + self.t += 1 + step = self._compute_step(global_g) + ratio = np.linalg.norm(step) / np.linalg.norm(theta) + return theta + step, ratio + + def _compute_step(self, global_g): + raise NotImplementedError + + +class SGD(Optimizer): + def __init__(self, parameter_size, stepsize, momentum=0.9): + Optimizer.__init__(self, parameter_size) + self.v = np.zeros(self.dim, dtype=np.float32) + self.stepsize, self.momentum = stepsize, momentum + + def _compute_step(self, global_g): + self.v = self.momentum * self.v + (1. - self.momentum) * global_g + step = -self.stepsize * self.v + return step + + +class Adam(Optimizer): + def __init__(self, + parameter_size, + stepsize, + beta1=0.9, + beta2=0.999, + epsilon=1e-08): + Optimizer.__init__(self, parameter_size) + self.stepsize = stepsize + self.beta1 = beta1 + self.beta2 = beta2 + self.epsilon = epsilon + self.m = np.zeros(self.dim, dtype=np.float32) + self.v = np.zeros(self.dim, dtype=np.float32) + + def _compute_step(self, global_g): + a = self.stepsize * (np.sqrt(1 - self.beta2**self.t) / + (1 - self.beta1**self.t)) + self.m = self.beta1 * self.m + (1 - self.beta1) * global_g + self.v = self.beta2 * self.v + (1 - self.beta2) * (global_g * global_g) + step = -a * self.m / (np.sqrt(self.v) + self.epsilon) + return step diff --git a/examples/ES/run_actors.sh b/examples/ES/run_actors.sh new file mode 100644 index 0000000000000000000000000000000000000000..7df4f4bba18be6ce174b78278b93eeee50361dfb --- /dev/null +++ b/examples/ES/run_actors.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export CPU_NUM=1 + +actor_num=96 + +for i in $(seq 1 $actor_num); do + python actor.py & +done; +wait diff --git a/examples/ES/utils.py b/examples/ES/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..29d43e02a7f1772e604b57ae2d086651a5a4d266 --- /dev/null +++ b/examples/ES/utils.py @@ -0,0 +1,62 @@ +# Third party code +# +# The following code are copied or modified from: +# https://github.com/openai/evolution-strategies-starter. + +import numpy as np + + +def compute_ranks(x): + """Returns ranks in [0, len(x)) + + Note: This is different from scipy.stats.rankdata, which returns ranks in + [1, len(x)]. + """ + assert x.ndim == 1 + ranks = np.empty(len(x), dtype=int) + ranks[x.argsort()] = np.arange(len(x)) + return ranks + + +def compute_centered_ranks(x): + y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32) + y /= (x.size - 1) + y -= 0.5 + return y + + +def itergroups(items, group_size): + assert group_size >= 1 + group = [] + for x in items: + group.append(x) + if len(group) == group_size: + yield tuple(group) + del group[:] + if group: + yield tuple(group) + + +def batched_weighted_sum(weights, vecs, batch_size): + total = 0 + num_items_summed = 0 + for batch_weights, batch_vecs in zip( + itergroups(weights, batch_size), itergroups(vecs, batch_size)): + assert len(batch_weights) == len(batch_vecs) <= batch_size + total += np.dot( + np.asarray(batch_weights, dtype=np.float32), + np.asarray(batch_vecs, dtype=np.float32)) + num_items_summed += len(batch_weights) + return total, num_items_summed + + +def unflatten(flat_array, array_shapes): + i = 0 + arrays = [] + for shape in array_shapes: + size = np.prod(shape, dtype=np.int) + array = flat_array[i:(i + size)].reshape(shape) + arrays.append(array) + i += size + assert len(flat_array) == i + return arrays diff --git a/examples/GA3C/README.md b/examples/GA3C/README.md index 2a6c1cbfd7081c85a94afe423d62ef85507dcee8..7c2ce7eb19fce1a3c13ec1dd1ea539fbb9a3c377 100755 --- a/examples/GA3C/README.md +++ b/examples/GA3C/README.md @@ -18,8 +18,8 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill ### Dependencies + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [parl](https://github.com/PaddlePaddle/PARL) -+ gym -+ atari-py ++ gym==0.12.1 ++ atari-py==0.1.7 ### Distributed Training @@ -33,10 +33,10 @@ Note that if you have started a master before, you don't have to run the above command. For more information about the cluster, please refer to our [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) -Then we can start the distributed training by running `train.py`. +Then we can start the distributed training by running `learner.py`. ```bash -python train.py +python learner.py ``` [Tips] The performance can be influenced dramatically in a slower computational diff --git a/examples/GA3C/learner.py b/examples/GA3C/learner.py index b875b7f73f254ac1e7a7e50ea94fc85f99c3c46a..58df0df5a7df276ac3403d3e609e826f71397e84 100755 --- a/examples/GA3C/learner.py +++ b/examples/GA3C/learner.py @@ -319,3 +319,15 @@ class Learner(object): tensorboard.add_scalar(key, value, self.sample_total_steps) logger.info(metric) + + +if __name__ == '__main__': + from ga3c_config import config + + learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 + + while True: + time.sleep(config['log_metrics_interval_s']) + + learner.log_metrics() diff --git a/examples/IMPALA/README.md b/examples/IMPALA/README.md index 0fcf6bd85fa690969191a025724c2797c898f0a7..5a65f3115aa3efcfa99e64fd34af0776c32a78c3 100755 --- a/examples/IMPALA/README.md +++ b/examples/IMPALA/README.md @@ -22,9 +22,8 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs). ### Dependencies + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [parl](https://github.com/PaddlePaddle/PARL) -+ gym -+ atari-py - ++ gym==0.12.1 ++ atari-py==0.1.7 ### Distributed Training: @@ -38,10 +37,10 @@ Note that if you have started a master before, you don't have to run the above command. For more information about the cluster, please refer to our [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) -Then we can start the distributed training by running `train.py`. +Then we can start the distributed training by running `learner.py`. ```bash -python train.py +python learner.py ``` ### Reference diff --git a/examples/IMPALA/learner.py b/examples/IMPALA/learner.py index ab1957313e3ef93ad7e0fd06d9d24a3e3d88f0eb..93d479d50eacba5c5f22bed58003aeb18bc7b4d7 100755 --- a/examples/IMPALA/learner.py +++ b/examples/IMPALA/learner.py @@ -243,3 +243,15 @@ class Learner(object): tensorboard.add_scalar(key, value, self.sample_total_steps) logger.info(metric) + + +if __name__ == '__main__': + from impala_config import config + + learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 + + while True: + time.sleep(config['log_metrics_interval_s']) + + learner.log_metrics() diff --git a/examples/IMPALA/train.py b/examples/IMPALA/train.py deleted file mode 100755 index a30b6f2c0383beea264b296c148b0ad3a1c541c3..0000000000000000000000000000000000000000 --- a/examples/IMPALA/train.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from learner import Learner - - -def main(config): - learner = Learner(config) - assert config['log_metrics_interval_s'] > 0 - - while True: - time.sleep(config['log_metrics_interval_s']) - - learner.log_metrics() - - -if __name__ == '__main__': - from impala_config import config - main(config) diff --git a/parl/core/fluid/model.py b/parl/core/fluid/model.py index 6afbc8f1b0f153c68fdf5e29b5b0a2993a64ccc7..7c23eda633bf7487ac577de9d50590c2e15626e9 100644 --- a/parl/core/fluid/model.py +++ b/parl/core/fluid/model.py @@ -263,8 +263,14 @@ class Model(ModelBase): """ assert len(weights) == len(self.parameters()), \ 'size of input weights should be same as weights number of current model' + try: + is_gpu_available = self._is_gpu_available + except AttributeError: + self._is_gpu_available = machine_info.is_gpu_available() + is_gpu_available = self._is_gpu_available + for (param_name, weight) in list(zip(self.parameters(), weights)): - set_value(param_name, weight) + set_value(param_name, weight, is_gpu_available) def _get_parameter_names(self, obj): """ Recursively get parameter names in a model and its child attributes. diff --git a/parl/core/fluid/plutils/common.py b/parl/core/fluid/plutils/common.py index fe1f9f7138a230da5f642ef8e62d7dc240d87cb6..c55173cb6a7faaf41d5dd0a81d6c2fd34372ca83 100644 --- a/parl/core/fluid/plutils/common.py +++ b/parl/core/fluid/plutils/common.py @@ -53,15 +53,15 @@ def fetch_value(attr_name): return _fetch_var(attr_name, return_numpy=True) -def set_value(attr_name, value): +def set_value(attr_name, value, is_gpu_available): """ Given name of ParamAttr, set numpy value to the parameter in global_scope Args: - attr_name: ParamAttr name of parameter - value: numpy array + attr_name(string): ParamAttr name of parameter + value(np.array): numpy value + is_gpu_available(bool): whether is gpu available """ - place = fluid.CUDAPlace( - 0) if machine_info.is_gpu_available() else fluid.CPUPlace() + place = fluid.CUDAPlace(0) if is_gpu_available else fluid.CPUPlace() var = _fetch_var(attr_name, return_numpy=False) var.set(value, place) diff --git a/parl/core/fluid/plutils/compiler.py b/parl/core/fluid/plutils/compiler.py index 079ed551d766208c995045b9c7414cdff65803fc..846343f7a3a0138f4f0c25af623d08124b6030dd 100644 --- a/parl/core/fluid/plutils/compiler.py +++ b/parl/core/fluid/plutils/compiler.py @@ -26,10 +26,12 @@ def compile(program, loss=None): program(fluid.Program): a normal fluid program. loss_name(str): Optional. The loss tensor of a trainable program. Set it to None if you are transferring a prediction or evaluation program. """ + loss_name = None if loss is not None: assert isinstance( loss, fluid.framework. Variable), 'type of loss is expected to be a fluid tensor' + loss_name = loss.name # TODO: after solving the learning rate issue that occurs in training A2C algorithm, set it to 3. os.environ['CPU_NUM'] = '1' exec_strategy = fluid.ExecutionStrategy() @@ -39,6 +41,6 @@ def compile(program, loss=None): build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce return fluid.compiler.CompiledProgram(program).with_data_parallel( - loss_name=loss.name, + loss_name=loss_name, exec_strategy=exec_strategy, build_strategy=build_strategy)