提交 3556c786 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

Refine (#67)

* fix typo

* Update README.md

* Update README.md

* Update README.md

* soft depend on fluid; add module to monitor client status

* improve performance of IMPALA example

* fix bug of some client cannot exit normally

* refine comment

* .
上级 432d75b7
......@@ -9,9 +9,9 @@ English | [简体中文](./README.cn.md)
# Features
**Reproducible**. We provide algorithms that stably reproduce the result of many influential reinforcement learning algorithms.
**Large Scale**. Ability to support high performance parallelization of training with thousands of CPUs and multi-GPUs.
**Large Scale**. Ability to support high-performance parallelization of training with thousands of CPUs and multi-GPUs.
**Reusable**. Algorithms provided in repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically.
**Reusable**. Algorithms provided in the repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically.
**Extensible**. Build new algorithms quickly by inheriting the abstract class in the framework.
......@@ -30,7 +30,7 @@ The main abstractions introduced by PARL that are used to build an agent recursi
### Agent
`Agent`, a data bridge between environment and algorithm, is responsible for data I/O with the outside environment and describes data preprocessing before feeding data into the training process.
Here is an example of building an agent with DQN algorithm for atari games.
Here is an example of building an agent with DQN algorithm for Atari games.
```python
import parl
from parl.algorithms import DQN, DDQN
......@@ -38,7 +38,7 @@ from parl.algorithms import DQN, DDQN
class AtariModel(parl.Model):
"""AtariModel
This class defines the forward part for an algorithm,
its input is state observed on environment.
its input is state observed on the environment.
"""
def __init__(self, img_shape, action_dim):
# define your layers
......@@ -58,7 +58,7 @@ class AtariModel(parl.Model):
three steps to build an agent
1. define a forward model which is critic_model in this example
2. a. to build a DQN algorithm, just pass the critic_model to `DQN`
b. to build a DDQN algorithm, just replace DQN in following line with DDQN
b. to build a DDQN algorithm, just replace DQN in the following line with DDQN
3. define the I/O part in AtariAgent so that it could update the algorithm based on the interactive data
"""
......@@ -69,7 +69,7 @@ agent = AtariAgent(algorithm)
# Parallelization
PARL provides a compact API for distributed training, allowing users to transfer the code into a parallelized version by simply adding a decorator.
Here is a `Hello World!` example to demonstrate how easily it is to leverage outer computation resources.
Here is a `Hello World` example to demonstrate how easily it is to leverage outer computation resources.
```python
#============Agent.py=================
@parl.remote_class
......@@ -91,14 +91,14 @@ if __main__ == '__main__':
remote_manager = parl.RemoteManager()
agent = remote_manager.get_remote()
agent.say_hello()
ans = agent.sum(1,5) # run remotely and not comsume any local computation resources
ans = agent.sum(1,5) # run remotely and not consume any local computation resources
```
Two steps to use outer computation resources:
1. use the `parl.remote_class` to decorate a class at first, after which it is transfered to be a new class that can run in other CPUs or machines.
1. use the `parl.remote_class` to decorate a class at first, after which it is transferred to be a new class that can run in other CPUs or machines.
2. Get remote objects from the `RemoteManager`, and these objects have the same functions as the real ones. However, calling any function of these objects **does not** consume local computation resources since they are executed elsewhere.
<img src=".github/decorator.png" alt="PARL" width="450"/>
As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(bule circle) is running at the local gpu with several remote actors(yellow circle with dotted edge).
As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(blue circle) is running at the local gpu with several remote actors(yellow circle with dotted edge).
For users, they can write code in a simple way, just like writing multi-thread code, but with actors consuming remote resources. We have also provided examples of parallized algorithms like [IMPALA](examples/IMPALA), [A2C](examples/A2C) and [GA3C](examples/GA3C). For more details in usage please refer to these examples.
......
......@@ -18,6 +18,7 @@ from learner import Learner
def main(config):
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
try:
while True:
......
......@@ -82,7 +82,7 @@ class AtariAgent(Agent):
name='entropy_coeff', shape=[], dtype='float32')
self.learn_reader = fluid.layers.create_py_reader_by_data(
capacity=self.config['train_batch_size'],
capacity=32,
feed_list=[
obs, actions, advantages, target_values, lr, entropy_coeff
])
......
......@@ -18,6 +18,7 @@ from learner import Learner
def main(config):
learner = Learner(config)
assert config['log_metrics_interval_s'] > 0
try:
while True:
......
......@@ -73,7 +73,7 @@ class AtariAgent(Agent):
name='entropy_coeff', shape=[], dtype='float32')
self.learn_reader = fluid.layers.create_py_reader_by_data(
capacity=self.config['train_batch_size'],
capacity=32,
feed_list=[
obs, actions, behaviour_logits, rewards, dones, lr,
entropy_coeff
......
......@@ -29,7 +29,6 @@ config = {
#========== learner config ==========
'train_batch_size': 1000,
'learner_queue_max_size': 16,
'sample_queue_max_size': 8,
'gamma': 0.99,
......
......@@ -32,9 +32,8 @@ from parl.utils.window_stat import WindowStat
class Learner(object):
def __init__(self, config):
self.config = config
self.learner_queue = queue.Queue(
maxsize=config['learner_queue_max_size'])
self.sample_data_queue = queue.Queue(
maxsize=config['sample_queue_max_size'])
#=========== Create Agent ==========
env = gym.make(config['env_name'])
......@@ -75,8 +74,7 @@ class Learner(object):
#========== Remote Actor ===========
self.remote_count = 0
self.sample_data_queue = queue.Queue(
maxsize=config['sample_queue_max_size'])
self.batch_buffer = []
self.remote_metrics_queue = queue.Queue()
self.sample_total_steps = 0
......@@ -93,11 +91,23 @@ class Learner(object):
""" Data generator for fluid.layers.py_reader
"""
while True:
batch = self.learner_queue.get()
sample_data = self.sample_data_queue.get()
self.sample_total_steps += sample_data['obs'].shape[0]
self.batch_buffer.append(sample_data)
buffer_size = sum(
[data['obs'].shape[0] for data in self.batch_buffer])
if buffer_size >= self.config['train_batch_size']:
batch = {}
for key in self.batch_buffer[0].keys():
batch[key] = np.concatenate(
[data[key] for data in self.batch_buffer])
self.batch_buffer = []
obs_np = batch['obs'].astype('float32')
actions_np = batch['actions'].astype('int64')
behaviour_logits_np = batch['behaviour_logits'].astype('float32')
behaviour_logits_np = batch['behaviour_logits'].astype(
'float32')
rewards_np = batch['rewards'].astype('float32')
dones_np = batch['dones'].astype('float32')
......@@ -105,8 +115,8 @@ class Learner(object):
self.entropy_coeff = self.entropy_coeff_scheduler.step()
yield [
obs_np, actions_np, behaviour_logits_np, rewards_np, dones_np,
self.lr, self.entropy_coeff
obs_np, actions_np, behaviour_logits_np, rewards_np,
dones_np, self.lr, self.entropy_coeff
]
def run_learn(self):
......@@ -170,31 +180,6 @@ class Learner(object):
remote_actor.set_params(self.cache_params)
def step(self):
""" Merge and generate batch learn data from sample_data_queue,
and put it in learner_queue.
"""
assert self.learn_thread.is_alive()
while True:
try:
sample_data = self.sample_data_queue.get_nowait()
self.sample_total_steps += sample_data['obs'].shape[0]
self.batch_buffer.append(sample_data)
buffer_size = sum(
[data['obs'].shape[0] for data in self.batch_buffer])
if buffer_size >= self.config['train_batch_size']:
train_batch = {}
for key in self.batch_buffer[0].keys():
train_batch[key] = np.concatenate(
[data[key] for data in self.batch_buffer])
self.learner_queue.put(train_batch)
self.batch_buffer = []
except queue.Empty:
break
def log_metrics(self):
""" Log metrics of learner and actors
"""
......@@ -233,7 +218,6 @@ class Learner(object):
'max_episode_steps': max_episode_steps,
'mean_episode_steps': mean_episode_steps,
'min_episode_steps': min_episode_steps,
'learner_queue_size': self.learner_queue.qsize(),
'sample_queue_size': self.sample_data_queue.qsize(),
'total_params_sync': self.total_params_sync,
'cache_params_sent_cnt': self.cache_params_sent_cnt,
......
......@@ -22,9 +22,8 @@ def main(config):
try:
while True:
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
time.sleep(config['log_metrics_interval_s'])
learner.log_metrics()
except KeyboardInterrupt:
......
......@@ -14,5 +14,14 @@
"""
generates new PARL python API
"""
from parl.framework import *
from parl.utils.utils import _HAS_FLUID
if _HAS_FLUID:
from parl.framework import *
else:
print(
"WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available."
)
from parl.remote import remote_class, RemoteManager
......@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from abc import ABCMeta, abstractmethod
from parl.framework.model_base import Network, Model
from parl.framework.model_base import Model
__all__ = ['Algorithm']
......@@ -23,13 +21,13 @@ __all__ = ['Algorithm']
class Algorithm(object):
"""
Algorithm defines the way how we update the model. For example,
after defining forward network in `Network` class, you should define how to update the model here.
after defining forward network in `Model` class, you should define how to update the model here.
Before creating a customized algorithm, please do check algorithms of PARL.
Most common used algorithms like DQN/DDPG/PPO have been providing in algorithms, go and have a try.
It's easy to use them and just try pl.algorithms.DQN.
Most common used algorithms like DQN/DDPG/PPO/A3C have been providing in algorithms, go and have a try.
It's easy to use them and just try parl.algorithms.DQN.
An Algorithm implements two functions:
1. define_predict() build forward process which was defined in Network
1. define_predict() build forward process which was defined in `Model`
2. define_learn() computes a cost for optimization
An algorithm should be updating part of a network. The user only needs to
......
......@@ -21,3 +21,6 @@ SERIALIZE_EXCEPTION_TAG = b'[SERIALIZE_EXCEPTION]'
DESERIALIZE_EXCEPTION_TAG = b'[DESERIALIZE_EXCEPTION]'
NORMAL_TAG = b'[NORMAL]'
# interval of heartbeat mechanism in the unit of second
HEARTBEAT_INTERVAL_S = 10
......@@ -66,36 +66,23 @@ def remote_class(cls):
and waits for requests (e.g. call a function) from server side.
"""
if remote_ip is None:
client_ip = get_ip_address()
else:
client_ip = remote_ip
remote_ip = get_ip_address()
self.zmq_context = zmq.Context()
socket = self.zmq_context.socket(zmq.REP)
free_port = None
if remote_port is None:
for port in range(6000, 8000):
try:
socket.bind("tcp://*:{}".format(port))
logger.info(
"[_create_reply_socket] free_port:{}".format(port))
free_port = port
break
except zmq.error.ZMQError:
logger.warning(
"[_create_reply_socket]cannot bind port:{}, retry".
format(port))
remote_port = socket.bind_to_random_port(addr="tcp://*")
except zmq.ZMQBindError:
logger.error(
'Can not bind to a random port, please set remote_port manually.'
)
sys.exit(1)
else:
socket.bind("tcp://*:{}".format(remote_port))
free_port = remote_port
if free_port is not None:
return socket, client_ip, free_port
else:
logger.error(
"cannot find any available port from 6000 to 8000")
sys.exit(1)
return socket, remote_ip, remote_port
def _connect_server(self, server_ip, server_port, remote_ip,
remote_port):
......@@ -116,26 +103,31 @@ def remote_class(cls):
socket = self.zmq_context.socket(zmq.REQ)
socket.connect("tcp://{}:{}".format(server_ip, server_port))
client_id = np.random.randint(int(1e18))
logger.info("client_id:{}".format(client_id))
logger.info("connecting {}:{}".format(server_ip, server_port))
client_info = '{}:{} {}'.format(local_ip, local_port, client_id)
client_addr = '{}:{}'.format(local_ip, local_port)
socket.send_multipart(
[remote_constants.CONNECT_TAG,
to_byte(client_info)])
to_byte(client_addr)])
message = socket.recv_multipart()
logger.info("[connect_server] done, message from server:{}".format(
message))
self.client_id = message[1]
logger.info("connect server done, client_id: {}".format(
self.client_id))
self.connect_socket = socket
self.connect_socket.linger = 0
def _exit_remote(self):
# Following release order matters
self.poller.unregister(self.connect_socket)
self.zmq_context.destroy()
self.connect_socket.close()
self.reply_socket.close()
# The program may hang when destroying zmq context manually.
# It will be destroyed automatically by the garbage collection mechanism of python,
# though it may raise some exceptions in C++.
#self.zmq_context.destroy()
def _heartbeat_loop(self):
"""
......@@ -146,7 +138,7 @@ def remote_class(cls):
while True:
self.connect_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
[remote_constants.HEARTBEAT_TAG, self.client_id])
# wait for at most 10s to receive response
socks = dict(self.poller.poll(10000))
......@@ -161,7 +153,7 @@ def remote_class(cls):
break
# HeartBeat interval 10s
time.sleep(10)
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
def __getattr__(self, attr):
"""
......
......@@ -14,6 +14,7 @@
import queue
import threading
import time
import zmq
from parl.utils import logger, to_byte, to_str
from parl.remote import remote_constants
......@@ -48,11 +49,16 @@ class RemoteManager(object):
self.socket.linger = 0
self.remote_pool = queue.Queue()
self.remote_latest_timestamp = {}
t = threading.Thread(target=self._wait_for_connection)
t.setDaemon(True) # The thread will exit when main thread exited
t.start()
t = threading.Thread(target=self._check_remote_status)
t.setDaemon(True) # The thread will exit when main thread exited
t.start()
def _wait_for_connection(self):
"""
A never-ending function keeps waiting for the connections from remote client.
......@@ -61,6 +67,7 @@ class RemoteManager(object):
Note that this function has been called inside the `__init__` function.
"""
remote_id = 0
while True:
try:
message = self.socket.recv_multipart()
......@@ -68,17 +75,23 @@ class RemoteManager(object):
if tag == remote_constants.CONNECT_TAG:
self.socket.send_multipart([
remote_constants.NORMAL_TAG, b'Connect server success.'
remote_constants.NORMAL_TAG,
to_byte('{}'.format(remote_id))
])
client_info = to_str(message[1])
remote_client_address, remote_client_id = client_info.split(
)
remote_client_address = to_str(message[1])
remote_obj = RemoteObject(remote_client_address,
remote_client_id,
self.zmq_context)
self.remote_latest_timestamp[to_byte(
str(remote_id))] = time.time()
logger.info('[RemoteManager] Added a new remote object.')
self.remote_pool.put(remote_obj)
remote_id += 1
elif tag == remote_constants.HEARTBEAT_TAG:
self.remote_latest_timestamp[message[1]] = time.time()
self.socket.send_multipart(
[remote_constants.NORMAL_TAG, b'Server is alive.'])
else:
......@@ -88,6 +101,17 @@ class RemoteManager(object):
logger.warning('Zmq error, exiting server.')
break
def _check_remote_status(self):
while True:
for remote_id in list(self.remote_latest_timestamp.keys()):
if time.time() - self.remote_latest_timestamp[
remote_id] > 3 * remote_constants.HEARTBEAT_INTERVAL_S:
logger.error(
'Remote object {} is lost, please check if anything wrong happens in the remote client'
.format(remote_id))
self.remote_latest_timestamp.pop(remote_id)
time.sleep(3 * remote_constants.HEARTBEAT_INTERVAL_S)
def get_remote(self):
"""
A blocking function to obtain a remote object.
......
......@@ -25,15 +25,11 @@ class RemoteObject(object):
Provides interface to call functions of object in remote client.
"""
def __init__(self,
remote_client_address,
remote_client_id,
zmq_context=None):
def __init__(self, remote_client_address, zmq_context=None):
"""
Args:
remote_client_address: address(ip:port) of remote client
remote_client_id: id of remote client
zmq_context: zmq.Context()
"""
if zmq_context is None:
self.zmq_context = zmq.Context()
......@@ -44,7 +40,6 @@ class RemoteObject(object):
self.command_socket = None
# lock for thread safety
self.internal_lock = threading.Lock()
self.client_id = remote_client_id
self._connect_remote_client(remote_client_address)
def _connect_remote_client(self, remote_client_address):
......
......@@ -72,3 +72,9 @@ def is_PY3():
MAX_INT32 = 0x7fffffff
try:
from paddle import fluid
_HAS_FLUID = True
except ImportError:
_HAS_FLUID = False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册