未验证 提交 d6e82f01 编写于 作者: L liuyuecheng-github 提交者: GitHub

add the function to get and set attributes of remote models (#381)

* add support of RegExp input of distributed_files[]

* add support of RegExp input of distributed_files[]

* add the function to get and set attributes of remote models

* get_set_attributes, not finished

* get_set_atributes, not finieshed yet

* add function to get and set attributes of remote model

* add function to get and set attributes of remote models

* add more unnitest cases, together with several comments

* make codes reusable
上级 ad8d0ced
......@@ -53,7 +53,7 @@ class Model(ModelBase):
copied_policy = copy.deepcopy(model)
Attributes:
model_id(str): each model instance has its uniqe model_id.
model_id(str): each model instance has its unique model_id.
Public Functions:
- ``sync_weights_to``: synchronize parameters of the current model to another model.
......
......@@ -395,24 +395,59 @@ class Job(object):
while True:
message = reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.CALL_TAG:
if tag in [
remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE,
remote_constants.SET_ATTRIBUTE,
remote_constants.CHECK_ATTRIBUTE
]:
try:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
# Redirect stdout to stdout.log temporarily
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
if tag == remote_constants.CHECK_ATTRIBUTE:
attr = to_str(message[1])
if attr in obj.__dict__:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(True)
])
else:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(False)
])
elif tag == remote_constants.CALL_TAG:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
# Redirect stdout to stdout.log temporarily
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.GET_ATTRIBUTE:
attribute_name = to_str(message[1])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
ret = getattr(obj, attribute_name)
ret = dumps_return(ret)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.SET_ATTRIBUTE:
attribute_name = to_str(message[1])
attribute_value = loads_return(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
setattr(obj, attribute_name, attribute_value)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG])
else:
pass
except Exception as e:
# reset the job
......
......@@ -29,6 +29,9 @@ NEW_JOB_TAG = b'[NEW_JOB]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]'
GET_ATTRIBUTE = b'[GET_ATTRIBUTE]'
SET_ATTRIBUTE = b'[SET_ATTRIBUTE]'
CHECK_ATTRIBUTE = b'[CHECK_ATTRIBUTE]'
EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
......
......@@ -190,25 +190,66 @@ def remote_class(*args, **kwargs):
cnt -= 1
return None
def __getattr__(self, attr):
def check_attribute(self, attr):
'''checkout if attr is a attribute or a function'''
self.internal_lock.acquire()
self.job_socket.send_multipart(
[remote_constants.CHECK_ATTRIBUTE,
to_byte(attr)])
message = self.job_socket.recv_multipart()
self.internal_lock.release()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
return loads_return(message[1])
else:
self.job_shutdown = True
raise NotImplementedError()
def set_remote_attr(self, attr, value):
self.internal_lock.acquire()
self.job_socket.send_multipart([
remote_constants.SET_ATTRIBUTE,
to_byte(attr),
dumps_return(value)
])
message = self.job_socket.recv_multipart()
tag = message[0]
self.internal_lock.release()
if tag == remote_constants.NORMAL_TAG:
pass
else:
self.job_shutdown = True
raise NotImplementedError()
return
def get_remote_attr(self, attr):
"""Call the function of the unwrapped class."""
#check if attr is a attribute or a function
is_attribute = self.check_attribute(attr)
def wrapper(*args, **kwargs):
if self.job_shutdown:
raise RemoteError(
attr, "This actor losts connection with the job.")
self.internal_lock.acquire()
data = dumps_argument(*args, **kwargs)
self.job_socket.send_multipart(
[remote_constants.CALL_TAG,
to_byte(attr), data])
if is_attribute:
self.job_socket.send_multipart(
[remote_constants.GET_ATTRIBUTE,
to_byte(attr)])
else:
if self.job_shutdown:
raise RemoteError(
attr,
"This actor losts connection with the job.")
data = dumps_argument(*args, **kwargs)
self.job_socket.send_multipart(
[remote_constants.CALL_TAG,
to_byte(attr), data])
message = self.job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1])
self.internal_lock.release()
return ret
elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
......@@ -234,13 +275,38 @@ def remote_class(*args, **kwargs):
self.job_shutdown = True
raise NotImplementedError()
self.internal_lock.release()
return ret
return wrapper() if is_attribute else wrapper
def proxy_wrapper_func(remote_wrapper):
'''
The 'proxy_wrapper_func' is defined on the top of class 'RemoteWrapper'
in order to set and get attributes of 'remoted_wrapper' and the corresponding
remote models individually.
With 'proxy_wrapper_func', it is allowed to define a attribute (or method) of
the same name in 'RemoteWrapper' and remote models.
'''
class ProxyWrapper(object):
def __init__(self, *args, **kwargs):
self.xparl_remote_wrapper_obj = remote_wrapper(
*args, **kwargs)
def __getattr__(self, attr):
return self.xparl_remote_wrapper_obj.get_remote_attr(attr)
def __setattr__(self, attr, value):
if attr == 'xparl_remote_wrapper_obj':
super(ProxyWrapper, self).__setattr__(attr, value)
else:
self.xparl_remote_wrapper_obj.set_remote_attr(
attr, value)
return wrapper
return ProxyWrapper
RemoteWrapper._original = cls
return RemoteWrapper
proxy_wrapper = proxy_wrapper_func(RemoteWrapper)
return proxy_wrapper
max_memory = kwargs.get('max_memory')
if len(args) == 1 and callable(args[0]):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import numpy as np
from parl.remote.client import disconnect
from parl.utils import logger
from parl.remote.master import Master
from parl.remote.worker import Worker
import time
import threading
import random
@parl.remote_class
class Actor(object):
def __init__(self, arg1, arg2, arg3, arg4):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
self.GLOBAL_CLIENT = arg4
def arg1(self, x, y):
time.sleep(0.2)
return x + y
def arg5(self):
return 100
class Test_get_and_set_attribute(unittest.TestCase):
def tearDown(self):
disconnect()
def test_get_attribute(self):
port1 = random.randint(6100, 6200)
logger.info("running:test_get_attirbute")
master = Master(port=port1)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port1), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port1))
actor = Actor(arg1, arg2, arg3, arg4)
self.assertTrue(arg1 == actor.arg1)
self.assertTrue(arg2 == actor.arg2)
self.assertTrue((arg3 == actor.arg3).all())
self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
master.exit()
worker1.exit()
def test_set_attribute(self):
port2 = random.randint(6200, 6300)
logger.info("running:test_set_attirbute")
master = Master(port=port2)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port2), 1)
arg1 = 3
arg2 = 3.5
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port2))
actor = Actor(arg1, arg2, arg3, arg4)
actor.arg1 = arg1
actor.arg2 = arg2
actor.arg3 = arg3
actor.GLOBAL_CLIENT = arg4
self.assertTrue(arg1 == actor.arg1)
self.assertTrue(arg2 == actor.arg2)
self.assertTrue((arg3 == actor.arg3).all())
self.assertTrue(arg4 == actor.GLOBAL_CLIENT)
master.exit()
worker1.exit()
def test_create_new_attribute_same_with_wrapper(self):
port3 = random.randint(6400, 6500)
logger.info("running:test_create_new_attribute_same_with_wrapper")
master = Master(port=port3)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port3), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port3))
actor = Actor(arg1, arg2, arg3, arg4)
actor.internal_lock = 50
self.assertTrue(actor.internal_lock == 50)
master.exit()
worker1.exit()
def test_same_name_of_attribute_and_method(self):
port4 = random.randint(6500, 6600)
logger.info("running:test_same_name_of_attribute_and_method")
master = Master(port=port4)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port4), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port4))
actor = Actor(arg1, arg2, arg3, arg4)
self.assertEqual(arg1, actor.arg1)
def call_method():
return actor.arg1(1, 2)
self.assertRaises(TypeError, call_method)
master.exit()
worker1.exit()
def test_non_existing_attribute_same_with_existing_method(self):
port5 = random.randint(6600, 6700)
logger.info(
"running:test_non_existing_attribute_same_with_existing_method")
master = Master(port=port5)
th = threading.Thread(target=master.run)
th.start()
time.sleep(3)
worker1 = Worker('localhost:{}'.format(port5), 1)
arg1 = np.random.randint(100)
arg2 = np.random.randn()
arg3 = np.random.randn(3, 3)
arg4 = 100
parl.connect('localhost:{}'.format(port5))
actor = Actor(arg1, arg2, arg3, arg4)
self.assertTrue(callable(actor.arg5))
def call_non_existing_method():
return actor.arg2(10)
self.assertRaises(TypeError, call_non_existing_method)
master.exit()
worker1.exit()
if __name__ == '__main__':
unittest.main()
......@@ -72,7 +72,6 @@ class Worker(object):
self.master_is_alive = True
self.worker_is_alive = True
self.worker_status = None # initialized at `self._create_jobs`
self.lock = threading.Lock()
self._set_cpu_num(cpu_num)
self.job_buffer = queue.Queue(maxsize=self.cpu_num)
self._create_sockets()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册