提交 39846831 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

A2C example (#62)

* add IMPALA algorithm and some common utils

* update README.md

* refactor files structure of impala algorithm; seperate numpy utils from utils

* add hyper parameter scheduler module; add entropy and lr scheduler in impala

* clip reward in atari wrapper instead of learner side; fix codestyle

* add benchmark result of impala; refine code of impala example; add obs_format in atari_wrappers

* Update README.md

* add a3c algorithm, A2C example and rl_utils

* require training in single gpu/cpu

* only check cpu/gpu num in learner

* refine Readme

* update impala benchmark picture; update Readme

* add benchmark result of A2C

* move get_params/set_params in agent_base

* fix shell script cannot run in ubuntu

* refine comment and document

* Update README.md

* Update README.md
上级 452050a0
......@@ -117,6 +117,7 @@ pip install parl
- [DDPG](examples/DDPG/)
- [PPO](examples/PPO/)
- [IMPALA](examples/IMPALA/)
- [A2C](examples/A2C/)
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
<img src=".github/NeurlIPS2018.gif" width = "300" height ="200" alt="NeurlIPS2018"/> <img src=".github/Half-Cheetah.gif" width = "300" height ="200" alt="Half-Cheetah"/> <img src=".github/Breakout.gif" width = "200" height ="200" alt="Breakout"/>
......
## Reproduce A2C with PARL
Based on PARL, the A2C algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
A2C is a synchronous, deterministic variant of [Asynchronous Advantage Actor Critic (A3C)](https://arxiv.org/abs/1602.01783). Instead of updating asynchronously in A3C or GA3C, A2C uses a synchronous approach that waits for each actor to finish its sampling before performing an update. Since loss definition of these A3C variants are identical, we use a common a3c algotrithm `parl.algorithms.A3C` for A2C and GA3C examples.
### Atari games introduction
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
### Benchmark result
Results with one learner (in a P40 GPU) and 5 actors in 10 million sample steps.
<img src=".benchmark/A2C_Pong.jpg" width = "400" height ="300" alt="A2C_Pong" /> <img src=".benchmark/A2C_Breakout.jpg" width = "400" height ="300" alt="A2C_Breakout"/>
## How to use
### Dependencies
+ python2.7 or python3.5+
+ [paddlepaddle>=1.3.0](https://github.com/PaddlePaddle/Paddle)
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ opencv-python
+ atari_py
### Distributed Training
#### Learner
```sh
python train.py
```
#### Actors (Suggest: 5 actors in 5 CPUs)
```sh
for i in $(seq 1 5); do
python actor.py &
done;
wait
```
You can change training settings (e.g. `env_name`, `server_ip`) in `a2c_config.py`.
Training result will be saved in `log_dir/train/result.csv`.
### Reference
+ [Ray](https://github.com/ray-project/ray)
+ [OpenAI Baselines: ACKTR & A2C](https://openai.com/blog/baselines-acktr-a2c/)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
config = {
#========== remote config ==========
'server_ip': 'localhost',
'server_port': 8037,
#========== env config ==========
'env_name': 'PongNoFrameskip-v4',
'env_dim': 42,
#========== actor config ==========
'actor_num': 5,
'env_num': 5,
'sample_batch_steps': 20,
#========== learner config ==========
'gamma': 0.99,
'lambda': 1.0, # GAE
# learning rate adjustment schedule: (train_step, learning_rate)
'lr_scheduler': [(0, 0.001), (20000, 0.0005), (40000, 0.0001)],
# coefficient of policy entropy adjustment schedule: (train_step, coefficient)
'entropy_coeff_scheduler': [(0, -0.01)],
'vf_loss_coeff': 0.5,
'get_remote_metrics_interval': 10,
'log_metrics_interval_s': 10,
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import parl
import six
from atari_model import AtariModel
from collections import defaultdict
from atari_agent import AtariAgent
from parl.algorithms import A3C
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
from parl.utils.rl_utils import calc_gae
@parl.remote_class
class Actor(object):
def __init__(self, config):
self.config = config
self.envs = []
for _ in six.moves.range(config['env_num']):
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env)
self.vector_env = VectorEnv(self.envs)
self.obs_batch = self.vector_env.reset()
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = A3C(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
def sample(self):
sample_data = defaultdict(list)
env_sample_data = {}
for env_id in six.moves.range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list)
for i in six.moves.range(self.config['sample_batch_steps']):
actions_batch, values_batch = self.agent.sample(
np.stack(self.obs_batch))
next_obs_batch, reward_batch, done_batch, info_batch = \
self.vector_env.step(actions_batch)
for env_id in six.moves.range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(self.obs_batch[env_id])
env_sample_data[env_id]['actions'].append(
actions_batch[env_id])
env_sample_data[env_id]['rewards'].append(reward_batch[env_id])
env_sample_data[env_id]['dones'].append(done_batch[env_id])
env_sample_data[env_id]['values'].append(values_batch[env_id])
# Calculate advantages when the episode is done or reach max sample steps.
if done_batch[
env_id] or i == self.config['sample_batch_steps'] - 1:
next_value = 0
if not done_batch[env_id]:
next_obs = np.expand_dims(next_obs_batch[env_id], 0)
next_value = self.agent.value(next_obs)
values = env_sample_data[env_id]['values']
rewards = env_sample_data[env_id]['rewards']
advantages = calc_gae(rewards, values, next_value,
self.config['gamma'],
self.config['lambda'])
target_values = advantages + values
sample_data['obs'].extend(env_sample_data[env_id]['obs'])
sample_data['actions'].extend(
env_sample_data[env_id]['actions'])
sample_data['advantages'].extend(advantages)
sample_data['target_values'].extend(target_values)
env_sample_data[env_id] = defaultdict(list)
self.obs_batch = next_obs_batch
# size of sample_data: env_num * sample_batch_steps
for key in sample_data:
sample_data[key] = np.stack(sample_data[key])
return sample_data
def get_metrics(self):
metrics = defaultdict(list)
for env in self.envs:
monitor = get_wrapper_by_cls(env, MonitorEnv)
if monitor is not None:
for episode_rewards, episode_steps in monitor.next_episode_results(
):
metrics['episode_rewards'].append(episode_rewards)
metrics['episode_steps'].append(episode_steps)
return metrics
def set_params(self, params):
self.agent.set_params(params)
if __name__ == '__main__':
from a2c_config import config
actor = Actor(config)
actor.as_remote(config['server_ip'], config['server_port'])
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
from parl.utils.scheduler import PiecewiseScheduler
class AtariAgent(Agent):
def __init__(self, algorithm, config):
self.config = config
super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = PiecewiseScheduler(config['lr_scheduler'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
use_cuda = True if self.gpu_id >= 0 else False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 4
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
# Use ParallelExecutor to make learn program run faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=self.learn_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
def build_program(self):
self.sample_program = fluid.Program()
self.predict_program = fluid.Program()
self.value_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
sample_actions, values = self.alg.sample(obs)
self.sample_outputs = [sample_actions, values]
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
self.predict_actions = self.alg.predict(obs)
with fluid.program_guard(self.value_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
self.values = self.alg.value(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
advantages = layers.data(
name='advantages', shape=[], dtype='float32')
target_values = layers.data(
name='target_values', shape=[], dtype='float32')
lr = layers.data(
name='lr', shape=[1], dtype='float32', append_batch_size=False)
entropy_coeff = layers.data(
name='entropy_coeff', shape=[], dtype='float32')
total_loss, pi_loss, vf_loss, entropy = self.alg.learn(
obs, actions, advantages, target_values, lr, entropy_coeff)
self.learn_outputs = [
total_loss.name, pi_loss.name, vf_loss.name, entropy.name
]
def sample(self, obs_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space).
Format of image input should be NCHW format.
Returns:
sample_ids: a numpy int64 array of shape [B]
values: a numpy float32 array of shape [B]
"""
obs_np = obs_np.astype('float32')
sample_actions, values = self.fluid_executor.run(
self.sample_program,
feed={'obs': obs_np},
fetch_list=self.sample_outputs)
return sample_actions, values
def predict(self, obs_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space).
Format of image input should be NCHW format.
Returns:
sample_ids: a numpy int64 array of shape [B]
"""
obs_np = obs_np.astype('float32')
predict_actions = self.fluid_executor.run(
self.predict_program,
feed={'obs': obs_np},
fetch_list=[self.predict_actions])[0]
return predict_actions
def value(self, obs_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space).
Format of image input should be NCHW format.
Returns:
values: a numpy float32 array of shape [B]
"""
obs_np = obs_np.astype('float32')
values = self.fluid_executor.run(
self.value_program, feed={'obs': obs_np},
fetch_list=[self.values])[0]
return values
def learn(self, obs_np, actions_np, advantages_np, target_values_np):
"""
Args:
obs_np: a numpy float32 array of shape ([B] + observation_space).
Format of image input should be NCHW format.
actions_np: a numpy int64 array of shape [B]
advantages_np: a numpy float32 array of shape [B]
target_values_np: a numpy float32 array of shape [B]
"""
obs_np = obs_np.astype('float32')
actions_np = actions_np.astype('int64')
advantages_np = advantages_np.astype('float32')
target_values_np = target_values_np.astype('float32')
lr = self.lr_scheduler.step()
entropy_coeff = self.entropy_coeff_scheduler.step()
total_loss, pi_loss, vf_loss, entropy = self.learn_exe.run(
feed={
'obs': obs_np,
'actions': actions_np,
'advantages': advantages_np,
'target_values': target_values_np,
'lr': np.array([lr], dtype='float32'),
'entropy_coeff': np.array([entropy_coeff], dtype='float32')
},
fetch_list=self.learn_outputs)
return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from paddle.fluid.param_attr import ParamAttr
class AtariModel(Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
num_filters=16, filter_size=4, stride=2, padding=1, act='relu')
self.conv2 = layers.conv2d(
num_filters=32, filter_size=4, stride=2, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=256, filter_size=11, stride=1, padding=0, act='relu')
self.policy_conv = layers.conv2d(
num_filters=act_dim,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
self.value_fc = layers.fc(
size=1,
param_attr=ParamAttr(initializer=fluid.initializer.Normal()))
def policy(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
return policy_logits
def value(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten)
values = layers.squeeze(values, axes=[1])
return values
def policy_and_value(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
policy_conv = self.policy_conv(conv3)
policy_logits = layers.flatten(policy_conv, axis=1)
flatten = layers.flatten(conv3, axis=1)
values = self.value_fc(flatten)
values = layers.squeeze(values, axes=[1])
return policy_logits, values
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import numpy as np
import os
import queue
import six
import time
import threading
from atari_model import AtariModel
from atari_agent import AtariAgent
from collections import defaultdict
from parl import RemoteManager
from parl.algorithms import A3C
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger, get_gpu_count
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat
class Learner(object):
def __init__(self, config):
self.config = config
#=========== Create Agent ==========
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = A3C(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
if self.agent.gpu_id >= 0:
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
else:
cpu_num = os.environ.get('CPU_NUM')
assert cpu_num is not None and cpu_num == '1', 'Only support training in single CPU,\
Please set environment variable: `export CPU_NUM=1`.'
#========== Learner ==========
self.total_loss_stat = WindowStat(100)
self.pi_loss_stat = WindowStat(100)
self.vf_loss_stat = WindowStat(100)
self.entropy_stat = WindowStat(100)
self.lr = None
self.entropy_coeff = None
self.learn_time_stat = TimeStat(100)
self.start_time = None
#========== Remote Actor ===========
self.remote_count = 0
self.sample_data_queue = queue.Queue()
self.remote_metrics_queue = queue.Queue()
self.sample_total_steps = 0
self.params_queues = []
self.run_remote_manager()
self.csv_logger = CSVLogger(
os.path.join(logger.get_dir(), 'result.csv'))
def run_remote_manager(self):
""" Accept connection of new remote actor and start sampling of the remote actor.
"""
remote_manager = RemoteManager(port=self.config['server_port'])
logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num']))
for i in six.moves.range(self.config['actor_num']):
remote_actor = remote_manager.get_remote()
params_queue = queue.Queue()
self.params_queues.append(params_queue)
self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count))
remote_thread = threading.Thread(
target=self.run_remote_sample,
args=(remote_actor, params_queue))
remote_thread.setDaemon(True)
remote_thread.start()
logger.info('All remote actors are ready, begin to learn.')
self.start_time = time.time()
def run_remote_sample(self, remote_actor, params_queue):
""" Sample data from remote actor and update parameters of remote actor.
"""
cnt = 0
while True:
latest_params = params_queue.get()
remote_actor.set_params(latest_params)
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
cnt += 1
if cnt % self.config['get_remote_metrics_interval'] == 0:
metrics = remote_actor.get_metrics()
if metrics:
self.remote_metrics_queue.put(metrics)
def step(self):
"""
1. kick off all actors to synchronize parameters and sample data;
2. collect sample data of all actors;
3. update parameters.
"""
latest_params = self.agent.get_params()
for params_queue in self.params_queues:
params_queue.put(latest_params)
train_batch = defaultdict(list)
for i in range(self.config['actor_num']):
sample_data = self.sample_data_queue.get()
for key, value in sample_data.items():
train_batch[key].append(value)
self.sample_total_steps += sample_data['obs'].shape[0]
for key, value in train_batch.items():
train_batch[key] = np.concatenate(value)
with self.learn_time_stat:
total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn(
obs_np=train_batch['obs'],
actions_np=train_batch['actions'],
advantages_np=train_batch['advantages'],
target_values_np=train_batch['target_values'])
self.total_loss_stat.add(total_loss)
self.pi_loss_stat.add(pi_loss)
self.vf_loss_stat.add(vf_loss)
self.entropy_stat.add(entropy)
self.lr = lr
self.entropy_coeff = entropy_coeff
def log_metrics(self):
""" Log metrics of learner and actors
"""
if self.start_time is None:
return
metrics = []
while True:
try:
metric = self.remote_metrics_queue.get_nowait()
metrics.append(metric)
except queue.Empty:
break
episode_rewards, episode_steps = [], []
for x in metrics:
episode_rewards.extend(x['episode_rewards'])
episode_steps.extend(x['episode_steps'])
max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
max_episode_steps, mean_episode_steps, min_episode_steps =\
None, None, None, None, None, None
if episode_rewards:
mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
max_episode_rewards = np.max(np.array(episode_rewards).flatten())
min_episode_rewards = np.min(np.array(episode_rewards).flatten())
mean_episode_steps = np.mean(np.array(episode_steps).flatten())
max_episode_steps = np.max(np.array(episode_steps).flatten())
min_episode_steps = np.min(np.array(episode_steps).flatten())
metric = {
'Sample steps': self.sample_total_steps,
'max_episode_rewards': max_episode_rewards,
'mean_episode_rewards': mean_episode_rewards,
'min_episode_rewards': min_episode_rewards,
'max_episode_steps': max_episode_steps,
'mean_episode_steps': mean_episode_steps,
'min_episode_steps': min_episode_steps,
'total_loss': self.total_loss_stat.mean,
'pi_loss': self.pi_loss_stat.mean,
'vf_loss': self.vf_loss_stat.mean,
'entropy': self.entropy_stat.mean,
'learn_time_s': self.learn_time_stat.mean,
'elapsed_time_s': int(time.time() - self.start_time),
'lr': self.lr,
'entropy_coeff': self.entropy_coeff,
}
logger.info(metric)
self.csv_logger.log_dict(metric)
def close(self):
self.csv_logger.close()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""
for i in $(seq 1 5); do
python actor.py &
done;
wait
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from learner import Learner
def main(config):
learner = Learner(config)
try:
while True:
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
except KeyboardInterrupt:
learner.close()
if __name__ == '__main__':
from a2c_config import config
main(config)
## Reproduce DDPG with PARL
Based on PARL, the DDPG model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Mujoco game.
Based on PARL, the DDPG algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
+ DDPG in
[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
......
## Reproduce DQN with PARL
Based on PARL, the DQN model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Atari game.
Based on PARL, the DQN algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
+ DQN in
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
......
......@@ -8,7 +8,7 @@ Based on PARL, the IMPALA algorithm of deep reinforcement learning is reproduced
Please see [here](https://gym.openai.com/envs/#atari) to know more about Atari game.
### Benchmark result
Result with one learner (in P40 GPU) and 32 actors (in 32 CPUs).
Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs).
+ PongNoFrameskip-v4: mean_episode_rewards can reach 18-19 score in about 7~10 minutes.
<img src=".benchmark/IMPALA_Pong.jpg" width = "400" height ="300" alt="IMPALA_Pong" />
......@@ -37,7 +37,7 @@ python train.py
#### Actors (Suggest: 32+ actors in 32+ CPUs)
```sh
for index in {1..32}; do
for i in $(seq 1 32); do
python actor.py &
done;
wait
......
......@@ -31,7 +31,7 @@ class AtariAgent(Agent):
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
# Use ParallelExecutor to make learn program running faster
# Use ParallelExecutor to make learn program run faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=self.learn_program,
......@@ -127,9 +127,3 @@ class AtariAgent(Agent):
total_loss, pi_loss, vf_loss, entropy, kl = self.learn_exe.run(
fetch_list=self.learn_outputs)
return total_loss, pi_loss, vf_loss, entropy, kl
def get_params(self):
return self.alg.get_params()
def set_params(self, params):
self.alg.set_params(params, gpu_id=self.gpu_id)
......@@ -43,7 +43,7 @@ class AtariModel(Model):
def policy(self, obs):
"""
Args:
obs: An float32 tensor of shape [B, C, H, W]
obs: A float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
"""
......@@ -59,7 +59,7 @@ class AtariModel(Model):
def value(self, obs):
"""
Args:
obs: An float32 tensor of shape [B, C, H, W]
obs: A float32 tensor of shape [B, C, H, W]
Returns:
value: B
"""
......
......@@ -22,7 +22,6 @@ config = {
#========== env config ==========
'env_name': 'PongNoFrameskip-v4',
'env_dim': 42,
'obs_format': 'NHWC',
#========== actor config ==========
'env_num': 5,
......
......@@ -2,7 +2,7 @@
export CUDA_VISIBLE_DEVICES=""
for index in {1..32}; do
for i in $(seq 1 32); do
python actor.py &
done;
wait
## Reproduce PPO with PARL
Based on PARL, the PPO model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Mujoco game.
Based on PARL, the PPO algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Atari benchmarks.
Include following approach:
+ Clipped Surrogate Objective
+ Adaptive KL Penalty Coefficient
......
......@@ -6,22 +6,7 @@
import numpy as np
import scipy.signal
__all__ = ['calc_discount_sum_rewards', 'calc_gae', 'Scaler']
def calc_discount_sum_rewards(rewards, gamma):
""" Calculate discounted forward sum of a sequence at each point """
return scipy.signal.lfilter([1.0], [1.0, -gamma], rewards[::-1])[::-1]
def calc_gae(rewards, values, gamma, lam):
""" Calculate generalized advantage estimator.
See: https://arxiv.org/pdf/1506.02438.pdf
"""
# temporal differences
tds = rewards - values + np.append(values[1:] * gamma, 0)
advantages = calc_discount_sum_rewards(tds, gamma * lam)
return advantages
__all__ = ['Scaler']
class Scaler(object):
......
......@@ -19,7 +19,8 @@ from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import PPO
from parl.utils import logger, action_mapping
from utils import *
from parl.utils.rl_utils import calc_gae, calc_discount_sum_rewards
from scaler import Scaler
def run_train_episode(env, agent, scaler):
......@@ -110,7 +111,8 @@ def build_train_data(trajectories, agent):
discount_sum_rewards = calc_discount_sum_rewards(
scale_rewards, args.gamma).astype('float32')
advantages = calc_gae(scale_rewards, pred_values, args.gamma, args.lam)
advantages = calc_gae(scale_rewards, pred_values, 0, args.gamma,
args.lam)
# normalize advantages
advantages = (advantages - advantages.mean()) / (
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.a3c import *
from parl.algorithms.ddpg import *
from parl.algorithms.dqn import *
from parl.algorithms.policy_gradient import *
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.algorithm_base import Algorithm
from parl.framework.policy_distribution import CategoricalDistribution
__all__ = ['A3C']
class A3C(Algorithm):
def __init__(self, model, hyperparas):
super(A3C, self).__init__(model, hyperparas)
def learn(self, obs, actions, advantages, target_values, learning_rate,
entropy_coeff):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
actions: An int64 tensor of shape [B].
advantages: A float32 tensor of shape [B].
target_values: A float32 tensor of shape [B].
learning_rate: float scalar of learning rate.
entropy_coeff: float scalar of entropy coefficient.
"""
logits = self.model.policy(obs)
policy_distribution = CategoricalDistribution(logits)
actions_log_probs = policy_distribution.logp(actions)
# The policy gradient loss
pi_loss = -1.0 * layers.reduce_sum(actions_log_probs * advantages)
# The value function loss
values = self.model.value(obs)
delta = values - target_values
vf_loss = 0.5 * layers.reduce_sum(layers.square(delta))
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
policy_entropy = policy_distribution.entropy()
entropy = layers.reduce_sum(policy_entropy)
total_loss = (pi_loss + vf_loss * self.hp['vf_loss_coeff'] +
entropy * entropy_coeff)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=40.0))
optimizer = fluid.optimizer.AdamOptimizer(learning_rate)
optimizer.minimize(total_loss)
return total_loss, pi_loss, vf_loss, entropy
def sample(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits, values = self.model.policy_and_value(obs)
policy_dist = CategoricalDistribution(logits)
sample_actions = policy_dist.sample()
return sample_actions, values
def predict(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
probs = layers.softmax(logits)
predict_actions = layers.argmax(probs, 1)
return predict_actions
def value(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
values = self.model.value(obs)
return values
......@@ -89,3 +89,19 @@ class Agent(object):
this function is the training interface for Agent.
"""
raise NotImplementedError
def get_params(self):
""" Get parameters of self.alg
Returns:
List of numpy array.
"""
return self.alg.get_params()
def set_params(self, params):
""" Set parameters of self.alg
Args:
params: List of numpy array.
"""
self.alg.set_params(params, gpu_id=self.gpu_id)
......@@ -18,4 +18,5 @@ from parl.utils.csv_logger import *
from parl.utils.machine_info import *
from parl.utils.np_utils import *
from parl.utils.replay_memory import *
from parl.utils.rl_utils import *
from parl.utils.scheduler import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import scipy.signal
__all__ = ['calc_discount_sum_rewards', 'calc_gae']
def calc_discount_sum_rewards(rewards, gamma):
""" Calculate discounted forward sum of a sequence at each point.
Args:
rewards (List/Tuple/np.array): rewards of (s_t, s_{t+1}, ..., s_T)
gamma (Scalar): gamma coefficient
Returns:
np.array: discounted sum rewards of (s_t, s_{t+1}, ..., s_T)
"""
return scipy.signal.lfilter([1.0], [1.0, -gamma], rewards[::-1])[::-1]
def calc_gae(rewards, values, next_value, gamma, lam):
""" Calculate generalized advantage estimator (GAE).
See: https://arxiv.org/pdf/1506.02438.pdf
Args:
rewards (List/Tuple/np.array): rewards of (s_t, s_{t+1}, ..., s_T)
values (List/Tuple/np.array): values of (s_t, s_{t+1}, ..., s_T)
next_value (Scalar): value of s_{T+1}
gamma (Scalar): gamma coefficient
lam (Scalar): lambda coefficient
Returns:
advantages (np.array): advantages of (s_t, s_{t+1}, ..., s_T)
"""
# temporal differences
tds = rewards + gamma * np.append(values[1:], next_value) - values
advantages = calc_discount_sum_rewards(tds, gamma * lam)
return advantages
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册