未验证 提交 78417815 编写于 作者: Y Yuecheng Liu 提交者: GitHub

check remote attributes in local (#388)

* check remote attribute in local

* update the attribute keys after calling remote functions, add more unnittests

* modify lock realease position

* yapf..

* yapf

* yapf

* remove bracket
上级 c2c947c3
...@@ -325,7 +325,10 @@ class Job(object): ...@@ -325,7 +325,10 @@ class Job(object):
to_byte(error_str + "\ntraceback:\n" + traceback_str) to_byte(error_str + "\ntraceback:\n" + traceback_str)
]) ])
return None return None
reply_socket.send_multipart([remote_constants.NORMAL_TAG]) reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else: else:
logger.error("Message from job {}".format(message)) logger.error("Message from job {}".format(message))
reply_socket.send_multipart([ reply_socket.send_multipart([
...@@ -397,25 +400,12 @@ class Job(object): ...@@ -397,25 +400,12 @@ class Job(object):
message = reply_socket.recv_multipart() message = reply_socket.recv_multipart()
tag = message[0] tag = message[0]
if tag in [ if tag in [
remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE, remote_constants.CALL_TAG,
remote_constants.SET_ATTRIBUTE, remote_constants.GET_ATTRIBUTE_TAG,
remote_constants.CHECK_ATTRIBUTE remote_constants.SET_ATTRIBUTE_TAG,
]: ]:
try: try:
if tag == remote_constants.CHECK_ATTRIBUTE: if tag == remote_constants.CALL_TAG:
attr = to_str(message[1])
if attr in obj.__dict__:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(True)
])
else:
reply_socket.send_multipart([
remote_constants.NORMAL_TAG,
dumps_return(False)
])
elif tag == remote_constants.CALL_TAG:
function_name = to_str(message[1]) function_name = to_str(message[1])
data = message[2] data = message[2]
args, kwargs = loads_argument(data) args, kwargs = loads_argument(data)
...@@ -426,11 +416,12 @@ class Job(object): ...@@ -426,11 +416,12 @@ class Job(object):
ret = getattr(obj, function_name)(*args, **kwargs) ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret) ret = dumps_return(ret)
reply_socket.send_multipart([
remote_constants.NORMAL_TAG, ret,
dumps_return(set(obj.__dict__.keys()))
])
reply_socket.send_multipart( elif tag == remote_constants.GET_ATTRIBUTE_TAG:
[remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.GET_ATTRIBUTE:
attribute_name = to_str(message[1]) attribute_name = to_str(message[1])
logfile_path = os.path.join(self.log_dir, 'stdout.log') logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path): with redirect_stdout_to_file(logfile_path):
...@@ -438,14 +429,16 @@ class Job(object): ...@@ -438,14 +429,16 @@ class Job(object):
ret = dumps_return(ret) ret = dumps_return(ret)
reply_socket.send_multipart( reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret]) [remote_constants.NORMAL_TAG, ret])
elif tag == remote_constants.SET_ATTRIBUTE: elif tag == remote_constants.SET_ATTRIBUTE_TAG:
attribute_name = to_str(message[1]) attribute_name = to_str(message[1])
attribute_value = loads_return(message[2]) attribute_value = loads_return(message[2])
logfile_path = os.path.join(self.log_dir, 'stdout.log') logfile_path = os.path.join(self.log_dir, 'stdout.log')
with redirect_stdout_to_file(logfile_path): with redirect_stdout_to_file(logfile_path):
setattr(obj, attribute_name, attribute_value) setattr(obj, attribute_name, attribute_value)
reply_socket.send_multipart( reply_socket.send_multipart([
[remote_constants.NORMAL_TAG]) remote_constants.NORMAL_TAG,
dumps_return(set(obj.__dict__.keys()))
])
else: else:
pass pass
......
...@@ -30,9 +30,8 @@ NEW_JOB_TAG = b'[NEW_JOB]' ...@@ -30,9 +30,8 @@ NEW_JOB_TAG = b'[NEW_JOB]'
CHECK_VERSION_TAG = b'[CHECK_VERSION]' CHECK_VERSION_TAG = b'[CHECK_VERSION]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]' INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]' CALL_TAG = b'[CALL]'
GET_ATTRIBUTE = b'[GET_ATTRIBUTE]' GET_ATTRIBUTE_TAG = b'[GET_ATTRIBUTE]'
SET_ATTRIBUTE = b'[SET_ATTRIBUTE]' SET_ATTRIBUTE_TAG = b'[SET_ATTRIBUTE]'
CHECK_ATTRIBUTE = b'[CHECK_ATTRIBUTE]'
EXCEPTION_TAG = b'[EXCEPTION]' EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
......
...@@ -95,7 +95,7 @@ def remote_class(*args, **kwargs): ...@@ -95,7 +95,7 @@ def remote_class(*args, **kwargs):
class. class.
""" """
self.GLOBAL_CLIENT = get_global_client() self.GLOBAL_CLIENT = get_global_client()
self.remote_attribute_keys_set = set()
self.ctx = self.GLOBAL_CLIENT.ctx self.ctx = self.GLOBAL_CLIENT.ctx
# GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
...@@ -142,10 +142,14 @@ def remote_class(*args, **kwargs): ...@@ -142,10 +142,14 @@ def remote_class(*args, **kwargs):
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
tag = message[0] tag = message[0]
if tag == remote_constants.EXCEPTION_TAG: if tag == remote_constants.NORMAL_TAG:
self.remote_attribute_keys_set = loads_return(message[1])
elif tag == remote_constants.EXCEPTION_TAG:
traceback_str = to_str(message[1]) traceback_str = to_str(message[1])
self.job_shutdown = True self.job_shutdown = True
raise RemoteError('__init__', traceback_str) raise RemoteError('__init__', traceback_str)
else:
pass
def __del__(self): def __del__(self):
"""Delete the remote class object and release remote resources.""" """Delete the remote class object and release remote resources."""
...@@ -190,33 +194,18 @@ def remote_class(*args, **kwargs): ...@@ -190,33 +194,18 @@ def remote_class(*args, **kwargs):
cnt -= 1 cnt -= 1
return None return None
def check_attribute(self, attr):
'''checkout if attr is a attribute or a function'''
self.internal_lock.acquire()
self.job_socket.send_multipart(
[remote_constants.CHECK_ATTRIBUTE,
to_byte(attr)])
message = self.job_socket.recv_multipart()
self.internal_lock.release()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
return loads_return(message[1])
else:
self.job_shutdown = True
raise NotImplementedError()
def set_remote_attr(self, attr, value): def set_remote_attr(self, attr, value):
self.internal_lock.acquire() self.internal_lock.acquire()
self.job_socket.send_multipart([ self.job_socket.send_multipart([
remote_constants.SET_ATTRIBUTE, remote_constants.SET_ATTRIBUTE_TAG,
to_byte(attr), to_byte(attr),
dumps_return(value) dumps_return(value)
]) ])
message = self.job_socket.recv_multipart() message = self.job_socket.recv_multipart()
tag = message[0] tag = message[0]
self.internal_lock.release()
if tag == remote_constants.NORMAL_TAG: if tag == remote_constants.NORMAL_TAG:
pass self.remote_attribute_keys_set = loads_return(message[1])
self.internal_lock.release()
else: else:
self.job_shutdown = True self.job_shutdown = True
raise NotImplementedError() raise NotImplementedError()
...@@ -225,14 +214,15 @@ def remote_class(*args, **kwargs): ...@@ -225,14 +214,15 @@ def remote_class(*args, **kwargs):
def get_remote_attr(self, attr): def get_remote_attr(self, attr):
"""Call the function of the unwrapped class.""" """Call the function of the unwrapped class."""
#check if attr is a attribute or a function #check if attr is a attribute or a function
is_attribute = self.check_attribute(attr) is_attribute = attr in self.remote_attribute_keys_set
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
self.internal_lock.acquire() self.internal_lock.acquire()
if is_attribute: if is_attribute:
self.job_socket.send_multipart( self.job_socket.send_multipart([
[remote_constants.GET_ATTRIBUTE, remote_constants.GET_ATTRIBUTE_TAG,
to_byte(attr)]) to_byte(attr)
])
else: else:
if self.job_shutdown: if self.job_shutdown:
raise RemoteError( raise RemoteError(
...@@ -248,6 +238,9 @@ def remote_class(*args, **kwargs): ...@@ -248,6 +238,9 @@ def remote_class(*args, **kwargs):
if tag == remote_constants.NORMAL_TAG: if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1]) ret = loads_return(message[1])
if not is_attribute:
self.remote_attribute_keys_set = loads_return(
message[2])
self.internal_lock.release() self.internal_lock.release()
return ret return ret
......
...@@ -39,6 +39,9 @@ class Actor(object): ...@@ -39,6 +39,9 @@ class Actor(object):
def arg5(self): def arg5(self):
return 100 return 100
def set_new_attr(self):
self.new_attr_1 = 200
class Test_get_and_set_attribute(unittest.TestCase): class Test_get_and_set_attribute(unittest.TestCase):
def tearDown(self): def tearDown(self):
...@@ -148,6 +151,10 @@ class Test_get_and_set_attribute(unittest.TestCase): ...@@ -148,6 +151,10 @@ class Test_get_and_set_attribute(unittest.TestCase):
arg4 = 100 arg4 = 100
parl.connect('localhost:{}'.format(port5)) parl.connect('localhost:{}'.format(port5))
actor = Actor(arg1, arg2, arg3, arg4) actor = Actor(arg1, arg2, arg3, arg4)
actor.new_attr_2 = 300
self.assertEqual(300, actor.new_attr_2)
actor.set_new_attr()
self.assertEqual(200, actor.new_attr_1)
self.assertTrue(callable(actor.arg5)) self.assertTrue(callable(actor.arg5))
def call_non_existing_method(): def call_non_existing_method():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册