diff --git a/.teamcity/Dockerfile b/.teamcity/Dockerfile index 612023ba86a5d65bed69df94ce8c317526698ba5..85d8a2f37e0db3472d1ad02ab06fb7a80fd72c3a 100644 --- a/.teamcity/Dockerfile +++ b/.teamcity/Dockerfile @@ -15,22 +15,7 @@ # A dev image based on paddle production image -FROM paddlepaddle/paddle:1.3.0-gpu-cuda9.0-cudnn7 - -Run apt-get update -RUN apt-get install --fix-missing -y cmake - -# Prepare packages for Python -RUN apt-get install --fix-missing -y make build-essential libssl-dev zlib1g-dev libbz2-dev \ - libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \ - xz-utils tk-dev libffi-dev liblzma-dev - -# Install python3.6 and pip3.6 -RUN wget -q https://www.python.org/ftp/python/3.6.0/Python-3.6.0.tgz && \ - tar -xzf Python-3.6.0.tgz && cd Python-3.6.0 && \ - CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \ - make -j8 > /dev/null && make altinstall > /dev/null && \ - cp libpython3.6m.so.1.0 /usr/lib/ && cp libpython3.6m.so.1.0 /usr/local/lib/ +FROM parl/parl-test:cuda9.0-cudnn7 COPY ./requirements.txt /root/ @@ -39,4 +24,3 @@ RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirement # Requirements for python3 RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt -RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle-gpu==1.3.0.post97 diff --git a/.teamcity/requirements.txt b/.teamcity/requirements.txt index 87a70c66fd1cc7e1a4d9d23674c219f32c43d1d3..47cf56b71616c94412b8a30801c8738455ad5045 100644 --- a/.teamcity/requirements.txt +++ b/.teamcity/requirements.txt @@ -1,3 +1,6 @@ +paddlepaddle-gpu==1.3.0.post97 gym details termcolor +pyarrow +zmq diff --git a/parl/__init__.py b/parl/__init__.py index cec8b257112c61ff2fb6a1bb80a9afa8fb698d9c..205164a3d0aa1bf35d0f75f2ef92431979b804a0 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -15,5 +15,4 @@ generates new PARL python API """ from parl.framework import * -from parl.utils import * -from parl.plutils import * +from parl.remote import remote_class, RemoteManager diff --git a/parl/remote/__init__.py b/parl/remote/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfec0c3bffe1802d4821ac6fa87ae945113fa557 --- /dev/null +++ b/parl/remote/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parl.remote.exceptions import * +from parl.remote.remote_decorator import * +from parl.remote.remote_manager import * +from parl.remote.remote_object import * diff --git a/parl/remote/decorator.py b/parl/remote/decorator.py deleted file mode 100644 index 08ae7157fbef745c16ab3bf52b3f24ac6f975f15..0000000000000000000000000000000000000000 --- a/parl/remote/decorator.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import zmq -from parl.utils import logger -import numpy as np -from parl.utils.communication import dumps_argument, loads_argument -from parl.utils.communication import dumps_return, loads_return -from parl.utils.machine_info import get_ip_address -import pyarrow -import threading -""" -Three steps to create a remote class -- -1. add a decroator(@virtual) before the definition of the class. -2. create an instance of remote class -3. call function `remote_run` with server address - -@virtual -Class Simulator(object): - ... - -sim = Simulator() -sim.remote_run(server_ip='172.18.202.45', port=8001) - -""" - - -def virtual(cls, location='client'): - """ - Class wrapper for wrapping a normal class as a remote class that can run in different machines. - Two kinds of wrapper are provided for the client as well as the server. - - Args: - location(str): specify which wrapper to use, available locations: client/server. - users are expected to use `client`. - """ - assert location in ['client', 'server'], \ - 'Remote Class has to be placed at client side or server side.' - - class ClientWrapper(object): - """ - Wrapper for remote class at client side. After the decoration, - the initial class is able to be called to run any function at sever side. - """ - - def __init__(self, *args): - """ - Args: - args: arguments for the initialisation of the initial class. - """ - self.unwrapped = cls(*args) - self.conect_socket = None - self.reply_socket = None - - def create_reply_socket(self): - """ - In fact, we have also a socket server in client side. This server keeps running - and waits for requests (e.g. call a function) from server side. - """ - client_ip = get_ip_address() - context = zmq.Context() - socket = context.socket(zmq.REP) - free_port = None - for port in range(6000, 8000): - try: - socket.bind("tcp://*:{}".format(port)) - logger.info( - "[create_reply_socket] free_port:{}".format(port)) - free_port = port - break - except zmq.error.ZMQError: - logger.warn( - "[create_reply_socket]cannot bind port:{}, retry". - format(port)) - if free_port is not None: - return socket, client_ip, free_port - else: - logger.error( - "cannot find any available port from 6000 to 8000") - sys.exit(1) - - def connect_server(self, server_ip, server_port): - """ - create the connection between client side and server side. - - Args: - server_ip(str): the ip of the server. - server_port(int): the connection port of the server. - """ - self.reply_socket, local_ip, local_port = self.create_reply_socket( - ) - - logger.info("connecting {}:{}".format(server_ip, server_port)) - context = zmq.Context() - socket = context.socket(zmq.REQ) - socket.connect("tcp://{}:{}".format(server_ip, server_port)) - client_id = np.random.randint(int(1e18)) - logger.info("client_id:{}".format(client_id)) - socket.send_string('{}:{} {}'.format(local_ip, local_port, - client_id)) - message = socket.recv_string() - logger.info("[connect_server] done, message from server:{}".format( - message)) - self.connect_socket = socket - - def __getattr__(self, attr): - """ - Call the function of the initial class. The wrapped class do not have - same functions as the unwrapped one. - We have to call the function of the function in unwrapped class, - This implementation utilise a function wrapper. - """ - if hasattr(self.unwrapped, attr): - - def wrapper(*args, **kw): - return getattr(self.unwrapped, attr)(*args, **kw) - - return wrapper - raise AttributeError(attr) - - def remote_run(self, server_ip, server_port): - """ - connect server and wait for requires of running functions from server side. - - Args: - server_ip(str): server's ip - server_port(int): server's port - """ - self.connect_server(server_ip, server_port) - while True: - function_name = self.reply_socket.recv_string() - self.reply_socket.send_string("OK") - data = self.reply_socket.recv() - args, kw = loads_argument(data) - ret = getattr(self.unwrapped, function_name)(*args, **kw) - ret = dumps_return(ret) - self.reply_socket.send(ret) - - class ServerWrapper(object): - """ - Wrapper for remote class at server side. - """ - - def __init__(self, *args): - """ - Args: - args: arguments used to initialize the initial class - """ - self.unwrapped = (cls(*args)).unwrapped - self.command_socket = None - self.internal_lock = threading.Lock() - - def connect_client(self, client_info): - """ - build another connection with the client to send command to the client. - """ - client_address, client_id = client_info.split(' ') - context = zmq.Context() - socket = context.socket(zmq.REQ) - logger.info( - "[connect_client] client_address:{}".format(client_address)) - socket.connect("tcp://{}".format(client_address)) - self.command_socket = socket - self.client_id = client_id - - def __getattr__(self, attr): - """ - Run the function at client side. we also implement this through a wrapper. - - Args: - attr(str): a function name specify which function to run. - """ - if hasattr(self.unwrapped, attr): - - def wrapper(*args, **kw): - self.internal_lock.acquire() - self.command_socket.send_string(attr) - self.command_socket.recv_string() - data = dumps_argument(*args, **kw) - self.command_socket.send(data) - ret = self.command_socket.recv() - ret = loads_return(ret) - self.internal_lock.release() - return ret - - return wrapper - raise NotImplementedError() - - if location == 'client': - return ClientWrapper - else: - return ServerWrapper diff --git a/parl/remote/exceptions.py b/parl/remote/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..60b404836bd7014862abb03ade791bd05934c1ca --- /dev/null +++ b/parl/remote/exceptions.py @@ -0,0 +1,62 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class RemoteError(Exception): + """ + Super class of exceptions in remote module. + """ + + def __init__(self, func_name, error_info): + self.error_info = "[PARL remote error when calling function `{}`]:\n{}".format( + func_name, error_info) + + def __str__(self): + return self.error_info + + +class RemoteSerializeError(RemoteError): + """ + Serialize error from remote + """ + + def __init__(self, func_name, error_info): + super(RemoteSerializeError, self).__init__(func_name, error_info) + + def __str__(self): + return self.error_info + + +class RemoteDeserializeError(RemoteError): + """ + Deserialize error from remote + """ + + def __init__(self, func_name, error_info): + super(RemoteDeserializeError, self).__init__(func_name, error_info) + + def __str__(self): + return self.error_info + + +class RemoteAttributeError(RemoteError): + """ + Attribute error from remote + """ + + def __init__(self, func_name, error_info): + super(RemoteAttributeError, self).__init__(func_name, error_info) + + def __str__(self): + return self.error_info diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6012f527ef256b0c1235f0556bde7810bb82e6cf --- /dev/null +++ b/parl/remote/remote_constants.py @@ -0,0 +1,23 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CONNECT_TAG = b'[CONNECT]' +HEARTBEAT_TAG = b'[HEARTBEAT]' + +EXCEPTION_TAG = b'[EXCEPTION]' +ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' +SERIALIZE_EXCEPTION_TAG = b'[SERIALIZE_EXCEPTION]' +DESERIALIZE_EXCEPTION_TAG = b'[DESERIALIZE_EXCEPTION]' + +NORMAL_TAG = b'[NORMAL]' diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..a81b45bc5d4aababddbd9aebd496730f428267b5 --- /dev/null +++ b/parl/remote/remote_decorator.py @@ -0,0 +1,250 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pyarrow +import threading +import time +import zmq +from parl.remote import remote_constants +from parl.utils import get_ip_address, logger, to_str, to_byte +from parl.utils.exceptions import SerializeError, DeserializeError +from parl.utils.communication import loads_argument, dumps_return +""" +Three steps to create a remote class: +1. add a decroator(@parl.remote_class) before the definition of the class; +2. create an instance of remote class; +3. call function `as_remote` with server_ip and server_port. + +@parl.remote_class +Class Simulator(object): + ... + +sim = Simulator() +sim.as_remote(server_ip='172.18.202.45', port=8001) + +""" + + +def remote_class(cls): + class ClientWrapper(object): + """ + Wrapper for remote class in client side. + when as_remote function called, the object initialized in the client can + handle function call from server. + """ + + def __init__(self, *args, **kwargs): + """ + Args: + args, kwargs: arguments for the initialisation of the unwrapped class. + """ + self.unwrapped = cls(*args, **kwargs) + + self.zmq_context = None + self.poller = None + + # socket for connecting server and telling ip and port of client to server + self.connect_socket = None + # socket for handle function call from server side + self.reply_socket = None + + def _create_reply_socket(self, remote_ip, remote_port): + """ + In fact, we also have a socket server in client side. This server keeps running + and waits for requests (e.g. call a function) from server side. + """ + if remote_ip is None: + client_ip = get_ip_address() + else: + client_ip = remote_ip + + self.zmq_context = zmq.Context() + socket = self.zmq_context.socket(zmq.REP) + + free_port = None + if remote_port is None: + for port in range(6000, 8000): + try: + socket.bind("tcp://*:{}".format(port)) + logger.info( + "[_create_reply_socket] free_port:{}".format(port)) + free_port = port + break + except zmq.error.ZMQError: + logger.warning( + "[_create_reply_socket]cannot bind port:{}, retry". + format(port)) + else: + socket.bind("tcp://*:{}".format(remote_port)) + free_port = remote_port + + if free_port is not None: + return socket, client_ip, free_port + else: + logger.error( + "cannot find any available port from 6000 to 8000") + sys.exit(1) + + def _connect_server(self, server_ip, server_port, remote_ip, + remote_port): + """ + Create the connection between client side and server side. + + Args: + server_ip(str): the ip of the server. + server_port(int): the connection port of the server. + remote_ip: the ip of the client itself. + remote_port: the port of the client itself, + which used to create reply socket. + """ + self.reply_socket, local_ip, local_port = self._create_reply_socket( + remote_ip, remote_port) + self.reply_socket.linger = 0 + + socket = self.zmq_context.socket(zmq.REQ) + socket.connect("tcp://{}:{}".format(server_ip, server_port)) + + client_id = np.random.randint(int(1e18)) + logger.info("client_id:{}".format(client_id)) + logger.info("connecting {}:{}".format(server_ip, server_port)) + + client_info = '{}:{} {}'.format(local_ip, local_port, client_id) + socket.send_multipart( + [remote_constants.CONNECT_TAG, + to_byte(client_info)]) + + message = socket.recv_multipart() + logger.info("[connect_server] done, message from server:{}".format( + message)) + self.connect_socket = socket + self.connect_socket.linger = 0 + + def _exit_remote(self): + # Following release order matters + self.poller.unregister(self.connect_socket) + + self.zmq_context.destroy() + + def _heartbeat_loop(self): + """ + Periodically detect whether the server is alive or not + """ + self.poller = zmq.Poller() + self.poller.register(self.connect_socket, zmq.POLLIN) + + while True: + self.connect_socket.send_multipart( + [remote_constants.HEARTBEAT_TAG]) + + # wait for at most 10s to receive response + socks = dict(self.poller.poll(10000)) + + if socks.get(self.connect_socket) == zmq.POLLIN: + _ = self.connect_socket.recv_multipart() + else: + logger.warning( + '[HeartBeat] Server no response, will exit now!') + self._exit_remote() + + break + + # HeartBeat interval 10s + time.sleep(10) + + def __getattr__(self, attr): + """ + Call the function of the unwrapped class. + """ + + def wrapper(*args, **kwargs): + return getattr(self.unwrapped, attr)(*args, **kwargs) + + return wrapper + + def _reply_loop(self): + while True: + message = self.reply_socket.recv_multipart() + + try: + function_name = to_str(message[1]) + data = message[2] + args, kwargs = loads_argument(data) + ret = getattr(self.unwrapped, function_name)(*args, + **kwargs) + ret = dumps_return(ret) + + except Exception as e: + error_str = str(e) + logger.error(e) + + if type(e) == AttributeError: + self.reply_socket.send_multipart([ + remote_constants.ATTRIBUTE_EXCEPTION_TAG, + to_byte(error_str) + ]) + elif type(e) == SerializeError: + self.reply_socket.send_multipart([ + remote_constants.SERIALIZE_EXCEPTION_TAG, + to_byte(error_str) + ]) + elif type(e) == DeserializeError: + self.reply_socket.send_multipart([ + remote_constants.DESERIALIZE_EXCEPTION_TAG, + to_byte(error_str) + ]) + else: + self.reply_socket.send_multipart([ + remote_constants.EXCEPTION_TAG, + to_byte(error_str) + ]) + + continue + + self.reply_socket.send_multipart( + [remote_constants.NORMAL_TAG, ret]) + + def as_remote(self, + server_ip, + server_port, + remote_ip=None, + remote_port=None): + """ + Client will connect server and wait for function calls from server side. + + Args: + server_ip(str): server's ip + server_port(int): server's port + remote_ip: the ip of the client itself. + remote_port: the port of the client itself, + which used to create reply socket. + """ + self._connect_server(server_ip, server_port, remote_ip, + remote_port) + + reply_thread = threading.Thread(target=self._reply_loop) + reply_thread.setDaemon(True) + reply_thread.start() + + self._heartbeat_loop() + + def remote_closed(self): + """ + Check whether as_remote mode is closed + """ + assert self.reply_socket is not None, 'as_remote function should be called first!' + assert self.connect_socket is not None, 'as_remote function should be called first!' + return self.reply_socket.closed and self.connect_socket.closed + + return ClientWrapper diff --git a/parl/remote/remote_manager.py b/parl/remote/remote_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..0d6d3edd8c77a72beb360775fee87813c6d52c4f --- /dev/null +++ b/parl/remote/remote_manager.py @@ -0,0 +1,105 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading +import zmq +from parl.utils import logger, to_byte, to_str +from parl.remote import remote_constants +from parl.remote.remote_object import RemoteObject +""" +Two steps to build the communication with remote clients: +1. Create a RemoteManager; +2. Get remote objects by calling the function get_remote. + +```python + remote_manager = RemoteManager(port=[port]) + remote_obj = remote_manager.get_remote() +``` + +""" + + +class RemoteManager(object): + """ + Base class for network communcation. + """ + + def __init__(self, port): + """ + Args: + port(int): a local port used for connections from remote clients. + """ + self.zmq_context = zmq.Context() + socket = self.zmq_context.socket(zmq.REP) + socket.bind("tcp://*:{}".format(port)) + self.socket = socket + self.socket.linger = 0 + + self.remote_pool = queue.Queue() + + t = threading.Thread(target=self._wait_for_connection) + t.setDaemon(True) # The thread will exit when main thread exited + t.start() + + def _wait_for_connection(self): + """ + A never-ending function keeps waiting for the connections from remote client. + It will put an available remote object in an internel pool, and remote object + can be obtained by calling `get_remote`. + + Note that this function has been called inside the `__init__` function. + """ + while True: + try: + message = self.socket.recv_multipart() + tag = message[0] + + if tag == remote_constants.CONNECT_TAG: + self.socket.send_multipart([ + remote_constants.NORMAL_TAG, b'Connect server success.' + ]) + client_info = to_str(message[1]) + remote_client_address, remote_client_id = client_info.split( + ) + remote_obj = RemoteObject(remote_client_address, + remote_client_id, + self.zmq_context) + logger.info('[RemoteManager] Added a new remote object.') + self.remote_pool.put(remote_obj) + elif tag == remote_constants.HEARTBEAT_TAG: + self.socket.send_multipart( + [remote_constants.NORMAL_TAG, b'Server is alive.']) + else: + raise NotImplementedError() + + except zmq.ZMQError: + logger.warning('Zmq error, exiting server.') + break + + def get_remote(self): + """ + A blocking function to obtain a remote object. + + Returns: + RemoteObject + """ + return self.remote_pool.get() + + def close(self): + """ + Close RemoteManager. + """ + + self.zmq_context.destroy() diff --git a/parl/remote/remote_object.py b/parl/remote/remote_object.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9e7b1fe5f8bf772e3c5801678521c3ef211fbe --- /dev/null +++ b/parl/remote/remote_object.py @@ -0,0 +1,103 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import zmq +from parl.remote import remote_constants +from parl.remote.exceptions import * +from parl.utils import logger, to_str, to_byte +from parl.utils.communication import dumps_argument, loads_return + + +class RemoteObject(object): + """ + Provides interface to call functions of object in remote client. + """ + + def __init__(self, + remote_client_address, + remote_client_id, + zmq_context=None): + """ + Args: + remote_client_address: address(ip:port) of remote client + remote_client_id: id of remote client + + """ + if zmq_context is None: + self.zmq_context = zmq.Context() + else: + self.zmq_context = zmq_context + + # socket for sending function call to remote object and receiving result + self.command_socket = None + # lock for thread safety + self.internal_lock = threading.Lock() + self.client_id = remote_client_id + self._connect_remote_client(remote_client_address) + + def _connect_remote_client(self, remote_client_address): + """ + Build connection with the remote client to send function call. + """ + socket = self.zmq_context.socket(zmq.REQ) + logger.info("[connect_remote_client] client_address:{}".format( + remote_client_address)) + socket.connect("tcp://{}".format(remote_client_address)) + self.command_socket = socket + self.command_socket.linger = 0 + + def __getattr__(self, attr): + """ + Provides interface to call functions of object in remote client. + 1. send fucntion name and packed auguments to remote client; + 2. remote clinet execute the function of the object really; + 3. receive function return from remote client. + + Args: + attr(str): a function name specify which function to run. + """ + + def wrapper(*args, **kwargs): + self.internal_lock.acquire() + + data = dumps_argument(*args, **kwargs) + + self.command_socket.send_multipart( + [remote_constants.NORMAL_TAG, + to_byte(attr), data]) + + message = self.command_socket.recv_multipart() + tag = message[0] + if tag == remote_constants.NORMAL_TAG: + ret = loads_return(message[1]) + elif tag == remote_constants.EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteError(attr, error_str) + elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteAttributeError(attr, error_str) + elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteSerializeError(attr, error_str) + elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteDeserializeError(attr, error_str) + else: + raise NotImplementedError() + + self.internal_lock.release() + return ret + + return wrapper diff --git a/parl/remote/server.py b/parl/remote/server.py deleted file mode 100644 index 8c9b6be3b6ad0b22aaeb1ee5aff0f75c3b4e64dc..0000000000000000000000000000000000000000 --- a/parl/remote/server.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import queue -from decorator import virtual -from parl.utils import logger -""" -3 steps to finish communication with remote clients. -1. Create a server: -2. Declare the type of remote client -3. Get remote clients - -```python - server = Server() - server.bind(Simulator) - remote_client = server.get_client() -``` - -""" - - -class Server(object): - """ - Base class for network communcation. - """ - - def __init__(self, port): - """ - Args: - port(int): a local port used for network communication. - """ - context = zmq.Context() - socket = context.socket(zmq.REP) - socket.bind("tcp://*:{}".format(port)) - self.socket = socket - self.pool = queue.Queue() - self.cls = None - t = threading.Thread(target=self.wait_for_connection) - t.start() - - def wait_for_connection(self): - """ - A never-ending function keeps waiting for the connection for remote client. - It will put an available remote client into an internel client pool, and clients - can be obtained by calling `get_client`. - - Note that this function has been called inside the `__init__` function. - """ - while True: - client_info = self.socket.recv_string() - client_id = client_info.split(' ')[1] - new_client = virtual(self.cls, location='server')() - self.socket.send_string('Hello World! Client:{}'.format(client_id)) - new_client.connect_client(client_info) - self.pool.put(new_client) - - def get_client(self): - """ - A blocking function to obtain a connected client. - - Returns: - remote_client(self.cls): a **remote** instance that has all functions as the real one. - """ - return self.pool.get() - - def register_client(self, cls): - """ - Declare the type of remote class. - Let the server know which class to use as a remote client. - - Args: - cls(Class): A class decorated by @virtual. - """ - self.cls = cls diff --git a/parl/remote/tests/remote_decorator_test.py b/parl/remote/tests/remote_decorator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c5d70b2afbf446e9f4f4b78210d06d5d1365dd75 --- /dev/null +++ b/parl/remote/tests/remote_decorator_test.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import unittest + + +@parl.remote_class +class Simulator: + def __init__(self, arg1, arg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + + def get_arg1(self): + return self.arg1 + + def get_arg2(self): + return self.arg2 + + def set_arg1(self, value): + self.arg1 = value + + def set_arg2(self, value): + self.arg2 = value + + +class TestRemoteDecorator(unittest.TestCase): + def test_instance_in_local(self): + local_sim = Simulator(1, 2) + + self.assertEqual(local_sim.get_arg1(), 1) + self.assertEqual(local_sim.get_arg2(), 2) + + local_sim.set_arg1(3) + local_sim.set_arg2(4) + + self.assertEqual(local_sim.get_arg1(), 3) + self.assertEqual(local_sim.get_arg2(), 4) + + def test_instance_in_local_with_wrong_getattr_get_variable(self): + local_sim = Simulator(1, 2) + + try: + local_sim.get_arg3() + except AttributeError: + return + + assert False # This line should not be executed. + + def test_instance_in_local_with_wrong_getattr_set_variable(self): + local_sim = Simulator(1, 2) + + try: + local_sim.set_arg3(3) + except AttributeError: + return + + assert False # This line should not be executed. + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/remote_test.py b/parl/remote/tests/remote_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf21d04829f5adb48d50bbbc41615de59ee5dd1 --- /dev/null +++ b/parl/remote/tests/remote_test.py @@ -0,0 +1,240 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import parl +import threading +import unittest +from parl.remote import * + + +@parl.remote_class +class Simulator: + def __init__(self, arg1, arg2=None): + self.arg1 = arg1 + self.arg2 = arg2 + + def get_arg1(self): + return self.arg1 + + def get_arg2(self): + return self.arg2 + + def set_arg1(self, value): + self.arg1 = value + + def set_arg2(self, value): + self.arg2 = value + + def get_unable_serialize_object(self): + return self + + +class TestRemote(unittest.TestCase): + def _setUp(self, server_port): + self.sim = Simulator(1, arg2=2) + + # run client in a new thread to fake a remote client + self.client_thread = threading.Thread( + target=self.sim.as_remote, args=( + 'localhost', + server_port, + )) + self.client_thread.setDaemon(True) + self.client_thread.start() + + self.remote_manager = RemoteManager(port=server_port) + + def test_remote_object(self): + server_port = 17770 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + self.assertEqual(remote_sim.get_arg1(), 1) + self.assertEqual(remote_sim.get_arg2(), 2) + + ret = remote_sim.set_arg1(3) + self.assertIsNone(ret) + ret = remote_sim.set_arg2(4) + self.assertIsNone(ret) + + self.assertEqual(remote_sim.get_arg1(), 3) + self.assertEqual(remote_sim.get_arg2(), 4) + + def test_remote_object_with_wrong_getattr_get_variable(self): + server_port = 17771 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.get_arg3() + except RemoteAttributeError: + # expected + return + + assert False + + def test_remote_object_with_wrong_getattr_set_variable(self): + server_port = 17772 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.set_arg3(3) + except RemoteAttributeError: + # expected + return + + assert False + + def test_remote_object_with_wrong_argument(self): + server_port = 17773 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.set_arg1(wrong_arg=1) + except RemoteError: + # expected + return + + assert False + + def test_remote_object_with_unable_serialize_argument(self): + server_port = 17774 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.set_arg1(wrong_arg=remote_sim) + except SerializeError: + # expected + return + + assert False + + def test_remote_object_with_unable_serialize_return(self): + server_port = 17775 + self._setUp(server_port) + + remote_sim = self.remote_manager.get_remote() + + try: + remote_sim.get_unable_serialize_object() + except RemoteSerializeError: + # expected + return + + assert False + + def test_mutli_remote_object(self): + server_port = 17776 + self._setUp(server_port) + + time.sleep(1) + # run second client + sim2 = Simulator(11, arg2=22) + client_thread2 = threading.Thread( + target=sim2.as_remote, args=( + 'localhost', + server_port, + )) + client_thread2.setDaemon(True) + client_thread2.start() + + time.sleep(1) + remote_sim1 = self.remote_manager.get_remote() + remote_sim2 = self.remote_manager.get_remote() + + self.assertEqual(remote_sim1.get_arg1(), 1) + self.assertEqual(remote_sim2.get_arg1(), 11) + + def test_mutli_remote_object_with_one_failed(self): + server_port = 17777 + self._setUp(server_port) + + time.sleep(1) + # run second client + sim2 = Simulator(11, arg2=22) + client_thread2 = threading.Thread( + target=sim2.as_remote, args=( + 'localhost', + server_port, + )) + client_thread2.setDaemon(True) + client_thread2.start() + + time.sleep(1) + remote_sim1 = self.remote_manager.get_remote() + remote_sim2 = self.remote_manager.get_remote() + + try: + # make remote sim1 failed + remote_sim1.get_arg3() + except: + pass + + self.assertEqual(remote_sim2.get_arg1(), 11) + + # Todo(@zenghongsheng): + # zmq will raise unexpected C++ exception when closing context, + # remove this unittest for now. + #def test_heartbeat_after_server_closed(self): + # server_port = 17778 + # self._setUp(server_port) + + # remote_sim = self.remote_manager.get_remote() + + # time.sleep(1) + # self.remote_manager.close() + + # # heartbeat interval (10s) + max waiting reply (10s) + # time.sleep(20) + + # logger.info('check self.sim.remote_closed') + # self.assertTrue(self.sim.remote_closed()) + + def test_set_client_ip_port_manually(self): + server_port = 17779 + self._setUp(server_port) + + time.sleep(1) + # run second client + sim2 = Simulator(11, arg2=22) + client_thread2 = threading.Thread( + target=sim2.as_remote, + args=( + 'localhost', + server_port, + 'localhost', + 6666, + )) + client_thread2.setDaemon(True) + client_thread2.start() + + time.sleep(1) + remote_sim1 = self.remote_manager.get_remote() + remote_sim2 = self.remote_manager.get_remote() + + self.assertEqual(remote_sim1.get_arg1(), 1) + self.assertEqual(remote_sim2.get_arg1(), 11) + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/utils/__init__.py b/parl/utils/__init__.py index 6da41cdf7c8f4aa9486aa46bda3f38b4bc29c452..22d164d1fc6628bcab945a181a13498bd1f8c53a 100644 --- a/parl/utils/__init__.py +++ b/parl/utils/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parl.utils.exceptions import * from parl.utils.utils import * from parl.utils.machine_info import * from parl.utils.replay_memory import * diff --git a/parl/utils/communication.py b/parl/utils/communication.py index 010194568a2036fb82d041791f0cdaaf0367314b..d6ec3e7e3f3efe7f4e89e5d8caf1de15b3e214bc 100644 --- a/parl/utils/communication.py +++ b/parl/utils/communication.py @@ -13,6 +13,7 @@ # limitations under the License. import pyarrow +from parl.utils import SerializeError, DeserializeError __all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return'] @@ -28,7 +29,12 @@ def dumps_argument(*args, **kwargs): Returns: Implementation-dependent object in bytes. """ - return pyarrow.serialize([args, kwargs]).to_buffer() + try: + ret = pyarrow.serialize([args, kwargs]).to_buffer() + except Exception as e: + raise SerializeError(e) + + return ret def loads_argument(data): @@ -42,7 +48,12 @@ def loads_argument(data): deserialized arguments [args, kwargs] like the input of `dumps_argument`, args is a tuple, and kwargs is a dict """ - return pyarrow.deserialize(data) + try: + ret = pyarrow.deserialize(data) + except Exception as e: + raise DeserializeError(e) + + return ret def dumps_return(data): @@ -55,7 +66,12 @@ def dumps_return(data): Returns: Implementation-dependent object in bytes. """ - return pyarrow.serialize(data).to_buffer() + try: + ret = pyarrow.serialize(data).to_buffer() + except Exception as e: + raise SerializeError(e) + + return ret def loads_return(data): @@ -68,4 +84,9 @@ def loads_return(data): Returns: deserialized data """ - return pyarrow.deserialize(data) + try: + ret = pyarrow.deserialize(data) + except Exception as e: + raise DeserializeError(e) + + return ret diff --git a/parl/utils/exceptions.py b/parl/utils/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..023cc5c1073a459b853a1652f43eccd640f70347 --- /dev/null +++ b/parl/utils/exceptions.py @@ -0,0 +1,46 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class UtilsError(Exception): + """ + Super class of exceptions in utils module. + """ + + def __init__(self, error_info): + self.error_info = '[PARL Utils Error]:\n{}'.format(error_info) + + +class SerializeError(UtilsError): + """ + Serialize error raised by pyarrow. + """ + + def __init__(self, error_info): + super(SerializeError, self).__init__(error_info) + + def __str__(self): + return self.error_info + + +class DeserializeError(UtilsError): + """ + Deserialize error raised by pyarrow. + """ + + def __init__(self, error_info): + super(DeserializeError, self).__init__(error_info) + + def __str__(self): + return self.error_info diff --git a/parl/utils/machine_info.py b/parl/utils/machine_info.py index a29229bb85e5ef5e052b36e4fb8177d60daa93d8..b7baeda8bc9a96646b96bd71791b5f4de02f64dc 100644 --- a/parl/utils/machine_info.py +++ b/parl/utils/machine_info.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import platform import subprocess from parl.utils import logger @@ -23,11 +24,35 @@ def get_ip_address(): """ get the IP address of the host. """ + platform_sys = platform.system() + + # Only support Linux and MacOS + if platform_sys != 'Linux' and platform_sys != 'Darwin': + logger.warning( + 'get_ip_address only support Linux and MacOS, please set ip address manually.' + ) + return None + + local_ip = None import socket - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - s.connect(("8.8.8.8", 80)) - local_ip = s.getsockname()[0] - s.close() + try: + # First way, tested in Ubuntu and MacOS + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + except: + # Second way, tested in CentOS + try: + local_ip = socket.gethostbyname(socket.gethostname()) + except: + pass + + if local_ip == None or local_ip == '127.0.0.1' or local_ip == '127.0.1.1': + logger.warning( + 'get_ip_address failed, please set ip address manually.') + return None + return local_ip diff --git a/parl/utils/utils.py b/parl/utils/utils.py index 94e1ab6bb7caed564f9398848a40eb860b3c2077..a604d116877059f259f28176881e852ed730c202 100644 --- a/parl/utils/utils.py +++ b/parl/utils/utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['has_func', 'action_mapping'] +import sys + +__all__ = [ + 'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3' +] def has_func(obj, fun): @@ -44,3 +48,23 @@ def action_mapping(model_output_act, low_bound, high_bound): action = low_bound + (model_output_act - (-1.0)) * ( (high_bound - low_bound) / 2.0) return action + + +def to_str(byte): + """ convert byte to string in pytohn2/3 + """ + return str(byte.decode()) + + +def to_byte(string): + """ convert byte to string in pytohn2/3 + """ + return string.encode() + + +def is_PY2(): + return sys.version_info[0] == 2 + + +def is_PY3(): + return sys.version_info[0] == 3