提交 bbde58fb 编写于 作者: B Bo Zhou 提交者: Hongsheng Zeng

first version of network communication (#49)

* first version of network communication

* fix code styple problems

* add a script to get machine's information

* code styple problems#2

* fix unit test problems

* update dockfile to fix the installation issue of cmake

* thread-saftey ensurance & copright

* resolve comments
上级 65ad2a4e
......@@ -103,9 +103,6 @@ def main(argv=None):
first_line = fd.readline()
second_line = fd.readline()
if "COPYRIGHT (C)" in first_line.upper(): continue
if first_line.startswith("#!") or PYTHON_ENCODE.match(
second_line) != None or PYTHON_ENCODE.match(first_line) != None:
continue
original_contents = io.open(filename, encoding="utf-8").read()
new_contents = generate_copyright(
COPYRIGHT, lang_type(filename)) + original_contents
......
......@@ -17,6 +17,7 @@
FROM paddlepaddle/paddle:1.1.0-gpu-cuda9.0-cudnn7
Run apt-get update
RUN apt-get install -y cmake
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple gym
RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple details
......
......@@ -41,7 +41,7 @@ class AtariModel(parl.Model):
def __init__(self, img_shape, action_dim):
# define your layers
self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5,
stride=[1, 1], padding=[2, 2], act='relu')
stride=1, padding=2, act='relu')
...
self.fc1 = layers.fc(action_dim)
def value(self, img):
......
......@@ -19,7 +19,7 @@ from paddle import fluid
from parl.framework.agent_base import Agent
from parl.framework.algorithm_base import Algorithm
from parl.framework.model_base import Model
from parl.utils import gputils
from parl.utils.machine_info import get_gpu_count
class TestModel(Model):
......@@ -66,7 +66,7 @@ class AgentBaseTest(unittest.TestCase):
self.algorithm = TestAlgorithm(self.model)
def test_agent_with_gpu(self):
if gputils.get_gpu_count() > 0:
if get_gpu_count() > 0:
agent = TestAgent(self.algorithm, gpu_id=0)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq
from parl.utils import logger
import numpy as np
from parl.utils.communication import dumps_argument, loads_argument
from parl.utils.communication import dumps_return, loads_return
from parl.utils.machine_info import get_ip_address
import pyarrow
import threading
"""
Three steps to create a remote class --
1. add a decroator(@virtual) before the definition of the class.
2. create an instance of remote class
3. call function `remote_run` with server address
@virtual
Class Simulator(object):
...
sim = Simulator()
sim.remote_run(server_ip='172.18.202.45', port=8001)
"""
def virtual(cls, location='client'):
"""
Class wrapper for wrapping a normal class as a remote class that can run in different machines.
Two kinds of wrapper are provided for the client as well as the server.
Args:
location(str): specify which wrapper to use, available locations: client/server.
users are expected to use `client`.
"""
assert location in ['client', 'server'], \
'Remote Class has to be placed at client side or server side.'
class ClientWrapper(object):
"""
Wrapper for remote class at client side. After the decoration,
the initial class is able to be called to run any function at sever side.
"""
def __init__(self, *args):
"""
Args:
args: arguments for the initialisation of the initial class.
"""
self.unwrapped = cls(*args)
self.conect_socket = None
self.reply_socket = None
def create_reply_socket(self):
"""
In fact, we have also a socket server in client side. This server keeps running
and waits for requests (e.g. call a function) from server side.
"""
client_ip = get_ip_address()
context = zmq.Context()
socket = context.socket(zmq.REP)
free_port = None
for port in range(6000, 8000):
try:
socket.bind("tcp://*:{}".format(port))
logger.info(
"[create_reply_socket] free_port:{}".format(port))
free_port = port
break
except zmq.error.ZMQError:
logger.warn(
"[create_reply_socket]cannot bind port:{}, retry".
format(port))
if free_port is not None:
return socket, client_ip, free_port
else:
logger.error(
"cannot find any available port from 6000 to 8000")
sys.exit(1)
def connect_server(self, server_ip, server_port):
"""
create the connection between client side and server side.
Args:
server_ip(str): the ip of the server.
server_port(int): the connection port of the server.
"""
self.reply_socket, local_ip, local_port = self.create_reply_socket(
)
logger.info("connecting {}:{}".format(server_ip, server_port))
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://{}:{}".format(server_ip, server_port))
client_id = np.random.randint(int(1e18))
logger.info("client_id:{}".format(client_id))
socket.send_string('{}:{} {}'.format(local_ip, local_port,
client_id))
message = socket.recv_string()
logger.info("[connect_server] done, message from server:{}".format(
message))
self.connect_socket = socket
def __getattr__(self, attr):
"""
Call the function of the initial class. The wrapped class do not have
same functions as the unwrapped one.
We have to call the function of the function in unwrapped class,
This implementation utilise a function wrapper.
"""
if hasattr(self.unwrapped, attr):
def wrapper(*args, **kw):
return getattr(self.unwrapped, attr)(*args, **kw)
return wrapper
raise AttributeError(attr)
def remote_run(self, server_ip, server_port):
"""
connect server and wait for requires of running functions from server side.
Args:
server_ip(str): server's ip
server_port(int): server's port
"""
self.connect_server(server_ip, server_port)
while True:
function_name = self.reply_socket.recv_string()
self.reply_socket.send_string("OK")
data = self.reply_socket.recv()
args, kw = loads_argument(data)
ret = getattr(self.unwrapped, function_name)(*args, **kw)
ret = dumps_return(ret)
self.reply_socket.send(ret)
class ServerWrapper(object):
"""
Wrapper for remote class at server side.
"""
def __init__(self, *args):
"""
Args:
args: arguments used to initialize the initial class
"""
self.unwrapped = (cls(*args)).unwrapped
self.command_socket = None
self.internal_lock = threading.Lock()
def connect_client(self, client_info):
"""
build another connection with the client to send command to the client.
"""
client_address, client_id = client_info.split(' ')
context = zmq.Context()
socket = context.socket(zmq.REQ)
logger.info(
"[connect_client] client_address:{}".format(client_address))
socket.connect("tcp://{}".format(client_address))
self.command_socket = socket
self.client_id = client_id
def __getattr__(self, attr):
"""
Run the function at client side. we also implement this through a wrapper.
Args:
attr(str): a function name specify which function to run.
"""
if hasattr(self.unwrapped, attr):
def wrapper(*args, **kw):
self.internal_lock.acquire()
self.command_socket.send_string(attr)
self.command_socket.recv_string()
data = dumps_argument(*args, **kw)
self.command_socket.send(data)
ret = self.command_socket.recv()
ret = loads_return(ret)
self.internal_lock.release()
return ret
return wrapper
raise NotImplementedError()
if location == 'client':
return ClientWrapper
else:
return ServerWrapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import queue
from decorator import virtual
from parl.utils import logger
"""
3 steps to finish communication with remote clients.
1. Create a server:
2. Declare the type of remote client
3. Get remote clients
```python
server = Server()
server.bind(Simulator)
remote_client = server.get_client()
```
"""
class Server(object):
"""
Base class for network communcation.
"""
def __init__(self, port):
"""
Args:
port(int): a local port used for network communication.
"""
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:{}".format(port))
self.socket = socket
self.pool = queue.Queue()
self.cls = None
t = threading.Thread(target=self.wait_for_connection)
t.start()
def wait_for_connection(self):
"""
A never-ending function keeps waiting for the connection for remote client.
It will put an available remote client into an internel client pool, and clients
can be obtained by calling `get_client`.
Note that this function has been called inside the `__init__` function.
"""
while True:
client_info = self.socket.recv_string()
client_id = client_info.split(' ')[1]
new_client = virtual(self.cls, location='server')()
self.socket.send_string('Hello World! Client:{}'.format(client_id))
new_client.connect_client(client_info)
self.pool.put(new_client)
def get_client(self):
"""
A blocking function to obtain a connected client.
Returns:
remote_client(self.cls): a **remote** instance that has all functions as the real one.
"""
return self.pool.get()
def register_client(self, cls):
"""
Declare the type of remote class.
Let the server know which class to use as a remote client.
Args:
cls(Class): A class decorated by @virtual.
"""
self.cls = cls
......@@ -13,5 +13,5 @@
# limitations under the License.
from parl.utils.utils import *
from parl.utils.gputils import *
from parl.utils.machine_info import *
from parl.utils.replay_memory import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pyarrow
__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
def dumps_argument(*args, **kwargs):
"""
Serialize arguments passed to a function.
args:
*args, **kwargs are general a commonly used representation of arguments in python.
Returns:
Implementation-dependent object in bytes.
"""
return pyarrow.serialize([args, kwargs]).to_buffer()
def loads_argument(data):
"""
Restore bytes data to their initial data formats.
Args:
data: the output of `dumps_argument`.
Returns:
deserialized arguments [args, kwargs]
like the input of `dumps_argument`, args is a tuple, and kwargs is a dict
"""
return pyarrow.deserialize(data)
def dumps_return(data):
"""
Serialize the return data of a function.
Args:
data: the output of a function.
Returns:
Implementation-dependent object in bytes.
"""
return pyarrow.serialize(data).to_buffer()
def loads_return(data):
"""
Deserialize the data generated by `dumps_return`.
Args:
data: the output of `dumps_return`
Returns:
deserialized data
"""
return pyarrow.deserialize(data)
......@@ -16,11 +16,24 @@ import os
import subprocess
from parl.utils import logger
__all__ = ['get_gpu_count']
__all__ = ['get_gpu_count', 'get_ip_address']
def get_ip_address():
"""
get the IP address of the host.
"""
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
return local_ip
def get_gpu_count():
""" get avaliable gpu count
"""
get avaliable gpu count
Returns:
gpu_count: int
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册