diff --git a/parl/remote/job.py b/parl/remote/job.py index d2f2d54ac231e3a05ceff65371751abfbe37f392..f70e71f434ae7836febf7e9caf90988ba5f7f9ef 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -36,6 +36,7 @@ from parl.utils.communication import loads_argument, loads_return,\ from parl.remote import remote_constants from parl.utils.exceptions import SerializeError, DeserializeError from parl.remote.message import InitializedJob +from parl.remote.utils import load_remote_class class Job(object): @@ -301,12 +302,12 @@ class Job(object): if tag == remote_constants.INIT_OBJECT_TAG: try: - file_name, class_name = cloudpickle.loads(message[1]) + file_name, class_name, end_of_file = cloudpickle.loads( + message[1]) #/home/nlp-ol/Firework/baidu/nlp/evokit/python_api/es_agent -> es_agent file_name = file_name.split(os.sep)[-1] + cls = load_remote_class(file_name, class_name, end_of_file) args, kwargs = cloudpickle.loads(message[2]) - mod = __import__(file_name) - cls = getattr(mod, class_name)._original obj = cls(*args, **kwargs) except Exception as e: traceback_str = str(traceback.format_exc()) diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index 9b31d8e370db42e291ce1ab08e7ef7542b447389..f4a498bf0169d1322f5a76c6b8b0978c8f61546e 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -75,6 +75,12 @@ def remote_class(*args, **kwargs): """ def decorator(cls): + # we are not going to create a remote actor in job.py + if 'XPARL' in os.environ and os.environ['XPARL'] == 'True': + logger.warning( + "Note: this object will be runnning as a local object") + return cls + class RemoteWrapper(object): """ Wrapper for remote class in client side. @@ -115,10 +121,12 @@ def remote_class(*args, **kwargs): self.send_file(self.job_socket) file_name = inspect.getfile(cls)[:-3] + cls_source = inspect.getsourcelines(cls) + end_of_file = cls_source[1] + len(cls_source[0]) class_name = cls.__name__ self.job_socket.send_multipart([ remote_constants.INIT_OBJECT_TAG, - cloudpickle.dumps([file_name, class_name]), + cloudpickle.dumps([file_name, class_name, end_of_file]), cloudpickle.dumps([args, kwargs]), ]) message = self.job_socket.recv_multipart() @@ -130,7 +138,10 @@ def remote_class(*args, **kwargs): def __del__(self): """Delete the remote class object and release remote resources.""" - self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000) + try: + self.job_socket.setsockopt(zmq.RCVTIMEO, 1 * 1000) + except AttributeError: + pass if not self.job_shutdown: try: self.job_socket.send_multipart( diff --git a/parl/remote/tests/local_actor.py b/parl/remote/tests/local_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..0435ed233153ec9efee548e012eb70ead11e2dd5 --- /dev/null +++ b/parl/remote/tests/local_actor.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +os.environ['XPARL'] = 'True' +import parl +import unittest + + +@parl.remote_class(max_memory=350) +class Actor(object): + def __init__(self, x=10): + self.x = x + self.data = [] + + def add_500mb(self): + self.data.append(os.urandom(500 * 1024**2)) + self.x += 1 + return self.x + + +class TestLocalActor(unittest.TestCase): + def test_create_actors_without_pre_connection(self): + actor = Actor() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/tests/recursive_actor.py b/parl/remote/tests/recursive_actor.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0665c04f84af62eee59f2530ba331a30d45f5d --- /dev/null +++ b/parl/remote/tests/recursive_actor.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from parl.utils import logger +import parl +from parl.remote.client import disconnect +from parl.remote.master import Master +from parl.remote.worker import Worker +import time +import threading + +c = 10 +port = 3002 +master = Master(port=port) +th = threading.Thread(target=master.run) +th.setDaemon(True) +th.start() +time.sleep(5) +cluster_addr = 'localhost:{}'.format(port) +parl.connect(cluster_addr) +worker = Worker(cluster_addr, 1) + + +@parl.remote_class +class Actor(object): + def add(self, a, b): + return a + b + c + + +actor = Actor() + + +class TestRecursive_actor(unittest.TestCase): + def tearDown(self): + disconnect() + + def test_global_running(self): + self.assertEqual(actor.add(1, 2), 13) + master.exit() + worker.exit() + + +if __name__ == '__main__': + unittest.main() diff --git a/parl/remote/utils.py b/parl/remote/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..537880f37cc0d7fc7043404f5ba7d6ea1d212adc --- /dev/null +++ b/parl/remote/utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['load_remote_class'] + + +def simplify_code(code, end_of_file): + """ + @parl.remote_actor has to use this function to simplify the code. + To create a remote object, PARL has to import the module that contains the decorated class. + It may run some unnecessary code when importing the module, and this is why we use this function + to simplify the code. + + For example. + @parl.remote_actor + class A(object): + def add(self, a, b): + return a + b + def data_process(): + XXXX + ------------------> + line 25 and 26 will be removed. + """ + to_write_lines = [] + for i, line in enumerate(code): + if line.startswith('parl.connect'): + continue + if i < end_of_file - 1: + to_write_lines.append(line) + else: + break + return to_write_lines + + +def load_remote_class(file_name, class_name, end_of_file): + with open(file_name) as t_file: + code = t_file.readlines() + code = simplify_code(code) + tmp_file_name = 'parl_' + file_name + with open(tmp_file_name, 'w') as t_file: + for line in code: + t_file.write(line) + mod = __import__(tmp_file_name) + cls = getattr(mod, class_name)._original + return cls