未验证 提交 78417815 编写于 作者: Y Yuecheng Liu 提交者: GitHub

check remote attributes in local (#388)

* check remote attribute in local

* update the attribute keys after calling remote functions, add more unnittests

* modify lock realease position

* yapf..

* yapf

* yapf

* remove bracket
上级 c2c947c3
......@@ -325,7 +325,10 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" + traceback_str)
])
return None
reply_socket.send_multipart([remote_constants.NORMAL_TAG])
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else:
logger.error("Message from job {}".format(message))
reply_socket.send_multipart([
......@@ -397,25 +400,12 @@ class Job(object):
message = reply_socket.recv_multipart()
tag = message[0]
if tag in [
remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE,
remote_constants.SET_ATTRIBUTE,
remote_constants.CHECK_ATTRIBUTE
remote_constants.CALL_TAG,
remote_constants.GET_ATTRIBUTE_TAG,
remote_constants.SET_ATTRIBUTE_TAG,
]:
try:
if tag == remote_constants.CHECK_ATTRIBUTE:
attr = to_str(message[1])
if attr in obj.__dict__:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(True)
])
else:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(False)
])
elif tag == remote_constants.CALL_TAG:
if tag == remote_constants.CALL_TAG:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
......@@ -426,11 +416,12 @@ class Job(object):
ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret)
reply_socket.send_multipart([
remote_constants.NORMAL_TAG, ret,
dumps_return(set(obj.__dict__.keys()))
])
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.GET_ATTRIBUTE:
elif tag == remote_constants.GET_ATTRIBUTE_TAG:
attribute_name = to_str(message[1])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
......@@ -438,14 +429,16 @@ class Job(object):
ret = dumps_return(ret)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.SET_ATTRIBUTE:
elif tag == remote_constants.SET_ATTRIBUTE_TAG:
attribute_name = to_str(message[1])
attribute_value = loads_return(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path):
setattr(obj, attribute_name, attribute_value)
reply_socket.send_multipart(
[remote_constants.NORMAL_TAG])
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else:
pass
......
......@@ -30,9 +30,8 @@ NEW_JOB_TAG = b'[NEW_JOB]'
CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]'
GET_ATTRIBUTE = b'[GET_ATTRIBUTE]'
SET_ATTRIBUTE = b'[SET_ATTRIBUTE]'
CHECK_ATTRIBUTE = b'[CHECK_ATTRIBUTE]'
GET_ATTRIBUTE_TAG = b'[GET_ATTRIBUTE]'
SET_ATTRIBUTE_TAG = b'[SET_ATTRIBUTE]'
EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
......
......@@ -95,7 +95,7 @@ def remote_class(*args, **kwargs):
class.
"""
self.GLOBAL_CLIENT = get_global_client()
self.remote_attribute_keys_set = set()
self.ctx = self.GLOBAL_CLIENT.ctx
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
......@@ -142,10 +142,14 @@ def remote_class(*args, **kwargs):
])
message = self.job_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.EXCEPTION_TAG:
if tag == remote_constants.NORMAL_TAG:
self.remote_attribute_keys_set = loads_return(message[1])
elif tag == remote_constants.EXCEPTION_TAG:
traceback_str = to_str(message[1])
self.job_shutdown = True
raise RemoteError('__init__', traceback_str)
else:
pass
def __del__(self):
"""Delete the remote class object and release remote resources."""
......@@ -190,33 +194,18 @@ def remote_class(*args, **kwargs):
cnt -= 1
return None
def check_attribute(self, attr):
'''checkout if attr is a attribute or a function'''
self.internal_lock.acquire()
self.job_socket.send_multipart(
[remote_constants.CHECK_ATTRIBUTE,
to_byte(attr)])
message = self.job_socket.recv_multipart()
self.internal_lock.release()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
return loads_return(message[1])
else:
self.job_shutdown = True
raise NotImplementedError()
def set_remote_attr(self, attr, value):
self.internal_lock.acquire()
self.job_socket.send_multipart([
remote_constants.SET_ATTRIBUTE,
remote_constants.SET_ATTRIBUTE_TAG,
to_byte(attr),
dumps_return(value)
])
message = self.job_socket.recv_multipart()
tag = message[0]
self.internal_lock.release()
if tag == remote_constants.NORMAL_TAG:
pass
self.remote_attribute_keys_set = loads_return(message[1])
self.internal_lock.release()
else:
self.job_shutdown = True
raise NotImplementedError()
......@@ -225,14 +214,15 @@ def remote_class(*args, **kwargs):
def get_remote_attr(self, attr):
"""Call the function of the unwrapped class."""
#check if attr is a attribute or a function
is_attribute = self.check_attribute(attr)
is_attribute = attr in self.remote_attribute_keys_set
def wrapper(*args, **kwargs):
self.internal_lock.acquire()
if is_attribute:
self.job_socket.send_multipart(
[remote_constants.GET_ATTRIBUTE,
to_byte(attr)])
self.job_socket.send_multipart([
remote_constants.GET_ATTRIBUTE_TAG,
to_byte(attr)
])
else:
if self.job_shutdown:
raise RemoteError(
......@@ -248,6 +238,9 @@ def remote_class(*args, **kwargs):
if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1])
if not is_attribute:
self.remote_attribute_keys_set = loads_return(
message[2])
self.internal_lock.release()
return ret
......
......@@ -39,6 +39,9 @@ class Actor(object):
def arg5(self):
return 100
def set_new_attr(self):
self.new_attr_1 = 200
class Test_get_and_set_attribute(unittest.TestCase):
def tearDown(self):
......@@ -148,6 +151,10 @@ class Test_get_and_set_attribute(unittest.TestCase):
arg4 = 100
parl.connect('localhost:{}'.format(port5))
actor = Actor(arg1, arg2, arg3, arg4)
actor.new_attr_2 = 300
self.assertEqual(300, actor.new_attr_2)
actor.set_new_attr()
self.assertEqual(200, actor.new_attr_1)
self.assertTrue(callable(actor.arg5))
def call_non_existing_method():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册