decorator.py 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import zmq
from parl.utils import logger
import numpy as np
from parl.utils.communication import dumps_argument, loads_argument
from parl.utils.communication import dumps_return, loads_return
from parl.utils.machine_info import get_ip_address
import pyarrow
import threading
"""
Three steps to create a remote class -- 
1. add a decroator(@virtual) before the definition of the class.
2. create an instance of remote class
3. call function `remote_run` with server address

@virtual 
Class Simulator(object):
    ...

sim = Simulator()
sim.remote_run(server_ip='172.18.202.45', port=8001)

"""


def virtual(cls, location='client'):
    """
    Class wrapper for wrapping a normal class as a remote class that can run in different machines.
    Two kinds of wrapper are provided for the client as well as the server.

    Args:
        location(str): specify which wrapper to use, available locations: client/server.
                       users are expected to use `client`.
    """
    assert location in ['client', 'server'], \
        'Remote Class has to be placed at client side or server side.'

    class ClientWrapper(object):
        """
        Wrapper for remote class at client side. After the decoration, 
        the initial class is able to be called to run any function at sever side.
        """

        def __init__(self, *args):
            """
            Args:
                args: arguments for the initialisation of the initial class.
            """
            self.unwrapped = cls(*args)
            self.conect_socket = None
            self.reply_socket = None

        def create_reply_socket(self):
            """
            In fact, we have also a socket server in client side. This server keeps running 
            and waits for requests (e.g. call a function) from server side.
            """
            client_ip = get_ip_address()
            context = zmq.Context()
            socket = context.socket(zmq.REP)
            free_port = None
            for port in range(6000, 8000):
                try:
                    socket.bind("tcp://*:{}".format(port))
                    logger.info(
                        "[create_reply_socket] free_port:{}".format(port))
                    free_port = port
                    break
                except zmq.error.ZMQError:
                    logger.warn(
                        "[create_reply_socket]cannot bind port:{}, retry".
                        format(port))
            if free_port is not None:
                return socket, client_ip, free_port
            else:
                logger.error(
                    "cannot find any available port from 6000 to 8000")
                sys.exit(1)

        def connect_server(self, server_ip, server_port):
            """
            create the connection between client side and server side.

            Args:
                server_ip(str): the ip of the server.
                server_port(int): the connection port of the server.
            """
            self.reply_socket, local_ip, local_port = self.create_reply_socket(
            )

            logger.info("connecting {}:{}".format(server_ip, server_port))
            context = zmq.Context()
            socket = context.socket(zmq.REQ)
            socket.connect("tcp://{}:{}".format(server_ip, server_port))
            client_id = np.random.randint(int(1e18))
            logger.info("client_id:{}".format(client_id))
            socket.send_string('{}:{} {}'.format(local_ip, local_port,
                                                 client_id))
            message = socket.recv_string()
            logger.info("[connect_server] done, message from server:{}".format(
                message))
            self.connect_socket = socket

        def __getattr__(self, attr):
            """
            Call the function of the initial class. The wrapped class do not have 
            same functions as the unwrapped one. 
            We have to call the function of the function in unwrapped class,
            This implementation utilise a function wrapper.
            """
            if hasattr(self.unwrapped, attr):

                def wrapper(*args, **kw):
                    return getattr(self.unwrapped, attr)(*args, **kw)

                return wrapper
            raise AttributeError(attr)

        def remote_run(self, server_ip, server_port):
            """
            connect server and wait for requires of running functions from server side.

            Args:
                server_ip(str): server's ip
                server_port(int): server's port
            """
            self.connect_server(server_ip, server_port)
            while True:
                function_name = self.reply_socket.recv_string()
                self.reply_socket.send_string("OK")
                data = self.reply_socket.recv()
                args, kw = loads_argument(data)
                ret = getattr(self.unwrapped, function_name)(*args, **kw)
                ret = dumps_return(ret)
                self.reply_socket.send(ret)

    class ServerWrapper(object):
        """
        Wrapper for remote class at server side. 
        """

        def __init__(self, *args):
            """
            Args:
                args: arguments used to initialize the initial class
            """
            self.unwrapped = (cls(*args)).unwrapped
            self.command_socket = None
            self.internal_lock = threading.Lock()

        def connect_client(self, client_info):
            """
            build another connection with the client to send command to the client.
            """
            client_address, client_id = client_info.split(' ')
            context = zmq.Context()
            socket = context.socket(zmq.REQ)
            logger.info(
                "[connect_client] client_address:{}".format(client_address))
            socket.connect("tcp://{}".format(client_address))
            self.command_socket = socket
            self.client_id = client_id

        def __getattr__(self, attr):
            """
            Run the function at client side. we also implement this through a wrapper.

            Args:
                attr(str): a function name specify which function to run.
            """
            if hasattr(self.unwrapped, attr):

                def wrapper(*args, **kw):
                    self.internal_lock.acquire()
                    self.command_socket.send_string(attr)
                    self.command_socket.recv_string()
                    data = dumps_argument(*args, **kw)
                    self.command_socket.send(data)
                    ret = self.command_socket.recv()
                    ret = loads_return(ret)
                    self.internal_lock.release()
                    return ret

                return wrapper
            raise NotImplementedError()

    if location == 'client':
        return ClientWrapper
    else:
        return ServerWrapper