From 3556c7869b8cd286f4267a8eb76c2787cba884a4 Mon Sep 17 00:00:00 2001 From: Hongsheng Zeng Date: Thu, 18 Apr 2019 20:16:46 +0800 Subject: [PATCH] Refine (#67) * fix typo * Update README.md * Update README.md * Update README.md * soft depend on fluid; add module to monitor client status * improve performance of IMPALA example * fix bug of some client cannot exit normally * refine comment * . --- README.md | 60 ++++++++++++------------- examples/A2C/train.py | 1 + examples/GA3C/atari_agent.py | 2 +- examples/GA3C/train.py | 1 + examples/IMPALA/atari_agent.py | 2 +- examples/IMPALA/impala_config.py | 1 - examples/IMPALA/learner.py | 76 +++++++++++++------------------- examples/IMPALA/train.py | 5 +-- parl/__init__.py | 11 ++++- parl/framework/algorithm_base.py | 12 +++-- parl/remote/remote_constants.py | 3 ++ parl/remote/remote_decorator.py | 56 ++++++++++------------- parl/remote/remote_manager.py | 34 +++++++++++--- parl/remote/remote_object.py | 9 +--- parl/utils/utils.py | 6 +++ 15 files changed, 145 insertions(+), 134 deletions(-) diff --git a/README.md b/README.md index d7546ab..4e7a956 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,9 @@ English | [简体中文](./README.cn.md) # Features **Reproducible**. We provide algorithms that stably reproduce the result of many influential reinforcement learning algorithms. -**Large Scale**. Ability to support high performance parallelization of training with thousands of CPUs and multi-GPUs. +**Large Scale**. Ability to support high-performance parallelization of training with thousands of CPUs and multi-GPUs. -**Reusable**. Algorithms provided in repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically. +**Reusable**. Algorithms provided in the repository could be directly adapted to a new task by defining a forward network and training mechanism will be built automatically. **Extensible**. Build new algorithms quickly by inheriting the abstract class in the framework. @@ -30,35 +30,35 @@ The main abstractions introduced by PARL that are used to build an agent recursi ### Agent `Agent`, a data bridge between environment and algorithm, is responsible for data I/O with the outside environment and describes data preprocessing before feeding data into the training process. -Here is an example of building an agent with DQN algorithm for atari games. +Here is an example of building an agent with DQN algorithm for Atari games. ```python import parl from parl.algorithms import DQN, DDQN class AtariModel(parl.Model): - """AtariModel - This class defines the forward part for an algorithm, - its input is state observed on environment. - """ - def __init__(self, img_shape, action_dim): - # define your layers - self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, - stride=1, padding=2, act='relu') - ... - self.fc1 = layers.fc(action_dim) - - def value(self, img): - # define how to estimate the Q value based on the image of atari games. - img = img / 255.0 - l = self.cnn1(img) - ... - Q = self.fc1(l) - return Q + """AtariModel + This class defines the forward part for an algorithm, + its input is state observed on the environment. + """ + def __init__(self, img_shape, action_dim): + # define your layers + self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, + stride=1, padding=2, act='relu') + ... + self.fc1 = layers.fc(action_dim) + + def value(self, img): + # define how to estimate the Q value based on the image of atari games. + img = img / 255.0 + l = self.cnn1(img) + ... + Q = self.fc1(l) + return Q """ three steps to build an agent 1. define a forward model which is critic_model in this example 2. a. to build a DQN algorithm, just pass the critic_model to `DQN` - b. to build a DDQN algorithm, just replace DQN in following line with DDQN + b. to build a DDQN algorithm, just replace DQN in the following line with DDQN 3. define the I/O part in AtariAgent so that it could update the algorithm based on the interactive data """ @@ -69,17 +69,17 @@ agent = AtariAgent(algorithm) # Parallelization PARL provides a compact API for distributed training, allowing users to transfer the code into a parallelized version by simply adding a decorator. -Here is a `Hello World!` example to demonstrate how easily it is to leverage outer computation resources. +Here is a `Hello World` example to demonstrate how easily it is to leverage outer computation resources. ```python #============Agent.py================= @parl.remote_class class Agent(object): - def say_hello(self): - print("Hello World!") + def say_hello(self): + print("Hello World!") - def sum(self, a, b): - return a+b + def sum(self, a, b): + return a+b # launch `Agent.py` at any computation platforms such as a CPU cluster. if __main__ == '__main__': @@ -91,14 +91,14 @@ if __main__ == '__main__': remote_manager = parl.RemoteManager() agent = remote_manager.get_remote() agent.say_hello() -ans = agent.sum(1,5) # run remotely and not comsume any local computation resources +ans = agent.sum(1,5) # run remotely and not consume any local computation resources ``` Two steps to use outer computation resources: -1. use the `parl.remote_class` to decorate a class at first, after which it is transfered to be a new class that can run in other CPUs or machines. +1. use the `parl.remote_class` to decorate a class at first, after which it is transferred to be a new class that can run in other CPUs or machines. 2. Get remote objects from the `RemoteManager`, and these objects have the same functions as the real ones. However, calling any function of these objects **does not** consume local computation resources since they are executed elsewhere. PARL -As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(bule circle) is running at the local gpu with several remote actors(yellow circle with dotted edge). +As shown in the above figure, real actors(orange circle) are running at the cpu cluster, while the learner(blue circle) is running at the local gpu with several remote actors(yellow circle with dotted edge). For users, they can write code in a simple way, just like writing multi-thread code, but with actors consuming remote resources. We have also provided examples of parallized algorithms like [IMPALA](examples/IMPALA), [A2C](examples/A2C) and [GA3C](examples/GA3C). For more details in usage please refer to these examples. diff --git a/examples/A2C/train.py b/examples/A2C/train.py index bc93472..b57bd16 100644 --- a/examples/A2C/train.py +++ b/examples/A2C/train.py @@ -18,6 +18,7 @@ from learner import Learner def main(config): learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 try: while True: diff --git a/examples/GA3C/atari_agent.py b/examples/GA3C/atari_agent.py index c27311f..c50be0b 100644 --- a/examples/GA3C/atari_agent.py +++ b/examples/GA3C/atari_agent.py @@ -82,7 +82,7 @@ class AtariAgent(Agent): name='entropy_coeff', shape=[], dtype='float32') self.learn_reader = fluid.layers.create_py_reader_by_data( - capacity=self.config['train_batch_size'], + capacity=32, feed_list=[ obs, actions, advantages, target_values, lr, entropy_coeff ]) diff --git a/examples/GA3C/train.py b/examples/GA3C/train.py index ade4389..a84f377 100644 --- a/examples/GA3C/train.py +++ b/examples/GA3C/train.py @@ -18,6 +18,7 @@ from learner import Learner def main(config): learner = Learner(config) + assert config['log_metrics_interval_s'] > 0 try: while True: diff --git a/examples/IMPALA/atari_agent.py b/examples/IMPALA/atari_agent.py index 6b7d345..4e07770 100644 --- a/examples/IMPALA/atari_agent.py +++ b/examples/IMPALA/atari_agent.py @@ -73,7 +73,7 @@ class AtariAgent(Agent): name='entropy_coeff', shape=[], dtype='float32') self.learn_reader = fluid.layers.create_py_reader_by_data( - capacity=self.config['train_batch_size'], + capacity=32, feed_list=[ obs, actions, behaviour_logits, rewards, dones, lr, entropy_coeff diff --git a/examples/IMPALA/impala_config.py b/examples/IMPALA/impala_config.py index b324d96..dcfc2f5 100644 --- a/examples/IMPALA/impala_config.py +++ b/examples/IMPALA/impala_config.py @@ -29,7 +29,6 @@ config = { #========== learner config ========== 'train_batch_size': 1000, - 'learner_queue_max_size': 16, 'sample_queue_max_size': 8, 'gamma': 0.99, diff --git a/examples/IMPALA/learner.py b/examples/IMPALA/learner.py index 272413c..2560642 100644 --- a/examples/IMPALA/learner.py +++ b/examples/IMPALA/learner.py @@ -32,9 +32,8 @@ from parl.utils.window_stat import WindowStat class Learner(object): def __init__(self, config): self.config = config - - self.learner_queue = queue.Queue( - maxsize=config['learner_queue_max_size']) + self.sample_data_queue = queue.Queue( + maxsize=config['sample_queue_max_size']) #=========== Create Agent ========== env = gym.make(config['env_name']) @@ -75,8 +74,7 @@ class Learner(object): #========== Remote Actor =========== self.remote_count = 0 - self.sample_data_queue = queue.Queue( - maxsize=config['sample_queue_max_size']) + self.batch_buffer = [] self.remote_metrics_queue = queue.Queue() self.sample_total_steps = 0 @@ -93,21 +91,33 @@ class Learner(object): """ Data generator for fluid.layers.py_reader """ while True: - batch = self.learner_queue.get() - - obs_np = batch['obs'].astype('float32') - actions_np = batch['actions'].astype('int64') - behaviour_logits_np = batch['behaviour_logits'].astype('float32') - rewards_np = batch['rewards'].astype('float32') - dones_np = batch['dones'].astype('float32') - - self.lr = self.lr_scheduler.step() - self.entropy_coeff = self.entropy_coeff_scheduler.step() - - yield [ - obs_np, actions_np, behaviour_logits_np, rewards_np, dones_np, - self.lr, self.entropy_coeff - ] + sample_data = self.sample_data_queue.get() + self.sample_total_steps += sample_data['obs'].shape[0] + self.batch_buffer.append(sample_data) + + buffer_size = sum( + [data['obs'].shape[0] for data in self.batch_buffer]) + if buffer_size >= self.config['train_batch_size']: + batch = {} + for key in self.batch_buffer[0].keys(): + batch[key] = np.concatenate( + [data[key] for data in self.batch_buffer]) + self.batch_buffer = [] + + obs_np = batch['obs'].astype('float32') + actions_np = batch['actions'].astype('int64') + behaviour_logits_np = batch['behaviour_logits'].astype( + 'float32') + rewards_np = batch['rewards'].astype('float32') + dones_np = batch['dones'].astype('float32') + + self.lr = self.lr_scheduler.step() + self.entropy_coeff = self.entropy_coeff_scheduler.step() + + yield [ + obs_np, actions_np, behaviour_logits_np, rewards_np, + dones_np, self.lr, self.entropy_coeff + ] def run_learn(self): """ Learn loop @@ -170,31 +180,6 @@ class Learner(object): remote_actor.set_params(self.cache_params) - def step(self): - """ Merge and generate batch learn data from sample_data_queue, - and put it in learner_queue. - """ - assert self.learn_thread.is_alive() - - while True: - try: - sample_data = self.sample_data_queue.get_nowait() - self.sample_total_steps += sample_data['obs'].shape[0] - self.batch_buffer.append(sample_data) - - buffer_size = sum( - [data['obs'].shape[0] for data in self.batch_buffer]) - if buffer_size >= self.config['train_batch_size']: - train_batch = {} - for key in self.batch_buffer[0].keys(): - train_batch[key] = np.concatenate( - [data[key] for data in self.batch_buffer]) - self.learner_queue.put(train_batch) - self.batch_buffer = [] - - except queue.Empty: - break - def log_metrics(self): """ Log metrics of learner and actors """ @@ -233,7 +218,6 @@ class Learner(object): 'max_episode_steps': max_episode_steps, 'mean_episode_steps': mean_episode_steps, 'min_episode_steps': min_episode_steps, - 'learner_queue_size': self.learner_queue.qsize(), 'sample_queue_size': self.sample_data_queue.qsize(), 'total_params_sync': self.total_params_sync, 'cache_params_sent_cnt': self.cache_params_sent_cnt, diff --git a/examples/IMPALA/train.py b/examples/IMPALA/train.py index f150751..bf38bb3 100644 --- a/examples/IMPALA/train.py +++ b/examples/IMPALA/train.py @@ -22,9 +22,8 @@ def main(config): try: while True: - start = time.time() - while time.time() - start < config['log_metrics_interval_s']: - learner.step() + time.sleep(config['log_metrics_interval_s']) + learner.log_metrics() except KeyboardInterrupt: diff --git a/parl/__init__.py b/parl/__init__.py index 205164a..c04b1a4 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -14,5 +14,14 @@ """ generates new PARL python API """ -from parl.framework import * + +from parl.utils.utils import _HAS_FLUID + +if _HAS_FLUID: + from parl.framework import * +else: + print( + "WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available." + ) + from parl.remote import remote_class, RemoteManager diff --git a/parl/framework/algorithm_base.py b/parl/framework/algorithm_base.py index 9672134..8b606d8 100644 --- a/parl/framework/algorithm_base.py +++ b/parl/framework/algorithm_base.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.fluid as fluid -import parl.layers as layers from abc import ABCMeta, abstractmethod -from parl.framework.model_base import Network, Model +from parl.framework.model_base import Model __all__ = ['Algorithm'] @@ -23,13 +21,13 @@ __all__ = ['Algorithm'] class Algorithm(object): """ Algorithm defines the way how we update the model. For example, - after defining forward network in `Network` class, you should define how to update the model here. + after defining forward network in `Model` class, you should define how to update the model here. Before creating a customized algorithm, please do check algorithms of PARL. - Most common used algorithms like DQN/DDPG/PPO have been providing in algorithms, go and have a try. - It's easy to use them and just try pl.algorithms.DQN. + Most common used algorithms like DQN/DDPG/PPO/A3C have been providing in algorithms, go and have a try. + It's easy to use them and just try parl.algorithms.DQN. An Algorithm implements two functions: - 1. define_predict() build forward process which was defined in Network + 1. define_predict() build forward process which was defined in `Model` 2. define_learn() computes a cost for optimization An algorithm should be updating part of a network. The user only needs to diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index 6012f52..b30aad1 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -21,3 +21,6 @@ SERIALIZE_EXCEPTION_TAG = b'[SERIALIZE_EXCEPTION]' DESERIALIZE_EXCEPTION_TAG = b'[DESERIALIZE_EXCEPTION]' NORMAL_TAG = b'[NORMAL]' + +# interval of heartbeat mechanism in the unit of second +HEARTBEAT_INTERVAL_S = 10 diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 41a9e2b..2e88a55 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -66,36 +66,23 @@ def remote_class(cls): and waits for requests (e.g. call a function) from server side. """ if remote_ip is None: - client_ip = get_ip_address() - else: - client_ip = remote_ip + remote_ip = get_ip_address() self.zmq_context = zmq.Context() socket = self.zmq_context.socket(zmq.REP) - free_port = None if remote_port is None: - for port in range(6000, 8000): - try: - socket.bind("tcp://*:{}".format(port)) - logger.info( - "[_create_reply_socket] free_port:{}".format(port)) - free_port = port - break - except zmq.error.ZMQError: - logger.warning( - "[_create_reply_socket]cannot bind port:{}, retry". - format(port)) + try: + remote_port = socket.bind_to_random_port(addr="tcp://*") + except zmq.ZMQBindError: + logger.error( + 'Can not bind to a random port, please set remote_port manually.' + ) + sys.exit(1) else: socket.bind("tcp://*:{}".format(remote_port)) - free_port = remote_port - if free_port is not None: - return socket, client_ip, free_port - else: - logger.error( - "cannot find any available port from 6000 to 8000") - sys.exit(1) + return socket, remote_ip, remote_port def _connect_server(self, server_ip, server_port, remote_ip, remote_port): @@ -116,26 +103,31 @@ def remote_class(cls): socket = self.zmq_context.socket(zmq.REQ) socket.connect("tcp://{}:{}".format(server_ip, server_port)) - client_id = np.random.randint(int(1e18)) - logger.info("client_id:{}".format(client_id)) logger.info("connecting {}:{}".format(server_ip, server_port)) - client_info = '{}:{} {}'.format(local_ip, local_port, client_id) + client_addr = '{}:{}'.format(local_ip, local_port) socket.send_multipart( [remote_constants.CONNECT_TAG, - to_byte(client_info)]) + to_byte(client_addr)]) message = socket.recv_multipart() - logger.info("[connect_server] done, message from server:{}".format( - message)) + self.client_id = message[1] + logger.info("connect server done, client_id: {}".format( + self.client_id)) self.connect_socket = socket self.connect_socket.linger = 0 def _exit_remote(self): - # Following release order matters self.poller.unregister(self.connect_socket) - self.zmq_context.destroy() + self.connect_socket.close() + self.reply_socket.close() + + # The program may hang when destroying zmq context manually. + # It will be destroyed automatically by the garbage collection mechanism of python, + # though it may raise some exceptions in C++. + + #self.zmq_context.destroy() def _heartbeat_loop(self): """ @@ -146,7 +138,7 @@ def remote_class(cls): while True: self.connect_socket.send_multipart( - [remote_constants.HEARTBEAT_TAG]) + [remote_constants.HEARTBEAT_TAG, self.client_id]) # wait for at most 10s to receive response socks = dict(self.poller.poll(10000)) @@ -161,7 +153,7 @@ def remote_class(cls): break # HeartBeat interval 10s - time.sleep(10) + time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) def __getattr__(self, attr): """ diff --git a/parl/remote/remote_manager.py b/parl/remote/remote_manager.py index 0d6d3ed..6afe700 100644 --- a/parl/remote/remote_manager.py +++ b/parl/remote/remote_manager.py @@ -14,6 +14,7 @@ import queue import threading +import time import zmq from parl.utils import logger, to_byte, to_str from parl.remote import remote_constants @@ -48,11 +49,16 @@ class RemoteManager(object): self.socket.linger = 0 self.remote_pool = queue.Queue() + self.remote_latest_timestamp = {} t = threading.Thread(target=self._wait_for_connection) t.setDaemon(True) # The thread will exit when main thread exited t.start() + t = threading.Thread(target=self._check_remote_status) + t.setDaemon(True) # The thread will exit when main thread exited + t.start() + def _wait_for_connection(self): """ A never-ending function keeps waiting for the connections from remote client. @@ -61,6 +67,7 @@ class RemoteManager(object): Note that this function has been called inside the `__init__` function. """ + remote_id = 0 while True: try: message = self.socket.recv_multipart() @@ -68,17 +75,23 @@ class RemoteManager(object): if tag == remote_constants.CONNECT_TAG: self.socket.send_multipart([ - remote_constants.NORMAL_TAG, b'Connect server success.' + remote_constants.NORMAL_TAG, + to_byte('{}'.format(remote_id)) ]) - client_info = to_str(message[1]) - remote_client_address, remote_client_id = client_info.split( - ) + remote_client_address = to_str(message[1]) remote_obj = RemoteObject(remote_client_address, - remote_client_id, self.zmq_context) + + self.remote_latest_timestamp[to_byte( + str(remote_id))] = time.time() + logger.info('[RemoteManager] Added a new remote object.') self.remote_pool.put(remote_obj) + + remote_id += 1 + elif tag == remote_constants.HEARTBEAT_TAG: + self.remote_latest_timestamp[message[1]] = time.time() self.socket.send_multipart( [remote_constants.NORMAL_TAG, b'Server is alive.']) else: @@ -88,6 +101,17 @@ class RemoteManager(object): logger.warning('Zmq error, exiting server.') break + def _check_remote_status(self): + while True: + for remote_id in list(self.remote_latest_timestamp.keys()): + if time.time() - self.remote_latest_timestamp[ + remote_id] > 3 * remote_constants.HEARTBEAT_INTERVAL_S: + logger.error( + 'Remote object {} is lost, please check if anything wrong happens in the remote client' + .format(remote_id)) + self.remote_latest_timestamp.pop(remote_id) + time.sleep(3 * remote_constants.HEARTBEAT_INTERVAL_S) + def get_remote(self): """ A blocking function to obtain a remote object. diff --git a/parl/remote/remote_object.py b/parl/remote/remote_object.py index 9f9e7b1..319e905 100644 --- a/parl/remote/remote_object.py +++ b/parl/remote/remote_object.py @@ -25,15 +25,11 @@ class RemoteObject(object): Provides interface to call functions of object in remote client. """ - def __init__(self, - remote_client_address, - remote_client_id, - zmq_context=None): + def __init__(self, remote_client_address, zmq_context=None): """ Args: remote_client_address: address(ip:port) of remote client - remote_client_id: id of remote client - + zmq_context: zmq.Context() """ if zmq_context is None: self.zmq_context = zmq.Context() @@ -44,7 +40,6 @@ class RemoteObject(object): self.command_socket = None # lock for thread safety self.internal_lock = threading.Lock() - self.client_id = remote_client_id self._connect_remote_client(remote_client_address) def _connect_remote_client(self, remote_client_address): diff --git a/parl/utils/utils.py b/parl/utils/utils.py index 206f9bc..c5ccc0d 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -72,3 +72,9 @@ def is_PY3(): MAX_INT32 = 0x7fffffff + +try: + from paddle import fluid + _HAS_FLUID = True +except ImportError: + _HAS_FLUID = False -- GitLab