diff --git a/README.md b/README.md index d7546ab7d8c0d8df94f026839e023efc14929ab2..4e7a956a08e592c2e9612c8ad37f15631483b228 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,9 @@ English | [简体中文](./README.cn.md) # Features **Reproducible**. We provide algorithms that stably reproduce the result of many influential reinforcement learning algorithms. -**Large Scale**. Ability to support high performance parallelization of training with thousands of CPUs and multi-GPUs. +**Large Scale**. Ability to support high-performance parallelization of training with thousands of CPUs and multi-GPUs. -**Reusable**. Algorithms provided in repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically. +**Reusable**. Algorithms provided in the repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically. **Extensible**. Build new algorithms quickly by inheriting the abstract class in the framework. @@ -30,35 +30,35 @@ The main abstractions introduced by PARL that are used to build an agent recursi ### Agent `Agent`, a data bridge between environment and algorithm, is responsible for data I/O with the outside environment and describes data preprocessing before feeding data into the training process. -Here is an example of building an agent with DQN algorithm for atari games. +Here is an example of building an agent with DQN algorithm for Atari games. ```python import parl from parl.algorithms import DQN, DDQN class AtariModel(parl.Model): - """AtariModel - This class defines the forward part for an algorithm, - its input is state observed on environment. - """ - def __init__(self, img_shape, action_dim): - # define your layers - self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, - stride=1, padding=2, act='relu') - ... - self.fc1 = layers.fc(action_dim) - - def value(self, img): - # define how to estimate the Q value based on the image of atari games. - img = img / 255.0 - l = self.cnn1(img) - ... - Q = self.fc1(l) - return Q + """AtariModel + This class defines the forward part for an algorithm, + its input is state observed on the environment. + """ + def __init__(self, img_shape, action_dim): + # define your layers + self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, + stride=1, padding=2, act='relu') + ... + self.fc1 = layers.fc(action_dim) + + def value(self, img): + # define how to estimate the Q value based on the image of atari games. + img = img / 255.0 + l = self.cnn1(img) + ... + Q = self.fc1(l) + return Q """ three steps to build an agent 1. define a forward model which is critic_model in this example 2. a. to build a DQN algorithm, just pass the critic_model to `DQN` - b. to build a DDQN algorithm, just replace DQN in following line with DDQN + b. to build a DDQN algorithm, just replace DQN in the following line with DDQN 3. define the I/O part in AtariAgent so that it could update the algorithm based on the interactive data """ @@ -69,17 +69,17 @@ agent = AtariAgent(algorithm) # Parallelization PARL provides a compact API for distributed training, allowing users to transfer the code into a parallelized version by simply adding a decorator. -Here is a `Hello World!` example to demonstrate how easily it is to leverage outer computation resources. +Here is a `Hello World` example to demonstrate how easily it is to leverage outer computation resources. ```python #============Agent.py================= @parl.remote_class class Agent(object): - def say_hello(self): - print("Hello World!") + def say_hello(self): + print("Hello World!") - def sum(self, a, b): - return a+b + def sum(self, a, b): + return a+b # launch `Agent.py` at any computation platforms such as a CPU cluster. if __main__ == '__main__': @@ -91,14 +91,14 @@ if __main__ == '__main__': remote_manager = parl.RemoteManager() agent = remote_manager.get_remote() agent.say_hello() -ans = agent.sum(1,5) # run remotely and not comsume any local computation resources +ans = agent.sum(1,5) # run remotely and not consume any local computation resources ``` Two steps to use outer computation resources: -1. use the `parl.remote_class` to decorate a class at first, after which it is transfered to be a new class that can run in other CPUs or machines. +1. use the `parl.remote_class` to decorate a class at first, after which it is transferred to be a new class that can run in other CPUs or machines. 2. Get remote objects from the `RemoteManager`, and these objects have the same functions as the real ones. However, calling any function of these objects **does not** consume local computation resources since they are executed elsewhere. PARL -As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(bule circle) is running at the local gpu with several remote actors(yellow circle with dotted edge). +As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(blue circle) is running at the local gpu with several remote actors(yellow circle with dotted edge). For users, they can write code in a simple way, just like writing multi-thread code, but with actors consuming remote resources. We have also provided examples of parallized algorithms like [IMPALA](examples/IMPALA), [A2C](examples/A2C) and [GA3C](examples/GA3C). For more details in usage please refer to these examples. diff --git a/examples/A2C/train.py b/examples/A2C/train.py index bc9347298dcae8ede5d7ed816858e7621316b190..b57bd1644794c6a1d13c5f48921f7dd075ff5203 100644 --- a/examples/A2C/train.py +++ b/examples/A2C/train.py @@ -18,6 +18,7 @@ from learner import Learner def main(config): learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 try: while True: diff --git a/examples/GA3C/atari_agent.py b/examples/GA3C/atari_agent.py index c27311fb2030838744e491d89b645b729c07f080..c50be0bb19b4c9b964c72923de07c13e8fb2416f 100644 --- a/examples/GA3C/atari_agent.py +++ b/examples/GA3C/atari_agent.py @@ -82,7 +82,7 @@ class AtariAgent(Agent): name='entropy_coeff', shape=[], dtype='float32') self.learn_reader = fluid.layers.create_py_reader_by_data( - capacity=self.config['train_batch_size'], + capacity=32, feed_list=[ obs, actions, advantages, target_values, lr, entropy_coeff ]) diff --git a/examples/GA3C/train.py b/examples/GA3C/train.py index ade43899b5caad97ba9bf2faf02f46878c8ffbad..a84f377b73335b6a044c7fd00590f38574ab1fd6 100644 --- a/examples/GA3C/train.py +++ b/examples/GA3C/train.py @@ -18,6 +18,7 @@ from learner import Learner def main(config): learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 try: while True: diff --git a/examples/IMPALA/atari_agent.py b/examples/IMPALA/atari_agent.py index 6b7d34540a5c2e8ecc6a51b33f2b6767464e00fb..4e077701cde73ec12dfcf44e8b5104fc98452cec 100644 --- a/examples/IMPALA/atari_agent.py +++ b/examples/IMPALA/atari_agent.py @@ -73,7 +73,7 @@ class AtariAgent(Agent): name='entropy_coeff', shape=[], dtype='float32') self.learn_reader = fluid.layers.create_py_reader_by_data( - capacity=self.config['train_batch_size'], + capacity=32, feed_list=[ obs, actions, behaviour_logits, rewards, dones, lr, entropy_coeff diff --git a/examples/IMPALA/impala_config.py b/examples/IMPALA/impala_config.py index b324d961eff08ec25a2f67ae60f4a3a6d3571794..dcfc2f5725f0a02b895e1426783274b2b944aefe 100644 --- a/examples/IMPALA/impala_config.py +++ b/examples/IMPALA/impala_config.py @@ -29,7 +29,6 @@ config = { #========== learner config ========== 'train_batch_size': 1000, - 'learner_queue_max_size': 16, 'sample_queue_max_size': 8, 'gamma': 0.99, diff --git a/examples/IMPALA/learner.py b/examples/IMPALA/learner.py index 272413c1b3499aae6149975b8835acfd4ea603dc..25606428876e2a0123a1d9e245348853c0f3f69b 100644 --- a/examples/IMPALA/learner.py +++ b/examples/IMPALA/learner.py @@ -32,9 +32,8 @@ from parl.utils.window_stat import WindowStat class Learner(object): def __init__(self, config): self.config = config - - self.learner_queue = queue.Queue( - maxsize=config['learner_queue_max_size']) + self.sample_data_queue = queue.Queue( + maxsize=config['sample_queue_max_size']) #=========== Create Agent ========== env = gym.make(config['env_name']) @@ -75,8 +74,7 @@ class Learner(object): #========== Remote Actor =========== self.remote_count = 0 - self.sample_data_queue = queue.Queue( - maxsize=config['sample_queue_max_size']) + self.batch_buffer = [] self.remote_metrics_queue = queue.Queue() self.sample_total_steps = 0 @@ -93,21 +91,33 @@ class Learner(object): """ Data generator for fluid.layers.py_reader """ while True: - batch = self.learner_queue.get() - - obs_np = batch['obs'].astype('float32') - actions_np = batch['actions'].astype('int64') - behaviour_logits_np = batch['behaviour_logits'].astype('float32') - rewards_np = batch['rewards'].astype('float32') - dones_np = batch['dones'].astype('float32') - - self.lr = self.lr_scheduler.step() - self.entropy_coeff = self.entropy_coeff_scheduler.step() - - yield [ - obs_np, actions_np, behaviour_logits_np, rewards_np, dones_np, - self.lr, self.entropy_coeff - ] + sample_data = self.sample_data_queue.get() + self.sample_total_steps += sample_data['obs'].shape[0] + self.batch_buffer.append(sample_data) + + buffer_size = sum( + [data['obs'].shape[0] for data in self.batch_buffer]) + if buffer_size >= self.config['train_batch_size']: + batch = {} + for key in self.batch_buffer[0].keys(): + batch[key] = np.concatenate( + [data[key] for data in self.batch_buffer]) + self.batch_buffer = [] + + obs_np = batch['obs'].astype('float32') + actions_np = batch['actions'].astype('int64') + behaviour_logits_np = batch['behaviour_logits'].astype( + 'float32') + rewards_np = batch['rewards'].astype('float32') + dones_np = batch['dones'].astype('float32') + + self.lr = self.lr_scheduler.step() + self.entropy_coeff = self.entropy_coeff_scheduler.step() + + yield [ + obs_np, actions_np, behaviour_logits_np, rewards_np, + dones_np, self.lr, self.entropy_coeff + ] def run_learn(self): """ Learn loop @@ -170,31 +180,6 @@ class Learner(object): remote_actor.set_params(self.cache_params) - def step(self): - """ Merge and generate batch learn data from sample_data_queue, - and put it in learner_queue. - """ - assert self.learn_thread.is_alive() - - while True: - try: - sample_data = self.sample_data_queue.get_nowait() - self.sample_total_steps += sample_data['obs'].shape[0] - self.batch_buffer.append(sample_data) - - buffer_size = sum( - [data['obs'].shape[0] for data in self.batch_buffer]) - if buffer_size >= self.config['train_batch_size']: - train_batch = {} - for key in self.batch_buffer[0].keys(): - train_batch[key] = np.concatenate( - [data[key] for data in self.batch_buffer]) - self.learner_queue.put(train_batch) - self.batch_buffer = [] - - except queue.Empty: - break - def log_metrics(self): """ Log metrics of learner and actors """ @@ -233,7 +218,6 @@ class Learner(object): 'max_episode_steps': max_episode_steps, 'mean_episode_steps': mean_episode_steps, 'min_episode_steps': min_episode_steps, - 'learner_queue_size': self.learner_queue.qsize(), 'sample_queue_size': self.sample_data_queue.qsize(), 'total_params_sync': self.total_params_sync, 'cache_params_sent_cnt': self.cache_params_sent_cnt, diff --git a/examples/IMPALA/train.py b/examples/IMPALA/train.py index f15075106487595d549f4aa28af01f117a4480aa..bf38bb34281404975a620b04171cda53f7eaa224 100644 --- a/examples/IMPALA/train.py +++ b/examples/IMPALA/train.py @@ -22,9 +22,8 @@ def main(config): try: while True: - start = time.time() - while time.time() - start < config['log_metrics_interval_s']: - learner.step() + time.sleep(config['log_metrics_interval_s']) + learner.log_metrics() except KeyboardInterrupt: diff --git a/parl/__init__.py b/parl/__init__.py index 205164a3d0aa1bf35d0f75f2ef92431979b804a0..c04b1a4ccbe9032a468d052af5d6bf09a0431030 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -14,5 +14,14 @@ """ generates new PARL python API """ -from parl.framework import * + +from parl.utils.utils import _HAS_FLUID + +if _HAS_FLUID: + from parl.framework import * +else: + print( + "WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available." + ) + from parl.remote import remote_class, RemoteManager diff --git a/parl/framework/algorithm_base.py b/parl/framework/algorithm_base.py index 9672134814458102ae89545dc5eee97988af5976..8b606d829b7d67a25888513bccb02848dad90d79 100644 --- a/parl/framework/algorithm_base.py +++ b/parl/framework/algorithm_base.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid as fluid -import parl.layers as layers from abc import ABCMeta, abstractmethod -from parl.framework.model_base import Network, Model +from parl.framework.model_base import Model __all__ = ['Algorithm'] @@ -23,13 +21,13 @@ __all__ = ['Algorithm'] class Algorithm(object): """ Algorithm defines the way how we update the model. For example, - after defining forward network in `Network` class, you should define how to update the model here. + after defining forward network in `Model` class, you should define how to update the model here. Before creating a customized algorithm, please do check algorithms of PARL. - Most common used algorithms like DQN/DDPG/PPO have been providing in algorithms, go and have a try. - It's easy to use them and just try pl.algorithms.DQN. + Most common used algorithms like DQN/DDPG/PPO/A3C have been providing in algorithms, go and have a try. + It's easy to use them and just try parl.algorithms.DQN. An Algorithm implements two functions: - 1. define_predict() build forward process which was defined in Network + 1. define_predict() build forward process which was defined in `Model` 2. define_learn() computes a cost for optimization An algorithm should be updating part of a network. The user only needs to diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index 6012f527ef256b0c1235f0556bde7810bb82e6cf..b30aad1c0c1076561af2dd41f634f9c891171290 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -21,3 +21,6 @@ SERIALIZE_EXCEPTION_TAG = b'[SERIALIZE_EXCEPTION]' DESERIALIZE_EXCEPTION_TAG = b'[DESERIALIZE_EXCEPTION]' NORMAL_TAG = b'[NORMAL]' + +# interval of heartbeat mechanism in the unit of second +HEARTBEAT_INTERVAL_S = 10 diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 41a9e2bf969c53aec2e2b2d6d52bd9dd3df94d20..2e88a558b0a4fb8b33809ccc90942a3dfb557fa4 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -66,36 +66,23 @@ def remote_class(cls): and waits for requests (e.g. call a function) from server side. """ if remote_ip is None: - client_ip = get_ip_address() - else: - client_ip = remote_ip + remote_ip = get_ip_address() self.zmq_context = zmq.Context() socket = self.zmq_context.socket(zmq.REP) - free_port = None if remote_port is None: - for port in range(6000, 8000): - try: - socket.bind("tcp://*:{}".format(port)) - logger.info( - "[_create_reply_socket] free_port:{}".format(port)) - free_port = port - break - except zmq.error.ZMQError: - logger.warning( - "[_create_reply_socket]cannot bind port:{}, retry". - format(port)) + try: + remote_port = socket.bind_to_random_port(addr="tcp://*") + except zmq.ZMQBindError: + logger.error( + 'Can not bind to a random port, please set remote_port manually.' + ) + sys.exit(1) else: socket.bind("tcp://*:{}".format(remote_port)) - free_port = remote_port - if free_port is not None: - return socket, client_ip, free_port - else: - logger.error( - "cannot find any available port from 6000 to 8000") - sys.exit(1) + return socket, remote_ip, remote_port def _connect_server(self, server_ip, server_port, remote_ip, remote_port): @@ -116,26 +103,31 @@ def remote_class(cls): socket = self.zmq_context.socket(zmq.REQ) socket.connect("tcp://{}:{}".format(server_ip, server_port)) - client_id = np.random.randint(int(1e18)) - logger.info("client_id:{}".format(client_id)) logger.info("connecting {}:{}".format(server_ip, server_port)) - client_info = '{}:{} {}'.format(local_ip, local_port, client_id) + client_addr = '{}:{}'.format(local_ip, local_port) socket.send_multipart( [remote_constants.CONNECT_TAG, - to_byte(client_info)]) + to_byte(client_addr)]) message = socket.recv_multipart() - logger.info("[connect_server] done, message from server:{}".format( - message)) + self.client_id = message[1] + logger.info("connect server done, client_id: {}".format( + self.client_id)) self.connect_socket = socket self.connect_socket.linger = 0 def _exit_remote(self): - # Following release order matters self.poller.unregister(self.connect_socket) - self.zmq_context.destroy() + self.connect_socket.close() + self.reply_socket.close() + + # The program may hang when destroying zmq context manually. + # It will be destroyed automatically by the garbage collection mechanism of python, + # though it may raise some exceptions in C++. + + #self.zmq_context.destroy() def _heartbeat_loop(self): """ @@ -146,7 +138,7 @@ def remote_class(cls): while True: self.connect_socket.send_multipart( - [remote_constants.HEARTBEAT_TAG]) + [remote_constants.HEARTBEAT_TAG, self.client_id]) # wait for at most 10s to receive response socks = dict(self.poller.poll(10000)) @@ -161,7 +153,7 @@ def remote_class(cls): break # HeartBeat interval 10s - time.sleep(10) + time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) def __getattr__(self, attr): """ diff --git a/parl/remote/remote_manager.py b/parl/remote/remote_manager.py index 0d6d3edd8c77a72beb360775fee87813c6d52c4f..6afe7002cd1e8f2acbca44d00ba992f0f0f9584c 100644 --- a/parl/remote/remote_manager.py +++ b/parl/remote/remote_manager.py @@ -14,6 +14,7 @@ import queue import threading +import time import zmq from parl.utils import logger, to_byte, to_str from parl.remote import remote_constants @@ -48,11 +49,16 @@ class RemoteManager(object): self.socket.linger = 0 self.remote_pool = queue.Queue() + self.remote_latest_timestamp = {} t = threading.Thread(target=self._wait_for_connection) t.setDaemon(True) # The thread will exit when main thread exited t.start() + t = threading.Thread(target=self._check_remote_status) + t.setDaemon(True) # The thread will exit when main thread exited + t.start() + def _wait_for_connection(self): """ A never-ending function keeps waiting for the connections from remote client. @@ -61,6 +67,7 @@ class RemoteManager(object): Note that this function has been called inside the `__init__` function. """ + remote_id = 0 while True: try: message = self.socket.recv_multipart() @@ -68,17 +75,23 @@ class RemoteManager(object): if tag == remote_constants.CONNECT_TAG: self.socket.send_multipart([ - remote_constants.NORMAL_TAG, b'Connect server success.' + remote_constants.NORMAL_TAG, + to_byte('{}'.format(remote_id)) ]) - client_info = to_str(message[1]) - remote_client_address, remote_client_id = client_info.split( - ) + remote_client_address = to_str(message[1]) remote_obj = RemoteObject(remote_client_address, - remote_client_id, self.zmq_context) + + self.remote_latest_timestamp[to_byte( + str(remote_id))] = time.time() + logger.info('[RemoteManager] Added a new remote object.') self.remote_pool.put(remote_obj) + + remote_id += 1 + elif tag == remote_constants.HEARTBEAT_TAG: + self.remote_latest_timestamp[message[1]] = time.time() self.socket.send_multipart( [remote_constants.NORMAL_TAG, b'Server is alive.']) else: @@ -88,6 +101,17 @@ class RemoteManager(object): logger.warning('Zmq error, exiting server.') break + def _check_remote_status(self): + while True: + for remote_id in list(self.remote_latest_timestamp.keys()): + if time.time() - self.remote_latest_timestamp[ + remote_id] > 3 * remote_constants.HEARTBEAT_INTERVAL_S: + logger.error( + 'Remote object {} is lost, please check if anything wrong happens in the remote client' + .format(remote_id)) + self.remote_latest_timestamp.pop(remote_id) + time.sleep(3 * remote_constants.HEARTBEAT_INTERVAL_S) + def get_remote(self): """ A blocking function to obtain a remote object. diff --git a/parl/remote/remote_object.py b/parl/remote/remote_object.py index 9f9e7b1fe5f8bf772e3c5801678521c3ef211fbe..319e905f9ff4c71c0ccf8251b7f8a4b86c8e2865 100644 --- a/parl/remote/remote_object.py +++ b/parl/remote/remote_object.py @@ -25,15 +25,11 @@ class RemoteObject(object): Provides interface to call functions of object in remote client. """ - def __init__(self, - remote_client_address, - remote_client_id, - zmq_context=None): + def __init__(self, remote_client_address, zmq_context=None): """ Args: remote_client_address: address(ip:port) of remote client - remote_client_id: id of remote client - + zmq_context: zmq.Context() """ if zmq_context is None: self.zmq_context = zmq.Context() @@ -44,7 +40,6 @@ class RemoteObject(object): self.command_socket = None # lock for thread safety self.internal_lock = threading.Lock() - self.client_id = remote_client_id self._connect_remote_client(remote_client_address) def _connect_remote_client(self, remote_client_address): diff --git a/parl/utils/utils.py b/parl/utils/utils.py index 206f9bc4975b213d2f98cac327b448eccba2b6a0..c5ccc0de82b4943be368279df3a2449c3ee1b6b1 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -72,3 +72,9 @@ def is_PY3(): MAX_INT32 = 0x7fffffff + +try: + from paddle import fluid + _HAS_FLUID = True +except ImportError: + _HAS_FLUID = False