From 784178154e28bd9d23bdae624ae8cf89e2583acf Mon Sep 17 00:00:00 2001 From: Yuecheng Liu <52879090+liuyuecheng-github@users.noreply.github.com> Date: Tue, 18 Aug 2020 20:54:15 +0800 Subject: [PATCH] check remote attributes in local (#388) * check remote attribute in local * update the attribute keys after calling remote functions, add more unnittests * modify lock realease position * yapf.. * yapf * yapf * remove bracket --- parl/remote/job.py | 43 +++++++++------------ parl/remote/remote_constants.py | 5 +-- parl/remote/remote_decorator.py | 41 ++++++++------------ parl/remote/tests/get_set_attribute_test.py | 7 ++++ 4 files changed, 44 insertions(+), 52 deletions(-) diff --git a/parl/remote/job.py b/parl/remote/job.py index b13c868..aa677eb 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -325,7 +325,10 @@ class Job(object): to_byte(error_str + "\ntraceback:\n" + traceback_str) ]) return None - reply_socket.send_multipart([remote_constants.NORMAL_TAG]) + reply_socket.send_multipart([ + remote_constants.NORMAL_TAG, + dumps_return(set(obj.__dict__.keys())) + ]) else: logger.error("Message from job {}".format(message)) reply_socket.send_multipart([ @@ -397,25 +400,12 @@ class Job(object): message = reply_socket.recv_multipart() tag = message[0] if tag in [ - remote_constants.CALL_TAG, remote_constants.GET_ATTRIBUTE, - remote_constants.SET_ATTRIBUTE, - remote_constants.CHECK_ATTRIBUTE + remote_constants.CALL_TAG, + remote_constants.GET_ATTRIBUTE_TAG, + remote_constants.SET_ATTRIBUTE_TAG, ]: try: - if tag == remote_constants.CHECK_ATTRIBUTE: - attr = to_str(message[1]) - if attr in obj.__dict__: - reply_socket.send_multipart([ - remote_constants.NORMAL_TAG, - dumps_return(True) - ]) - else: - reply_socket.send_multipart([ - remote_constants.NORMAL_TAG, - dumps_return(False) - ]) - - elif tag == remote_constants.CALL_TAG: + if tag == remote_constants.CALL_TAG: function_name = to_str(message[1]) data = message[2] args, kwargs = loads_argument(data) @@ -426,11 +416,12 @@ class Job(object): ret = getattr(obj, function_name)(*args, **kwargs) ret = dumps_return(ret) + reply_socket.send_multipart([ + remote_constants.NORMAL_TAG, ret, + dumps_return(set(obj.__dict__.keys())) + ]) - reply_socket.send_multipart( - [remote_constants.NORMAL_TAG, ret]) - - elif tag == remote_constants.GET_ATTRIBUTE: + elif tag == remote_constants.GET_ATTRIBUTE_TAG: attribute_name = to_str(message[1]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): @@ -438,14 +429,16 @@ class Job(object): ret = dumps_return(ret) reply_socket.send_multipart( [remote_constants.NORMAL_TAG, ret]) - elif tag == remote_constants.SET_ATTRIBUTE: + elif tag == remote_constants.SET_ATTRIBUTE_TAG: attribute_name = to_str(message[1]) attribute_value = loads_return(message[2]) logfile_path = os.path.join(self.log_dir, 'stdout.log') with redirect_stdout_to_file(logfile_path): setattr(obj, attribute_name, attribute_value) - reply_socket.send_multipart( - [remote_constants.NORMAL_TAG]) + reply_socket.send_multipart([ + remote_constants.NORMAL_TAG, + dumps_return(set(obj.__dict__.keys())) + ]) else: pass diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index 7ce2ae1..09f96ca 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -30,9 +30,8 @@ NEW_JOB_TAG = b'[NEW_JOB]' CHECK_VERSION_TAG = b'[CHECK_VERSION]' INIT_OBJECT_TAG = b'[INIT_OBJECT]' CALL_TAG = b'[CALL]' -GET_ATTRIBUTE = b'[GET_ATTRIBUTE]' -SET_ATTRIBUTE = b'[SET_ATTRIBUTE]' -CHECK_ATTRIBUTE = b'[CHECK_ATTRIBUTE]' +GET_ATTRIBUTE_TAG = b'[GET_ATTRIBUTE]' +SET_ATTRIBUTE_TAG = b'[SET_ATTRIBUTE]' EXCEPTION_TAG = b'[EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 47909b1..7403f03 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -95,7 +95,7 @@ def remote_class(*args, **kwargs): class. """ self.GLOBAL_CLIENT = get_global_client() - + self.remote_attribute_keys_set = set() self.ctx = self.GLOBAL_CLIENT.ctx # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat @@ -142,10 +142,14 @@ def remote_class(*args, **kwargs): ]) message = self.job_socket.recv_multipart() tag = message[0] - if tag == remote_constants.EXCEPTION_TAG: + if tag == remote_constants.NORMAL_TAG: + self.remote_attribute_keys_set = loads_return(message[1]) + elif tag == remote_constants.EXCEPTION_TAG: traceback_str = to_str(message[1]) self.job_shutdown = True raise RemoteError('__init__', traceback_str) + else: + pass def __del__(self): """Delete the remote class object and release remote resources.""" @@ -190,33 +194,18 @@ def remote_class(*args, **kwargs): cnt -= 1 return None - def check_attribute(self, attr): - '''checkout if attr is a attribute or a function''' - self.internal_lock.acquire() - self.job_socket.send_multipart( - [remote_constants.CHECK_ATTRIBUTE, - to_byte(attr)]) - message = self.job_socket.recv_multipart() - self.internal_lock.release() - tag = message[0] - if tag == remote_constants.NORMAL_TAG: - return loads_return(message[1]) - else: - self.job_shutdown = True - raise NotImplementedError() - def set_remote_attr(self, attr, value): self.internal_lock.acquire() self.job_socket.send_multipart([ - remote_constants.SET_ATTRIBUTE, + remote_constants.SET_ATTRIBUTE_TAG, to_byte(attr), dumps_return(value) ]) message = self.job_socket.recv_multipart() tag = message[0] - self.internal_lock.release() if tag == remote_constants.NORMAL_TAG: - pass + self.remote_attribute_keys_set = loads_return(message[1]) + self.internal_lock.release() else: self.job_shutdown = True raise NotImplementedError() @@ -225,14 +214,15 @@ def remote_class(*args, **kwargs): def get_remote_attr(self, attr): """Call the function of the unwrapped class.""" #check if attr is a attribute or a function - is_attribute = self.check_attribute(attr) + is_attribute = attr in self.remote_attribute_keys_set def wrapper(*args, **kwargs): self.internal_lock.acquire() if is_attribute: - self.job_socket.send_multipart( - [remote_constants.GET_ATTRIBUTE, - to_byte(attr)]) + self.job_socket.send_multipart([ + remote_constants.GET_ATTRIBUTE_TAG, + to_byte(attr) + ]) else: if self.job_shutdown: raise RemoteError( @@ -248,6 +238,9 @@ def remote_class(*args, **kwargs): if tag == remote_constants.NORMAL_TAG: ret = loads_return(message[1]) + if not is_attribute: + self.remote_attribute_keys_set = loads_return( + message[2]) self.internal_lock.release() return ret diff --git a/parl/remote/tests/get_set_attribute_test.py b/parl/remote/tests/get_set_attribute_test.py index a233822..68eef3f 100644 --- a/parl/remote/tests/get_set_attribute_test.py +++ b/parl/remote/tests/get_set_attribute_test.py @@ -39,6 +39,9 @@ class Actor(object): def arg5(self): return 100 + def set_new_attr(self): + self.new_attr_1 = 200 + class Test_get_and_set_attribute(unittest.TestCase): def tearDown(self): @@ -148,6 +151,10 @@ class Test_get_and_set_attribute(unittest.TestCase): arg4 = 100 parl.connect('localhost:{}'.format(port5)) actor = Actor(arg1, arg2, arg3, arg4) + actor.new_attr_2 = 300 + self.assertEqual(300, actor.new_attr_2) + actor.set_new_attr() + self.assertEqual(200, actor.new_attr_1) self.assertTrue(callable(actor.arg5)) def call_non_existing_method(): -- GitLab