提交 348db1fb 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

new feature: parl.remote (#54)

* refine remote module, add heartbeat machanism and unittest

* yapf

* yapf

* support get ip address in CentOS, add dependence

* yapf

* add dependence in Dockerfile

* refine message_tag, Compatible with Python2 and python3

* refine unittest and comments

* remove ParlError, use to_pybytes api to compatible with Python 2 and python 3

* Not need to use to_pybytes

* use parl-test docker image for unittest, which has python2 and python3 env

* test different release order of sockets

* test for different closing way fo context and socket

* tmp commit for debug in teamcity

* tmp commit for debug in teamcity

* tmp commit for debug in teamcity

* use zmq.context destroy to close multi-thread socket, refine RemoteError

* set linger=0 for command socket in RemoteObject

* remove close context unittest

* fix codestyle

* fix codestyle

* rename parl.remote to parl.remote_class; will not exit client when having errors in function call; use sepereate server port in unittest to avoiding closing server manually

* rename parl.remote to parl.remote_class; will not exit client when having errors in function call; use sepereate server port in unittest to avoiding closing server manually

* fix typo

* remove unnecessary try/except in reply loop of client

* import RemoteManager to parl; refine comment
上级 e80604f8
......@@ -15,22 +15,7 @@
# A dev image based on paddle production image
FROM paddlepaddle/paddle:1.3.0-gpu-cuda9.0-cudnn7
Run apt-get update
RUN apt-get install --fix-missing -y cmake
# Prepare packages for Python
RUN apt-get install --fix-missing -y make build-essential libssl-dev zlib1g-dev libbz2-dev \
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
xz-utils tk-dev libffi-dev liblzma-dev
# Install python3.6 and pip3.6
RUN wget -q https://www.python.org/ftp/python/3.6.0/Python-3.6.0.tgz && \
tar -xzf Python-3.6.0.tgz && cd Python-3.6.0 && \
CFLAGS="-Wformat" ./configure --prefix=/usr/local/ --enable-shared > /dev/null && \
make -j8 > /dev/null && make altinstall > /dev/null && \
cp libpython3.6m.so.1.0 /usr/lib/ && cp libpython3.6m.so.1.0 /usr/local/lib/
FROM parl/parl-test:cuda9.0-cudnn7
COPY ./requirements.txt /root/
......@@ -39,4 +24,3 @@ RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirement
# Requirements for python3
RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /root/requirements.txt
RUN pip3.6 install -i https://pypi.tuna.tsinghua.edu.cn/simple paddlepaddle-gpu==1.3.0.post97
paddlepaddle-gpu==1.3.0.post97
gym
details
termcolor
pyarrow
zmq
......@@ -15,5 +15,4 @@
generates new PARL python API
"""
from parl.framework import *
from parl.utils import *
from parl.plutils import *
from parl.remote import remote_class, RemoteManager
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.remote.exceptions import *
from parl.remote.remote_decorator import *
from parl.remote.remote_manager import *
from parl.remote.remote_object import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq
from parl.utils import logger
import numpy as np
from parl.utils.communication import dumps_argument, loads_argument
from parl.utils.communication import dumps_return, loads_return
from parl.utils.machine_info import get_ip_address
import pyarrow
import threading
"""
Three steps to create a remote class --
1. add a decroator(@virtual) before the definition of the class.
2. create an instance of remote class
3. call function `remote_run` with server address
@virtual
Class Simulator(object):
...
sim = Simulator()
sim.remote_run(server_ip='172.18.202.45', port=8001)
"""
def virtual(cls, location='client'):
"""
Class wrapper for wrapping a normal class as a remote class that can run in different machines.
Two kinds of wrapper are provided for the client as well as the server.
Args:
location(str): specify which wrapper to use, available locations: client/server.
users are expected to use `client`.
"""
assert location in ['client', 'server'], \
'Remote Class has to be placed at client side or server side.'
class ClientWrapper(object):
"""
Wrapper for remote class at client side. After the decoration,
the initial class is able to be called to run any function at sever side.
"""
def __init__(self, *args):
"""
Args:
args: arguments for the initialisation of the initial class.
"""
self.unwrapped = cls(*args)
self.conect_socket = None
self.reply_socket = None
def create_reply_socket(self):
"""
In fact, we have also a socket server in client side. This server keeps running
and waits for requests (e.g. call a function) from server side.
"""
client_ip = get_ip_address()
context = zmq.Context()
socket = context.socket(zmq.REP)
free_port = None
for port in range(6000, 8000):
try:
socket.bind("tcp://*:{}".format(port))
logger.info(
"[create_reply_socket] free_port:{}".format(port))
free_port = port
break
except zmq.error.ZMQError:
logger.warn(
"[create_reply_socket]cannot bind port:{}, retry".
format(port))
if free_port is not None:
return socket, client_ip, free_port
else:
logger.error(
"cannot find any available port from 6000 to 8000")
sys.exit(1)
def connect_server(self, server_ip, server_port):
"""
create the connection between client side and server side.
Args:
server_ip(str): the ip of the server.
server_port(int): the connection port of the server.
"""
self.reply_socket, local_ip, local_port = self.create_reply_socket(
)
logger.info("connecting {}:{}".format(server_ip, server_port))
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://{}:{}".format(server_ip, server_port))
client_id = np.random.randint(int(1e18))
logger.info("client_id:{}".format(client_id))
socket.send_string('{}:{} {}'.format(local_ip, local_port,
client_id))
message = socket.recv_string()
logger.info("[connect_server] done, message from server:{}".format(
message))
self.connect_socket = socket
def __getattr__(self, attr):
"""
Call the function of the initial class. The wrapped class do not have
same functions as the unwrapped one.
We have to call the function of the function in unwrapped class,
This implementation utilise a function wrapper.
"""
if hasattr(self.unwrapped, attr):
def wrapper(*args, **kw):
return getattr(self.unwrapped, attr)(*args, **kw)
return wrapper
raise AttributeError(attr)
def remote_run(self, server_ip, server_port):
"""
connect server and wait for requires of running functions from server side.
Args:
server_ip(str): server's ip
server_port(int): server's port
"""
self.connect_server(server_ip, server_port)
while True:
function_name = self.reply_socket.recv_string()
self.reply_socket.send_string("OK")
data = self.reply_socket.recv()
args, kw = loads_argument(data)
ret = getattr(self.unwrapped, function_name)(*args, **kw)
ret = dumps_return(ret)
self.reply_socket.send(ret)
class ServerWrapper(object):
"""
Wrapper for remote class at server side.
"""
def __init__(self, *args):
"""
Args:
args: arguments used to initialize the initial class
"""
self.unwrapped = (cls(*args)).unwrapped
self.command_socket = None
self.internal_lock = threading.Lock()
def connect_client(self, client_info):
"""
build another connection with the client to send command to the client.
"""
client_address, client_id = client_info.split(' ')
context = zmq.Context()
socket = context.socket(zmq.REQ)
logger.info(
"[connect_client] client_address:{}".format(client_address))
socket.connect("tcp://{}".format(client_address))
self.command_socket = socket
self.client_id = client_id
def __getattr__(self, attr):
"""
Run the function at client side. we also implement this through a wrapper.
Args:
attr(str): a function name specify which function to run.
"""
if hasattr(self.unwrapped, attr):
def wrapper(*args, **kw):
self.internal_lock.acquire()
self.command_socket.send_string(attr)
self.command_socket.recv_string()
data = dumps_argument(*args, **kw)
self.command_socket.send(data)
ret = self.command_socket.recv()
ret = loads_return(ret)
self.internal_lock.release()
return ret
return wrapper
raise NotImplementedError()
if location == 'client':
return ClientWrapper
else:
return ServerWrapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class RemoteError(Exception):
"""
Super class of exceptions in remote module.
"""
def __init__(self, func_name, error_info):
self.error_info = "[PARL remote error when calling function `{}`]:\n{}".format(
func_name, error_info)
def __str__(self):
return self.error_info
class RemoteSerializeError(RemoteError):
"""
Serialize error from remote
"""
def __init__(self, func_name, error_info):
super(RemoteSerializeError, self).__init__(func_name, error_info)
def __str__(self):
return self.error_info
class RemoteDeserializeError(RemoteError):
"""
Deserialize error from remote
"""
def __init__(self, func_name, error_info):
super(RemoteDeserializeError, self).__init__(func_name, error_info)
def __str__(self):
return self.error_info
class RemoteAttributeError(RemoteError):
"""
Attribute error from remote
"""
def __init__(self, func_name, error_info):
super(RemoteAttributeError, self).__init__(func_name, error_info)
def __str__(self):
return self.error_info
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
CONNECT_TAG = b'[CONNECT]'
HEARTBEAT_TAG = b'[HEARTBEAT]'
EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
SERIALIZE_EXCEPTION_TAG = b'[SERIALIZE_EXCEPTION]'
DESERIALIZE_EXCEPTION_TAG = b'[DESERIALIZE_EXCEPTION]'
NORMAL_TAG = b'[NORMAL]'
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pyarrow
import threading
import time
import zmq
from parl.remote import remote_constants
from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.exceptions import SerializeError, DeserializeError
from parl.utils.communication import loads_argument, dumps_return
"""
Three steps to create a remote class:
1. add a decroator(@parl.remote_class) before the definition of the class;
2. create an instance of remote class;
3. call function `as_remote` with server_ip and server_port.
@parl.remote_class
Class Simulator(object):
...
sim = Simulator()
sim.as_remote(server_ip='172.18.202.45', port=8001)
"""
def remote_class(cls):
class ClientWrapper(object):
"""
Wrapper for remote class in client side.
when as_remote function called, the object initialized in the client can
handle function call from server.
"""
def __init__(self, *args, **kwargs):
"""
Args:
args, kwargs: arguments for the initialisation of the unwrapped class.
"""
self.unwrapped = cls(*args, **kwargs)
self.zmq_context = None
self.poller = None
# socket for connecting server and telling ip and port of client to server
self.connect_socket = None
# socket for handle function call from server side
self.reply_socket = None
def _create_reply_socket(self, remote_ip, remote_port):
"""
In fact, we also have a socket server in client side. This server keeps running
and waits for requests (e.g. call a function) from server side.
"""
if remote_ip is None:
client_ip = get_ip_address()
else:
client_ip = remote_ip
self.zmq_context = zmq.Context()
socket = self.zmq_context.socket(zmq.REP)
free_port = None
if remote_port is None:
for port in range(6000, 8000):
try:
socket.bind("tcp://*:{}".format(port))
logger.info(
"[_create_reply_socket] free_port:{}".format(port))
free_port = port
break
except zmq.error.ZMQError:
logger.warning(
"[_create_reply_socket]cannot bind port:{}, retry".
format(port))
else:
socket.bind("tcp://*:{}".format(remote_port))
free_port = remote_port
if free_port is not None:
return socket, client_ip, free_port
else:
logger.error(
"cannot find any available port from 6000 to 8000")
sys.exit(1)
def _connect_server(self, server_ip, server_port, remote_ip,
remote_port):
"""
Create the connection between client side and server side.
Args:
server_ip(str): the ip of the server.
server_port(int): the connection port of the server.
remote_ip: the ip of the client itself.
remote_port: the port of the client itself,
which used to create reply socket.
"""
self.reply_socket, local_ip, local_port = self._create_reply_socket(
remote_ip, remote_port)
self.reply_socket.linger = 0
socket = self.zmq_context.socket(zmq.REQ)
socket.connect("tcp://{}:{}".format(server_ip, server_port))
client_id = np.random.randint(int(1e18))
logger.info("client_id:{}".format(client_id))
logger.info("connecting {}:{}".format(server_ip, server_port))
client_info = '{}:{} {}'.format(local_ip, local_port, client_id)
socket.send_multipart(
[remote_constants.CONNECT_TAG,
to_byte(client_info)])
message = socket.recv_multipart()
logger.info("[connect_server] done, message from server:{}".format(
message))
self.connect_socket = socket
self.connect_socket.linger = 0
def _exit_remote(self):
# Following release order matters
self.poller.unregister(self.connect_socket)
self.zmq_context.destroy()
def _heartbeat_loop(self):
"""
Periodically detect whether the server is alive or not
"""
self.poller = zmq.Poller()
self.poller.register(self.connect_socket, zmq.POLLIN)
while True:
self.connect_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
# wait for at most 10s to receive response
socks = dict(self.poller.poll(10000))
if socks.get(self.connect_socket) == zmq.POLLIN:
_ = self.connect_socket.recv_multipart()
else:
logger.warning(
'[HeartBeat] Server no response, will exit now!')
self._exit_remote()
break
# HeartBeat interval 10s
time.sleep(10)
def __getattr__(self, attr):
"""
Call the function of the unwrapped class.
"""
def wrapper(*args, **kwargs):
return getattr(self.unwrapped, attr)(*args, **kwargs)
return wrapper
def _reply_loop(self):
while True:
message = self.reply_socket.recv_multipart()
try:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
ret = getattr(self.unwrapped, function_name)(*args,
**kwargs)
ret = dumps_return(ret)
except Exception as e:
error_str = str(e)
logger.error(e)
if type(e) == AttributeError:
self.reply_socket.send_multipart([
remote_constants.ATTRIBUTE_EXCEPTION_TAG,
to_byte(error_str)
])
elif type(e) == SerializeError:
self.reply_socket.send_multipart([
remote_constants.SERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
elif type(e) == DeserializeError:
self.reply_socket.send_multipart([
remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
else:
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str)
])
continue
self.reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
def as_remote(self,
server_ip,
server_port,
remote_ip=None,
remote_port=None):
"""
Client will connect server and wait for function calls from server side.
Args:
server_ip(str): server's ip
server_port(int): server's port
remote_ip: the ip of the client itself.
remote_port: the port of the client itself,
which used to create reply socket.
"""
self._connect_server(server_ip, server_port, remote_ip,
remote_port)
reply_thread = threading.Thread(target=self._reply_loop)
reply_thread.setDaemon(True)
reply_thread.start()
self._heartbeat_loop()
def remote_closed(self):
"""
Check whether as_remote mode is closed
"""
assert self.reply_socket is not None, 'as_remote function should be called first!'
assert self.connect_socket is not None, 'as_remote function should be called first!'
return self.reply_socket.closed and self.connect_socket.closed
return ClientWrapper
......@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import queue
from decorator import virtual
from parl.utils import logger
import threading
import zmq
from parl.utils import logger, to_byte, to_str
from parl.remote import remote_constants
from parl.remote.remote_object import RemoteObject
"""
3 steps to finish communication with remote clients.
1. Create a server:
2. Declare the type of remote client
3. Get remote clients
Two steps to build the communication with remote clients:
1. Create a RemoteManager;
2. Get remote objects by calling the function get_remote.
```python
server = Server()
server.bind(Simulator)
remote_client = server.get_client()
remote_manager = RemoteManager(port=[port])
remote_obj = remote_manager.get_remote()
```
"""
class Server(object):
class RemoteManager(object):
"""
Base class for network communcation.
"""
......@@ -39,48 +39,67 @@ class Server(object):
def __init__(self, port):
"""
Args:
port(int): a local port used for network communication.
port(int): a local port used for connections from remote clients.
"""
context = zmq.Context()
socket = context.socket(zmq.REP)
self.zmq_context = zmq.Context()
socket = self.zmq_context.socket(zmq.REP)
socket.bind("tcp://*:{}".format(port))
self.socket = socket
self.pool = queue.Queue()
self.cls = None
t = threading.Thread(target=self.wait_for_connection)
self.socket.linger = 0
self.remote_pool = queue.Queue()
t = threading.Thread(target=self._wait_for_connection)
t.setDaemon(True) # The thread will exit when main thread exited
t.start()
def wait_for_connection(self):
def _wait_for_connection(self):
"""
A never-ending function keeps waiting for the connection for remote client.
It will put an available remote client into an internel client pool, and clients
can be obtained by calling `get_client`.
A never-ending function keeps waiting for the connections from remote client.
It will put an available remote object in an internel pool, and remote object
can be obtained by calling `get_remote`.
Note that this function has been called inside the `__init__` function.
"""
while True:
client_info = self.socket.recv_string()
client_id = client_info.split(' ')[1]
new_client = virtual(self.cls, location='server')()
self.socket.send_string('Hello World! Client:{}'.format(client_id))
new_client.connect_client(client_info)
self.pool.put(new_client)
def get_client(self):
try:
message = self.socket.recv_multipart()
tag = message[0]
if tag == remote_constants.CONNECT_TAG:
self.socket.send_multipart([
remote_constants.NORMAL_TAG, b'Connect server success.'
])
client_info = to_str(message[1])
remote_client_address, remote_client_id = client_info.split(
)
remote_obj = RemoteObject(remote_client_address,
remote_client_id,
self.zmq_context)
logger.info('[RemoteManager] Added a new remote object.')
self.remote_pool.put(remote_obj)
elif tag == remote_constants.HEARTBEAT_TAG:
self.socket.send_multipart(
[remote_constants.NORMAL_TAG, b'Server is alive.'])
else:
raise NotImplementedError()
except zmq.ZMQError:
logger.warning('Zmq error, exiting server.')
break
def get_remote(self):
"""
A blocking function to obtain a connected client.
A blocking function to obtain a remote object.
Returns:
remote_client(self.cls): a **remote** instance that has all functions as the real one.
RemoteObject
"""
return self.pool.get()
return self.remote_pool.get()
def register_client(self, cls):
def close(self):
"""
Declare the type of remote class.
Let the server know which class to use as a remote client.
Args:
cls(Class): A class decorated by @virtual.
Close RemoteManager.
"""
self.cls = cls
self.zmq_context.destroy()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import zmq
from parl.remote import remote_constants
from parl.remote.exceptions import *
from parl.utils import logger, to_str, to_byte
from parl.utils.communication import dumps_argument, loads_return
class RemoteObject(object):
"""
Provides interface to call functions of object in remote client.
"""
def __init__(self,
remote_client_address,
remote_client_id,
zmq_context=None):
"""
Args:
remote_client_address: address(ip:port) of remote client
remote_client_id: id of remote client
"""
if zmq_context is None:
self.zmq_context = zmq.Context()
else:
self.zmq_context = zmq_context
# socket for sending function call to remote object and receiving result
self.command_socket = None
# lock for thread safety
self.internal_lock = threading.Lock()
self.client_id = remote_client_id
self._connect_remote_client(remote_client_address)
def _connect_remote_client(self, remote_client_address):
"""
Build connection with the remote client to send function call.
"""
socket = self.zmq_context.socket(zmq.REQ)
logger.info("[connect_remote_client] client_address:{}".format(
remote_client_address))
socket.connect("tcp://{}".format(remote_client_address))
self.command_socket = socket
self.command_socket.linger = 0
def __getattr__(self, attr):
"""
Provides interface to call functions of object in remote client.
1. send fucntion name and packed auguments to remote client;
2. remote clinet execute the function of the object really;
3. receive function return from remote client.
Args:
attr(str): a function name specify which function to run.
"""
def wrapper(*args, **kwargs):
self.internal_lock.acquire()
data = dumps_argument(*args, **kwargs)
self.command_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(attr), data])
message = self.command_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1])
elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteError(attr, error_str)
elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteAttributeError(attr, error_str)
elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteSerializeError(attr, error_str)
elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteDeserializeError(attr, error_str)
else:
raise NotImplementedError()
self.internal_lock.release()
return ret
return wrapper
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import unittest
@parl.remote_class
class Simulator:
def __init__(self, arg1, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
class TestRemoteDecorator(unittest.TestCase):
def test_instance_in_local(self):
local_sim = Simulator(1, 2)
self.assertEqual(local_sim.get_arg1(), 1)
self.assertEqual(local_sim.get_arg2(), 2)
local_sim.set_arg1(3)
local_sim.set_arg2(4)
self.assertEqual(local_sim.get_arg1(), 3)
self.assertEqual(local_sim.get_arg2(), 4)
def test_instance_in_local_with_wrong_getattr_get_variable(self):
local_sim = Simulator(1, 2)
try:
local_sim.get_arg3()
except AttributeError:
return
assert False # This line should not be executed.
def test_instance_in_local_with_wrong_getattr_set_variable(self):
local_sim = Simulator(1, 2)
try:
local_sim.set_arg3(3)
except AttributeError:
return
assert False # This line should not be executed.
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import parl
import threading
import unittest
from parl.remote import *
@parl.remote_class
class Simulator:
def __init__(self, arg1, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return self
class TestRemote(unittest.TestCase):
def _setUp(self, server_port):
self.sim = Simulator(1, arg2=2)
# run client in a new thread to fake a remote client
self.client_thread = threading.Thread(
target=self.sim.as_remote, args=(
'localhost',
server_port,
))
self.client_thread.setDaemon(True)
self.client_thread.start()
self.remote_manager = RemoteManager(port=server_port)
def test_remote_object(self):
server_port = 17770
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
self.assertEqual(remote_sim.get_arg1(), 1)
self.assertEqual(remote_sim.get_arg2(), 2)
ret = remote_sim.set_arg1(3)
self.assertIsNone(ret)
ret = remote_sim.set_arg2(4)
self.assertIsNone(ret)
self.assertEqual(remote_sim.get_arg1(), 3)
self.assertEqual(remote_sim.get_arg2(), 4)
def test_remote_object_with_wrong_getattr_get_variable(self):
server_port = 17771
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.get_arg3()
except RemoteAttributeError:
# expected
return
assert False
def test_remote_object_with_wrong_getattr_set_variable(self):
server_port = 17772
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg3(3)
except RemoteAttributeError:
# expected
return
assert False
def test_remote_object_with_wrong_argument(self):
server_port = 17773
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg1(wrong_arg=1)
except RemoteError:
# expected
return
assert False
def test_remote_object_with_unable_serialize_argument(self):
server_port = 17774
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg1(wrong_arg=remote_sim)
except SerializeError:
# expected
return
assert False
def test_remote_object_with_unable_serialize_return(self):
server_port = 17775
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.get_unable_serialize_object()
except RemoteSerializeError:
# expected
return
assert False
def test_mutli_remote_object(self):
server_port = 17776
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote, args=(
'localhost',
server_port,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
def test_mutli_remote_object_with_one_failed(self):
server_port = 17777
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote, args=(
'localhost',
server_port,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
try:
# make remote sim1 failed
remote_sim1.get_arg3()
except:
pass
self.assertEqual(remote_sim2.get_arg1(), 11)
# Todo(@zenghongsheng):
# zmq will raise unexpected C++ exception when closing context,
# remove this unittest for now.
#def test_heartbeat_after_server_closed(self):
# server_port = 17778
# self._setUp(server_port)
# remote_sim = self.remote_manager.get_remote()
# time.sleep(1)
# self.remote_manager.close()
# # heartbeat interval (10s) + max waiting reply (10s)
# time.sleep(20)
# logger.info('check self.sim.remote_closed')
# self.assertTrue(self.sim.remote_closed())
def test_set_client_ip_port_manually(self):
server_port = 17779
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote,
args=(
'localhost',
server_port,
'localhost',
6666,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
if __name__ == '__main__':
unittest.main()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.utils.exceptions import *
from parl.utils.utils import *
from parl.utils.machine_info import *
from parl.utils.replay_memory import *
......@@ -13,6 +13,7 @@
# limitations under the License.
import pyarrow
from parl.utils import SerializeError, DeserializeError
__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
......@@ -28,7 +29,12 @@ def dumps_argument(*args, **kwargs):
Returns:
Implementation-dependent object in bytes.
"""
return pyarrow.serialize([args, kwargs]).to_buffer()
try:
ret = pyarrow.serialize([args, kwargs]).to_buffer()
except Exception as e:
raise SerializeError(e)
return ret
def loads_argument(data):
......@@ -42,7 +48,12 @@ def loads_argument(data):
deserialized arguments [args, kwargs]
like the input of `dumps_argument`, args is a tuple, and kwargs is a dict
"""
return pyarrow.deserialize(data)
try:
ret = pyarrow.deserialize(data)
except Exception as e:
raise DeserializeError(e)
return ret
def dumps_return(data):
......@@ -55,7 +66,12 @@ def dumps_return(data):
Returns:
Implementation-dependent object in bytes.
"""
return pyarrow.serialize(data).to_buffer()
try:
ret = pyarrow.serialize(data).to_buffer()
except Exception as e:
raise SerializeError(e)
return ret
def loads_return(data):
......@@ -68,4 +84,9 @@ def loads_return(data):
Returns:
deserialized data
"""
return pyarrow.deserialize(data)
try:
ret = pyarrow.deserialize(data)
except Exception as e:
raise DeserializeError(e)
return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UtilsError(Exception):
"""
Super class of exceptions in utils module.
"""
def __init__(self, error_info):
self.error_info = '[PARL Utils Error]:\n{}'.format(error_info)
class SerializeError(UtilsError):
"""
Serialize error raised by pyarrow.
"""
def __init__(self, error_info):
super(SerializeError, self).__init__(error_info)
def __str__(self):
return self.error_info
class DeserializeError(UtilsError):
"""
Deserialize error raised by pyarrow.
"""
def __init__(self, error_info):
super(DeserializeError, self).__init__(error_info)
def __str__(self):
return self.error_info
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import platform
import subprocess
from parl.utils import logger
......@@ -23,11 +24,35 @@ def get_ip_address():
"""
get the IP address of the host.
"""
platform_sys = platform.system()
# Only support Linux and MacOS
if platform_sys != 'Linux' and platform_sys != 'Darwin':
logger.warning(
'get_ip_address only support Linux and MacOS, please set ip address manually.'
)
return None
local_ip = None
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
try:
# First way, tested in Ubuntu and MacOS
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
except:
# Second way, tested in CentOS
try:
local_ip = socket.gethostbyname(socket.gethostname())
except:
pass
if local_ip == None or local_ip == '127.0.0.1' or local_ip == '127.0.1.1':
logger.warning(
'get_ip_address failed, please set ip address manually.')
return None
return local_ip
......
......@@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['has_func', 'action_mapping']
import sys
__all__ = [
'has_func', 'action_mapping', 'to_str', 'to_byte', 'is_PY2', 'is_PY3'
]
def has_func(obj, fun):
......@@ -44,3 +48,23 @@ def action_mapping(model_output_act, low_bound, high_bound):
action = low_bound + (model_output_act - (-1.0)) * (
(high_bound - low_bound) / 2.0)
return action
def to_str(byte):
""" convert byte to string in pytohn2/3
"""
return str(byte.decode())
def to_byte(string):
""" convert byte to string in pytohn2/3
"""
return string.encode()
def is_PY2():
return sys.version_info[0] == 2
def is_PY3():
return sys.version_info[0] == 3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册