提交 60d68135 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

ES example (#105)

* ES example

* refine settings

* fix yapf

* refine documentation; remove csv logger

* fix bug

* merge learner.py and train.py; add version requirements of gym and atari_py

* unify actor num
上级 2ad3c4c0
...@@ -18,8 +18,8 @@ Mean episode reward in training process after 10 million sample steps. ...@@ -18,8 +18,8 @@ Mean episode reward in training process after 10 million sample steps.
### Dependencies ### Dependencies
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL) + [parl](https://github.com/PaddlePaddle/PARL)
+ gym + gym==0.12.1
+ atari-py + atari-py==0.1.7
### Distributed Training ### Distributed Training
...@@ -34,10 +34,10 @@ Note that if you have started a master before, you don't have to run the above ...@@ -34,10 +34,10 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`. Then we can start the distributed training by running `learner.py`.
```bash ```bash
python train.py python learner.py
``` ```
### Reference ### Reference
......
...@@ -216,3 +216,16 @@ class Learner(object): ...@@ -216,3 +216,16 @@ class Learner(object):
def should_stop(self): def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps'] return self.sample_total_steps >= self.config['max_sample_steps']
if __name__ == '__main__':
from a2c_config import config
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
while not learner.should_stop():
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
## Reproduce ES with PARL
Based on PARL, the Evolution Strategies (ES) algorithm has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks.
+ ES in
[Evolution Strategies as a Scalable Alternative to Reinforcement Learning](https://arxiv.org/abs/1703.03864)
### Mujoco games introduction
Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
### Benchmark result
TODO
## How to use
### Dependencies
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym==0.9.4
+ mujoco-py==0.5.1
### Distributed Training
#### Learner
```sh
python learner.py
```
#### Actors
```sh
sh run_actors.sh
```
You can change training settings (e.g. `env_name`, `server_ip`) in `es_config.py`. If you want to use different number of actors, please modify `actor_num` in both `es_config.py` and `run_actors.sh`.
Training result will be saved in `log_dir/train/result.csv`.
### Reference
+ [Ray](https://github.com/ray-project/ray)
+ [evolution-strategies-starter](https://github.com/openai/evolution-strategies-starter)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import parl
import time
import numpy as np
from es import ES
from obs_filter import MeanStdFilter
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from noise import SharedNoiseTable
@parl.remote_class
class Actor(object):
def __init__(self, config):
self.config = config
self.env = gym.make(self.config['env_name'])
self.config['obs_dim'] = self.env.observation_space.shape[0]
self.config['act_dim'] = self.env.action_space.shape[0]
self.obs_filter = MeanStdFilter(self.config['obs_dim'])
self.noise = SharedNoiseTable(self.config['noise_size'])
model = MujocoModel(self.config['act_dim'])
algorithm = ES(model)
self.agent = MujocoAgent(algorithm, self.config)
def _play_one_episode(self, add_noise=False):
episode_reward = 0
episode_step = 0
obs = self.env.reset()
while True:
if np.random.uniform() < self.config['filter_update_prob']:
obs = self.obs_filter(obs[None], update=True)
else:
obs = self.obs_filter(obs[None], update=False)
action = self.agent.predict(obs)
if add_noise:
action += np.random.randn(
*action.shape) * self.config['action_noise_std']
obs, reward, done, _ = self.env.step(action)
episode_reward += reward
episode_step += 1
if done:
break
return episode_reward, episode_step
def sample(self, flat_weights):
noise_indices, rewards, lengths = [], [], []
eval_rewards, eval_lengths = [], []
# Perform some rollouts with noise.
task_tstart = time.time()
while (len(noise_indices) == 0
or time.time() - task_tstart < self.config['min_task_runtime']):
if np.random.uniform() < self.config["eval_prob"]:
# Do an evaluation run with no perturbation.
self.agent.set_flat_weights(flat_weights)
episode_reward, episode_step = self._play_one_episode(
add_noise=False)
eval_rewards.append(episode_reward)
eval_lengths.append(episode_step)
else:
# Do a regular run with parameter perturbations.
noise_index = self.noise.sample_index(
self.agent.weights_total_size)
perturbation = self.config["noise_stdev"] * self.noise.get(
noise_index, self.agent.weights_total_size)
# mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon
self.agent.set_flat_weights(flat_weights + perturbation)
episode_reward_pos, episode_step_pos = self._play_one_episode(
add_noise=True)
self.agent.set_flat_weights(flat_weights - perturbation)
episode_reward_neg, episode_step_neg = self._play_one_episode(
add_noise=True)
noise_indices.append(noise_index)
rewards.append([episode_reward_pos, episode_reward_neg])
lengths.append([episode_step_pos, episode_step_neg])
return {
'noise_indices': noise_indices,
'noisy_rewards': rewards,
'noisy_lengths': lengths,
'eval_rewards': eval_rewards,
'eval_lengths': eval_lengths
}
def get_filter(self, flush_after=False):
return_filter = self.obs_filter.as_serializable()
if flush_after:
self.obs_filter.clear_buffer()
return return_filter
def set_filter(self, new_filter):
self.obs_filter.sync(new_filter)
if __name__ == '__main__':
from es_config import config
actor = Actor(config)
actor.as_remote(config['server_ip'], config['server_port'])
...@@ -12,20 +12,30 @@ ...@@ -12,20 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time import parl
from learner import Learner
__all__ = ['ES']
def main(config):
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
while True: class ES(parl.Algorithm):
time.sleep(config['log_metrics_interval_s']) def __init__(self, model):
"""ES algorithm.
Since parameters of the model is updated in the numpy level, `learn` function is not needed
in this algorithm.
learner.log_metrics() Args:
model(`parl.Model`): policy model of ES algorithm.
"""
self.model = model
def predict(self, obs):
"""Use the policy model to predict actions of observations.
if __name__ == '__main__': Args:
from impala_config import config obs(layers.data): data layer of observations.
main(config)
Returns:
tensor of predicted actions.
"""
return self.model(obs)
...@@ -12,20 +12,27 @@ ...@@ -12,20 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time config = {
from learner import Learner #========== remote config ==========
'server_ip': 'localhost',
'server_port': 8037,
#========== env config ==========
'env_name': 'Humanoid-v1',
def main(config): #========== actor config ==========
learner = Learner(config) 'actor_num': 96,
assert config['log_metrics_interval_s'] > 0 'action_noise_std': 0.01,
'min_task_runtime': 0.2,
'eval_prob': 0.003,
'filter_update_prob': 0.01,
while True: #========== learner config ==========
time.sleep(config['log_metrics_interval_s']) 'stepsize': 0.01,
'min_episodes_per_batch': 1000,
learner.log_metrics() 'min_steps_per_batch': 10000,
'noise_size': 200000000,
'noise_stdev': 0.02,
if __name__ == '__main__': 'l2_coeff': 0.005,
from ga3c_config import config 'report_window_size': 10,
main(config) }
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import os
import parl
import numpy as np
import threading
import utils
from es import ES
from obs_filter import MeanStdFilter
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from noise import SharedNoiseTable
from parl import RemoteManager
from parl.utils import logger, tensorboard
from parl.utils.window_stat import WindowStat
from six.moves import queue
class Learner(object):
def __init__(self, config):
self.config = config
env = gym.make(self.config['env_name'])
self.config['obs_dim'] = env.observation_space.shape[0]
self.config['act_dim'] = env.action_space.shape[0]
self.obs_filter = MeanStdFilter(self.config['obs_dim'])
self.noise = SharedNoiseTable(self.config['noise_size'])
model = MujocoModel(self.config['act_dim'])
algorithm = ES(model)
self.agent = MujocoAgent(algorithm, self.config)
self.latest_flat_weights = self.agent.get_flat_weights()
self.latest_obs_filter = self.obs_filter.as_serializable()
self.sample_total_episodes = 0
self.sample_total_steps = 0
self.actors_signal_input_queues = []
self.actors_output_queues = []
self.run_remote_manager()
self.eval_rewards_stat = WindowStat(self.config['report_window_size'])
self.eval_lengths_stat = WindowStat(self.config['report_window_size'])
def run_remote_manager(self):
""" Accept connection of new remote actor and start sampling of the remote actor.
"""
remote_manager = RemoteManager(port=self.config['server_port'])
logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num']))
self.remote_count = 0
for i in range(self.config['actor_num']):
remote_actor = remote_manager.get_remote()
signal_queue = queue.Queue()
output_queue = queue.Queue()
self.actors_signal_input_queues.append(signal_queue)
self.actors_output_queues.append(output_queue)
self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count))
remote_thread = threading.Thread(
target=self.run_remote_sample,
args=(remote_actor, signal_queue, output_queue))
remote_thread.setDaemon(True)
remote_thread.start()
logger.info('All remote actors are ready, begin to learn.')
def run_remote_sample(self, remote_actor, signal_queue, output_queue):
""" Sample data from remote actor or get filters of remote actor.
"""
while True:
info = signal_queue.get()
if info['signal'] == 'sample':
result = remote_actor.sample(self.latest_flat_weights)
output_queue.put(result)
elif info['signal'] == 'get_filter':
actor_filter = remote_actor.get_filter(flush_after=True)
output_queue.put(actor_filter)
elif info['signal'] == 'set_filter':
remote_actor.set_filter(self.latest_obs_filter)
else:
raise NotImplementedError
def step(self):
"""Run a step in ES.
1. kick off all actors to synchronize weights and sample data;
2. update parameters of the model based on sampled data.
3. update global observation filter based on local filters of all actors, and synchronize global
filter to all actors.
"""
num_episodes, num_timesteps = 0, 0
results = []
while num_episodes < self.config['min_episodes_per_batch'] or \
num_timesteps < self.config['min_steps_per_batch']:
# Send sample signal to all actors
for q in self.actors_signal_input_queues:
q.put({'signal': 'sample'})
# Collect results from all actors
for q in self.actors_output_queues:
result = q.get()
results.append(result)
# result['noisy_lengths'] is a list of lists, where the inner lists have length 2.
num_episodes += sum(
len(pair) for pair in result['noisy_lengths'])
num_timesteps += sum(
sum(pair) for pair in result['noisy_lengths'])
all_noise_indices = []
all_training_rewards = []
all_training_lengths = []
all_eval_rewards = []
all_eval_lengths = []
for result in results:
all_eval_rewards.extend(result['eval_rewards'])
all_eval_lengths.extend(result['eval_lengths'])
all_noise_indices.extend(result['noise_indices'])
all_training_rewards.extend(result['noisy_rewards'])
all_training_lengths.extend(result['noisy_lengths'])
assert len(all_eval_rewards) == len(all_eval_lengths)
assert (len(all_noise_indices) == len(all_training_rewards) ==
len(all_training_lengths))
self.sample_total_episodes += num_episodes
self.sample_total_steps += num_timesteps
eval_rewards = np.array(all_eval_rewards)
eval_lengths = np.array(all_eval_lengths)
noise_indices = np.array(all_noise_indices)
noisy_rewards = np.array(all_training_rewards)
noisy_lengths = np.array(all_training_lengths)
# normalize rewards to (-0.5, 0.5)
proc_noisy_rewards = utils.compute_centered_ranks(noisy_rewards)
noises = [
self.noise.get(index, self.agent.weights_total_size)
for index in noise_indices
]
# Update the parameters of the model.
self.agent.learn(proc_noisy_rewards, noises)
self.latest_flat_weights = self.agent.get_flat_weights()
# Update obs filter
self._update_filter()
# Store the evaluate rewards
if len(all_eval_rewards) > 0:
self.eval_rewards_stat.add(np.mean(eval_rewards))
self.eval_lengths_stat.add(np.mean(eval_lengths))
metrics = {
"episodes_this_iter": noisy_lengths.size,
"sample_total_episodes": self.sample_total_episodes,
'sample_total_steps': self.sample_total_steps,
"evaluate_rewards_mean": self.eval_rewards_stat.mean,
"evaluate_steps_mean": self.eval_lengths_stat.mean,
"timesteps_this_iter": noisy_lengths.sum(),
}
self.log_metrics(metrics)
return metrics
def _update_filter(self):
# Send get_filter signal to all actors
for q in self.actors_signal_input_queues:
q.put({'signal': 'get_filter'})
filters = []
# Collect filters from all actors and update global filter
for q in self.actors_output_queues:
actor_filter = q.get()
self.obs_filter.apply_changes(actor_filter)
# Send set_filter signal to all actors
self.latest_obs_filter = self.obs_filter.as_serializable()
for q in self.actors_signal_input_queues:
q.put({'signal': 'set_filter'})
def log_metrics(self, metrics):
logger.info(metrics)
for k, v in metrics.items():
if v is not None:
tensorboard.add_scalar(k, v, self.sample_total_steps)
if __name__ == '__main__':
from es_config import config
learner = Learner(config)
while True:
learner.step()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl
import utils
from parl import layers
from optimizers import Adam
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, config):
self.config = config
super(MujocoAgent, self).__init__(algorithm)
weights = self.get_weights()
assert len(
weights) == 1, "There should be only one model in the algorithm."
self.weights_name = list(weights.keys())[0]
weights = list(weights.values())[0]
self.weights_shapes = [x.shape for x in weights]
self.weights_total_size = np.sum(
[np.prod(x) for x in self.weights_shapes])
self.optimizer = Adam(self.weights_total_size, self.config['stepsize'])
def build_program(self):
self.predict_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=[self.config['obs_dim']], dtype='float32')
self.predict_action = self.alg.predict(obs)
self.predict_program = parl.compile(self.predict_program)
def learn(self, noisy_rewards, noises):
""" Update weights of the model in the numpy level.
Compute the grident and take a step.
Args:
noisy_rewards(np.float32): [batch_size, 2]
noises(np.float32): [batch_size, weights_total_size]
"""
g, count = utils.batched_weighted_sum(
# mirrored sampling: evaluate pairs of perturbations \epsilon, −\epsilon
noisy_rewards[:, 0] - noisy_rewards[:, 1],
noises,
batch_size=500)
g /= noisy_rewards.size
latest_flat_weights = self.get_flat_weights()
# Compute the new weights theta.
theta, update_ratio = self.optimizer.update(
latest_flat_weights,
-g + self.config["l2_coeff"] * latest_flat_weights)
self.set_flat_weights(theta)
def predict(self, obs):
obs = obs.astype('float32')
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(
program=self.predict_program,
feed={'obs': obs},
fetch_list=[self.predict_action])[0]
return act
def get_flat_weights(self):
weights = list(self.get_weights().values())[0]
flat_weights = np.concatenate([x.flatten() for x in weights])
return flat_weights
def set_flat_weights(self, flat_weights):
weights = utils.unflatten(flat_weights, self.weights_shapes)
self.set_weights({self.weights_name: weights})
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time import paddle.fluid as fluid
from learner import Learner
import parl import parl
from parl import layers
def main(config): class MujocoModel(parl.Model):
learner = Learner(config) def __init__(self, act_dim):
assert config['log_metrics_interval_s'] > 0 hid1_size = 256
hid2_size = 256
while not learner.should_stop(): self.fc1 = layers.fc(size=hid1_size, act='tanh')
start = time.time() self.fc2 = layers.fc(size=hid2_size, act='tanh')
while time.time() - start < config['log_metrics_interval_s']: self.fc3 = layers.fc(size=act_dim)
learner.step()
learner.log_metrics()
def forward(self, obs):
if __name__ == '__main__': hid1 = self.fc1(obs)
from a2c_config import config hid2 = self.fc2(hid1)
main(config) means = self.fc3(hid2)
return means
# Third party code
#
# The following code are copied or modified from:
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/filter.py
import numpy as np
class SharedNoiseTable(object):
"""Shared noise table used by learner and actor.
Learner and actor will create a same noise table by passing the same seed.
With the same noise table, learner and actor can communicate the noises by
index of noise table instead of numpy array of noises.
"""
def __init__(self, noise_size, seed=1024):
self.noise_size = noise_size
self.seed = seed
self.noise = self._create_noise()
def _create_noise(self):
noise = np.random.RandomState(self.seed).randn(self.noise_size).astype(
np.float32)
return noise
def get(self, i, dim):
return self.noise[i:i + dim]
def sample_index(self, dim):
return np.random.randint(0, len(self.noise) - dim + 1)
# Third party code
#
# The following code are copied or modified from:
# https://github.com/ray-project/ray/blob/master/python/ray/rllib/utils/filter.py
import numpy as np
class Filter(object):
"""Processes input, possibly statefully."""
def apply_changes(self, other, *args, **kwargs):
"""Updates self with "new state" from other filter."""
raise NotImplementedError
def copy(self):
"""Creates a new object with same state as self.
Returns:
A copy of self.
"""
raise NotImplementedError
def sync(self, other):
"""Copies all state from other filter to self."""
raise NotImplementedError
def clear_buffer(self):
"""Creates copy of current state and clears accumulated state"""
raise NotImplementedError
def as_serializable(self):
raise NotImplementedError
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
def __init__(self, shape=None):
self._n = 0
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def copy(self):
other = RunningStat()
other._n = self._n
other._M = np.copy(self._M)
other._S = np.copy(self._S)
return other
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
if x.shape != self._M.shape:
raise ValueError(
"Unexpected input shape {}, expected {}, value = {}".format(
x.shape, self._M.shape, x))
n1 = self._n
self._n += 1
if self._n == 1:
self._M[...] = x
else:
delta = x - self._M
self._M[...] += delta / self._n
self._S[...] += delta * delta * n1 / self._n
def update(self, other):
n1 = self._n
n2 = other._n
n = n1 + n2
if n == 0:
# Avoid divide by zero, which creates nans
return
delta = self._M - other._M
delta2 = delta * delta
M = (n1 * self._M + n2 * other._M) / n
S = self._S + other._S + delta2 * n1 * n2 / n
self._n = n
self._M = M
self._S = S
def __repr__(self):
return '(n={}, mean_mean={}, mean_std={})'.format(
self.n, np.mean(self.mean), np.mean(self.std))
@property
def n(self):
return self._n
@property
def mean(self):
return self._M
@property
def var(self):
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
@property
def std(self):
return np.sqrt(self.var)
@property
def shape(self):
return self._M.shape
class MeanStdFilter(Filter):
"""Keeps track of a running mean for seen states.
The filter will be used to normalize observations and will be
online updated according to the seen observations of all actors.
"""
is_concurrent = False
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.shape = shape
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
# In distributed rollouts, each worker sees different states.
# The buffer is used to keep track of deltas amongst all the
# observation filters.
self.buffer = RunningStat(shape)
def clear_buffer(self):
self.buffer = RunningStat(self.shape)
def apply_changes(self, other, with_buffer=False):
"""Applies updates from the buffer of another filter.
Params:
other (MeanStdFilter): Other filter to apply info from
with_buffer (bool): Flag for specifying if the buffer should be
copied from other.
Examples:
>>> a = MeanStdFilter(())
>>> a(1)
>>> a(2)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[2, 1.5, 2]
>>> b = MeanStdFilter(())
>>> b(10)
>>> a.apply_changes(b, with_buffer=False)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[3, 4.333333333333333, 2]
>>> a.apply_changes(b, with_buffer=True)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[4, 5.75, 1]
"""
self.rs.update(other.buffer)
if with_buffer:
self.buffer = other.buffer.copy()
def copy(self):
"""Returns a copy of Filter."""
other = MeanStdFilter(self.shape)
other.sync(self)
return other
def as_serializable(self):
return self.copy()
def sync(self, other):
"""Syncs all fields together from other filter.
Examples:
>>> a = MeanStdFilter(())
>>> a(1)
>>> a(2)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[2, array(1.5), 2]
>>> b = MeanStdFilter(())
>>> b(10)
>>> print([b.rs.n, b.rs.mean, b.buffer.n])
[1, array(10.0), 1]
>>> a.sync(b)
>>> print([a.rs.n, a.rs.mean, a.buffer.n])
[1, array(10.0), 1]
"""
assert other.shape == self.shape, "Shapes don't match!"
self.demean = other.demean
self.destd = other.destd
self.clip = other.clip
self.rs = other.rs.copy()
self.buffer = other.buffer.copy()
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
if len(x.shape) == len(self.rs.shape) + 1:
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
self.buffer.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
self.buffer.push(x)
if self.demean:
x = x - self.rs.mean
if self.destd:
x = x / (self.rs.std + 1e-8)
if self.clip:
x = np.clip(x, -self.clip, self.clip)
return x
def __repr__(self):
return 'MeanStdFilter({}, {}, {}, {}, {}, {})'.format(
self.shape, self.demean, self.destd, self.clip, self.rs,
self.buffer)
# Third party code
#
# The following code are copied or modified from:
# https://github.com/openai/evolution-strategies-starter/blob/master/es_distributed/optimizers.py
import numpy as np
class Optimizer(object):
def __init__(self, parameter_size):
self.dim = parameter_size
self.t = 0
def update(self, theta, global_g):
self.t += 1
step = self._compute_step(global_g)
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
return theta + step, ratio
def _compute_step(self, global_g):
raise NotImplementedError
class SGD(Optimizer):
def __init__(self, parameter_size, stepsize, momentum=0.9):
Optimizer.__init__(self, parameter_size)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
def _compute_step(self, global_g):
self.v = self.momentum * self.v + (1. - self.momentum) * global_g
step = -self.stepsize * self.v
return step
class Adam(Optimizer):
def __init__(self,
parameter_size,
stepsize,
beta1=0.9,
beta2=0.999,
epsilon=1e-08):
Optimizer.__init__(self, parameter_size)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = np.zeros(self.dim, dtype=np.float32)
self.v = np.zeros(self.dim, dtype=np.float32)
def _compute_step(self, global_g):
a = self.stepsize * (np.sqrt(1 - self.beta2**self.t) /
(1 - self.beta1**self.t))
self.m = self.beta1 * self.m + (1 - self.beta1) * global_g
self.v = self.beta2 * self.v + (1 - self.beta2) * (global_g * global_g)
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
return step
#!/bin/bash
export CPU_NUM=1
actor_num=96
for i in $(seq 1 $actor_num); do
python actor.py &
done;
wait
# Third party code
#
# The following code are copied or modified from:
# https://github.com/openai/evolution-strategies-starter.
import numpy as np
def compute_ranks(x):
"""Returns ranks in [0, len(x))
Note: This is different from scipy.stats.rankdata, which returns ranks in
[1, len(x)].
"""
assert x.ndim == 1
ranks = np.empty(len(x), dtype=int)
ranks[x.argsort()] = np.arange(len(x))
return ranks
def compute_centered_ranks(x):
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
return y
def itergroups(items, group_size):
assert group_size >= 1
group = []
for x in items:
group.append(x)
if len(group) == group_size:
yield tuple(group)
del group[:]
if group:
yield tuple(group)
def batched_weighted_sum(weights, vecs, batch_size):
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(
itergroups(weights, batch_size), itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(
np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
def unflatten(flat_array, array_shapes):
i = 0
arrays = []
for shape in array_shapes:
size = np.prod(shape, dtype=np.int)
array = flat_array[i:(i + size)].reshape(shape)
arrays.append(array)
i += size
assert len(flat_array) == i
return arrays
...@@ -18,8 +18,8 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill ...@@ -18,8 +18,8 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill
### Dependencies ### Dependencies
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL) + [parl](https://github.com/PaddlePaddle/PARL)
+ gym + gym==0.12.1
+ atari-py + atari-py==0.1.7
### Distributed Training ### Distributed Training
...@@ -33,10 +33,10 @@ Note that if you have started a master before, you don't have to run the above ...@@ -33,10 +33,10 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`. Then we can start the distributed training by running `learner.py`.
```bash ```bash
python train.py python learner.py
``` ```
[Tips] The performance can be influenced dramatically in a slower computational [Tips] The performance can be influenced dramatically in a slower computational
......
...@@ -319,3 +319,15 @@ class Learner(object): ...@@ -319,3 +319,15 @@ class Learner(object):
tensorboard.add_scalar(key, value, self.sample_total_steps) tensorboard.add_scalar(key, value, self.sample_total_steps)
logger.info(metric) logger.info(metric)
if __name__ == '__main__':
from ga3c_config import config
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
while True:
time.sleep(config['log_metrics_interval_s'])
learner.log_metrics()
...@@ -22,9 +22,8 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs). ...@@ -22,9 +22,8 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs).
### Dependencies ### Dependencies
+ [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle) + [paddlepaddle>=1.5.1](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL) + [parl](https://github.com/PaddlePaddle/PARL)
+ gym + gym==0.12.1
+ atari-py + atari-py==0.1.7
### Distributed Training: ### Distributed Training:
...@@ -38,10 +37,10 @@ Note that if you have started a master before, you don't have to run the above ...@@ -38,10 +37,10 @@ Note that if you have started a master before, you don't have to run the above
command. For more information about the cluster, please refer to our command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) [documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`. Then we can start the distributed training by running `learner.py`.
```bash ```bash
python train.py python learner.py
``` ```
### Reference ### Reference
......
...@@ -243,3 +243,15 @@ class Learner(object): ...@@ -243,3 +243,15 @@ class Learner(object):
tensorboard.add_scalar(key, value, self.sample_total_steps) tensorboard.add_scalar(key, value, self.sample_total_steps)
logger.info(metric) logger.info(metric)
if __name__ == '__main__':
from impala_config import config
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
while True:
time.sleep(config['log_metrics_interval_s'])
learner.log_metrics()
...@@ -263,8 +263,14 @@ class Model(ModelBase): ...@@ -263,8 +263,14 @@ class Model(ModelBase):
""" """
assert len(weights) == len(self.parameters()), \ assert len(weights) == len(self.parameters()), \
'size of input weights should be same as weights number of current model' 'size of input weights should be same as weights number of current model'
try:
is_gpu_available = self._is_gpu_available
except AttributeError:
self._is_gpu_available = machine_info.is_gpu_available()
is_gpu_available = self._is_gpu_available
for (param_name, weight) in list(zip(self.parameters(), weights)): for (param_name, weight) in list(zip(self.parameters(), weights)):
set_value(param_name, weight) set_value(param_name, weight, is_gpu_available)
def _get_parameter_names(self, obj): def _get_parameter_names(self, obj):
""" Recursively get parameter names in a model and its child attributes. """ Recursively get parameter names in a model and its child attributes.
......
...@@ -53,15 +53,15 @@ def fetch_value(attr_name): ...@@ -53,15 +53,15 @@ def fetch_value(attr_name):
return _fetch_var(attr_name, return_numpy=True) return _fetch_var(attr_name, return_numpy=True)
def set_value(attr_name, value): def set_value(attr_name, value, is_gpu_available):
""" Given name of ParamAttr, set numpy value to the parameter in global_scope """ Given name of ParamAttr, set numpy value to the parameter in global_scope
Args: Args:
attr_name: ParamAttr name of parameter attr_name(string): ParamAttr name of parameter
value: numpy array value(np.array): numpy value
is_gpu_available(bool): whether is gpu available
""" """
place = fluid.CUDAPlace( place = fluid.CUDAPlace(0) if is_gpu_available else fluid.CPUPlace()
0) if machine_info.is_gpu_available() else fluid.CPUPlace()
var = _fetch_var(attr_name, return_numpy=False) var = _fetch_var(attr_name, return_numpy=False)
var.set(value, place) var.set(value, place)
......
...@@ -26,10 +26,12 @@ def compile(program, loss=None): ...@@ -26,10 +26,12 @@ def compile(program, loss=None):
program(fluid.Program): a normal fluid program. program(fluid.Program): a normal fluid program.
loss_name(str): Optional. The loss tensor of a trainable program. Set it to None if you are transferring a prediction or evaluation program. loss_name(str): Optional. The loss tensor of a trainable program. Set it to None if you are transferring a prediction or evaluation program.
""" """
loss_name = None
if loss is not None: if loss is not None:
assert isinstance( assert isinstance(
loss, fluid.framework. loss, fluid.framework.
Variable), 'type of loss is expected to be a fluid tensor' Variable), 'type of loss is expected to be a fluid tensor'
loss_name = loss.name
# TODO: after solving the learning rate issue that occurs in training A2C algorithm, set it to 3. # TODO: after solving the learning rate issue that occurs in training A2C algorithm, set it to 3.
os.environ['CPU_NUM'] = '1' os.environ['CPU_NUM'] = '1'
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
...@@ -39,6 +41,6 @@ def compile(program, loss=None): ...@@ -39,6 +41,6 @@ def compile(program, loss=None):
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
return fluid.compiler.CompiledProgram(program).with_data_parallel( return fluid.compiler.CompiledProgram(program).with_data_parallel(
loss_name=loss.name, loss_name=loss_name,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
build_strategy=build_strategy) build_strategy=build_strategy)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册