提交 f2a1fda6 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

make remote client support multiprocessing (#133)

* make remote client support multiprocessing

* refine unittest

* refine unittest

* refine comment
上级 dcb16294
......@@ -41,11 +41,15 @@ class Client(object):
"""
def __init__(self, master_address):
def __init__(self, master_address, process_id):
"""
Args:
master_addr (str): ip address of the master node.
process_id (str): id of the process that created the Client.
Should use os.getpid() to get the process id.
"""
self.master_address = master_address
self.process_id = process_id
self.ctx = zmq.Context()
self.lock = threading.Lock()
self.heartbeat_socket_initialized = threading.Event()
......@@ -272,13 +276,19 @@ def connect(master_address):
assert len(master_address.split(":")) == 2, "please input address in " +\
"{ip}:{port} format"
global GLOBAL_CLIENT
cur_process_id = os.getpid()
if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address)
GLOBAL_CLIENT = Client(master_address, cur_process_id)
else:
if GLOBAL_CLIENT.process_id != cur_process_id:
GLOBAL_CLIENT = Client(master_address, cur_process_id)
def get_global_client():
"""Get the global client.
To support process-based programming, we will create a new global client in the new process.
Returns:
The global client.
"""
......@@ -286,6 +296,10 @@ def get_global_client():
assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
" job, have you connected to the cluster by calling " +\
"parl.connect(master_ip, master_port)?"
cur_process_id = os.getpid()
if GLOBAL_CLIENT.process_id != cur_process_id:
GLOBAL_CLIENT = Client(GLOBAL_CLIENT.master_address, cur_process_id)
return GLOBAL_CLIENT
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import parl
import time
import threading
import timeout_decorator
import multiprocessing
from parl.remote.master import Master
from parl.remote.worker import Worker
from parl.remote.client import disconnect
@parl.remote_class
class Actor(object):
def __init__(self, arg1=None, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def add_one(self, value):
value += 1
return value
class TestCluster(unittest.TestCase):
def tearDown(self):
disconnect()
def _connect_and_create_actor(self, cluster_addr):
parl.connect(cluster_addr)
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
def _create_actor(self):
for _ in range(2):
actor = Actor()
ret = actor.add_one(1)
self.assertEqual(ret, 2)
@timeout_decorator.timeout(seconds=60)
def test_connect_and_create_actor_in_multiprocessing_with_connected_in_main_process(
self):
# start the master
master = Master(port=8238)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8238', 4)
parl.connect('localhost:8238')
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8238', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8238', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
@timeout_decorator.timeout(seconds=60)
def test_connect_and_create_actor_in_multiprocessing_without_connected_in_main_process(
self):
# start the master
master = Master(port=8239)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8239', 4)
proc1 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc2 = multiprocessing.Process(
target=self._connect_and_create_actor, args=('localhost:8239', ))
proc1.start()
proc2.start()
proc1.join()
proc2.join()
self.assertRaises(AssertionError, self._create_actor)
worker1.exit()
master.exit()
@timeout_decorator.timeout(seconds=60)
def test_create_actor_in_multiprocessing(self):
# start the master
master = Master(port=8240)
th = threading.Thread(target=master.run)
th.start()
time.sleep(1)
worker1 = Worker('localhost:8240', 4)
parl.connect('localhost:8240')
proc1 = multiprocessing.Process(target=self._create_actor)
proc2 = multiprocessing.Process(target=self._create_actor)
proc1.start()
proc2.start()
proc1.join()
proc2.join()
# make sure that the client of the main process still works
self._create_actor()
worker1.exit()
master.exit()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册