提交 ee4ea7fd 编写于 作者: M Megvii Engine Team

test(distributed/test): make distributed test more stronger

GitOrigin-RevId: 085fd1dcfd3a80467e84ffe463b5dcedf615bd48
上级 3ecded74
...@@ -45,9 +45,15 @@ def launcher(func): ...@@ -45,9 +45,15 @@ def launcher(func):
while len(ranks) > 0: while len(ranks) > 0:
left = [] left = []
# check all processes in one second
time_to_wait = 1.0 / len(ranks)
for rank in ranks: for rank in ranks:
procs[rank].join(1) procs[rank].join(time_to_wait)
code = procs[rank].exitcode code = procs[rank].exitcode
# terminate processes if one of them has failed
if code != 0 and code != None:
for i in ranks:
procs[i].terminate()
assert ( assert (
code == 0 or code == None code == 0 or code == None
), "subprocess {} exit with code {}".format(rank, code) ), "subprocess {} exit with code {}".format(rank, code)
......
...@@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): ...@@ -133,18 +133,22 @@ class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass pass
def start_server(py_server_port, mm_server_port, queue): def _start_server(py_server_port, mm_server_port, queue):
""" """
Start python distributed server and multiple machine server. Start python distributed server and multiple machine server.
:param py_server_port: python server port. :param py_server_port: python server port.
:param mm_server_port: multiple machine server port. :param mm_server_port: multiple machine server port.
:param queue: server port will put in this queue, puts exception when process fails.
""" """
server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) try:
server.register_instance(Methods(mm_server_port)) server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False)
_, port = server.server_address server.register_instance(Methods(mm_server_port))
queue.put(port) _, port = server.server_address
server.serve_forever() queue.put(port)
server.serve_forever()
except Exception as e:
queue.put(e)
class Server: class Server:
...@@ -159,10 +163,14 @@ class Server: ...@@ -159,10 +163,14 @@ class Server:
self.mm_server_port = create_mm_server("0.0.0.0", 0) self.mm_server_port = create_mm_server("0.0.0.0", 0)
q = Queue() q = Queue()
self.proc = threading.Thread( self.proc = threading.Thread(
target=start_server, args=(port, self.mm_server_port, q), daemon=True, target=_start_server, args=(port, self.mm_server_port, q), daemon=True,
) )
self.proc.start() self.proc.start()
self.py_server_port = q.get() ret = q.get()
if isinstance(ret, Exception):
raise ret
else:
self.py_server_port = ret
class Client: class Client:
......
...@@ -159,11 +159,9 @@ def run_test( ...@@ -159,11 +159,9 @@ def run_test(
checkpoint = mge.load(model_path) checkpoint = mge.load(model_path)
data = checkpoint["data"] data = checkpoint["data"]
label = checkpoint["label"] label = checkpoint["label"]
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank, max_err): @dist.launcher
dist.init_process_group("localhost", port, p_num, rank, rank) def worker(max_err):
net = MnistNet(has_bn=True) net = MnistNet(has_bn=True)
net.load_state_dict(checkpoint["net_init"]) net.load_state_dict(checkpoint["net_init"])
lr = checkpoint["sgd_lr"] lr = checkpoint["sgd_lr"]
...@@ -194,15 +192,7 @@ def run_test( ...@@ -194,15 +192,7 @@ def run_test(
else: else:
np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) np.testing.assert_allclose(param[1], param_ref[1], atol=max_err)
procs = [] worker(max_err)
for rank in range(p_num):
p = mp.Process(target=worker, args=(rank, max_err,))
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 4, reason="need more gpu device")
......
...@@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise ...@@ -23,6 +23,7 @@ from megengine.core.ops.builtin import Elemwise
from megengine.core.tensor.raw_tensor import as_raw_tensor from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine.core.tensor.tensor import Tensor, apply from megengine.core.tensor.tensor import Tensor, apply
from megengine.core.tensor.tensor_wrapper import TensorWrapper from megengine.core.tensor.tensor_wrapper import TensorWrapper
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send
...@@ -53,15 +54,19 @@ def save_to(self, name="grad"): ...@@ -53,15 +54,19 @@ def save_to(self, name="grad"):
return callback return callback
@pytest.mark.isolated_distributed @pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif( @pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
) )
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_dist_grad(): def test_dist_grad():
world_size = 2 world_size = 2
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker0(): def worker0():
dist.init_process_group("localhost", port, world_size, 0, 0) dist.init_process_group("localhost", port, world_size, 0, 0)
......
...@@ -47,8 +47,8 @@ def _assert_q_val(q, val): ...@@ -47,8 +47,8 @@ def _assert_q_val(q, val):
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_init_process_group(): def test_init_process_group():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, backend): def worker(rank, backend):
dist.init_process_group("localhost", port, world_size, rank, rank, backend) dist.init_process_group("localhost", port, world_size, rank, rank, backend)
...@@ -92,11 +92,10 @@ def test_init_process_group(): ...@@ -92,11 +92,10 @@ def test_init_process_group():
def test_new_group(): def test_new_group():
world_size = 3 world_size = 3
ranks = [2, 0] ranks = [2, 0]
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank): @dist.launcher
dist.init_process_group("localhost", port, world_size, rank, rank) def worker():
rank = dist.get_rank()
if rank in ranks: if rank in ranks:
group = dist.new_group(ranks) group = dist.new_group(ranks)
assert group.size == 2 assert group.size == 2
...@@ -104,15 +103,7 @@ def test_new_group(): ...@@ -104,15 +103,7 @@ def test_new_group():
assert group.rank == ranks.index(rank) assert group.rank == ranks.index(rank)
assert group.comp_node == "gpu{}:2".format(rank) assert group.comp_node == "gpu{}:2".format(rank)
procs = [] worker()
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -125,8 +116,8 @@ def test_new_group(): ...@@ -125,8 +116,8 @@ def test_new_group():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_group_barrier(): def test_group_barrier():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, q): def worker(rank, q):
dist.init_process_group("localhost", port, world_size, rank, rank) dist.init_process_group("localhost", port, world_size, rank, rank)
...@@ -161,8 +152,8 @@ def test_group_barrier(): ...@@ -161,8 +152,8 @@ def test_group_barrier():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_synchronized(): def test_synchronized():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
@dist.synchronized @dist.synchronized
def func(rank, q): def func(rank, q):
...@@ -205,26 +196,16 @@ def test_synchronized(): ...@@ -205,26 +196,16 @@ def test_synchronized():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_user_set_get(): def test_user_set_get():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank): @dist.launcher
dist.init_process_group("localhost", port, world_size, rank, rank) def worker():
# set in race condition # set in race condition
dist.get_client().user_set("foo", 1) dist.get_client().user_set("foo", 1)
# get in race condition # get in race condition
ret = dist.get_client().user_get("foo") ret = dist.get_client().user_get("foo")
assert ret == 1 assert ret == 1
procs = [] worker()
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
def test_oprmm_hashable(): def test_oprmm_hashable():
......
...@@ -41,8 +41,8 @@ from megengine.functional.distributed import ( ...@@ -41,8 +41,8 @@ from megengine.functional.distributed import (
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_reduce_sum(): def test_reduce_sum():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -83,8 +83,8 @@ def test_reduce_sum(): ...@@ -83,8 +83,8 @@ def test_reduce_sum():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_broadcast(): def test_broadcast():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -121,8 +121,8 @@ def test_broadcast(): ...@@ -121,8 +121,8 @@ def test_broadcast():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_gather(): def test_all_gather():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -160,8 +160,8 @@ def test_all_gather(): ...@@ -160,8 +160,8 @@ def test_all_gather():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_reduce_scatter_sum(): def test_reduce_scatter_sum():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -199,8 +199,8 @@ def test_reduce_scatter_sum(): ...@@ -199,8 +199,8 @@ def test_reduce_scatter_sum():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_sum(): def test_all_reduce_sum():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -238,8 +238,8 @@ def test_all_reduce_sum(): ...@@ -238,8 +238,8 @@ def test_all_reduce_sum():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_max(): def test_all_reduce_max():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -277,8 +277,8 @@ def test_all_reduce_max(): ...@@ -277,8 +277,8 @@ def test_all_reduce_max():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_reduce_min(): def test_all_reduce_min():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -316,8 +316,8 @@ def test_all_reduce_min(): ...@@ -316,8 +316,8 @@ def test_all_reduce_min():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_gather(): def test_gather():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -358,8 +358,8 @@ def test_gather(): ...@@ -358,8 +358,8 @@ def test_gather():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_scatter(): def test_scatter():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -396,8 +396,8 @@ def test_scatter(): ...@@ -396,8 +396,8 @@ def test_scatter():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_all_to_all(): def test_all_to_all():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
def worker(rank, data, expect, port): def worker(rank, data, expect, port):
if mge.get_device_count("gpu") < world_size: if mge.get_device_count("gpu") < world_size:
...@@ -436,8 +436,8 @@ def test_all_to_all(): ...@@ -436,8 +436,8 @@ def test_all_to_all():
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_io_remote(): def test_io_remote():
world_size = 2 world_size = 2
port = dist.get_free_ports(1)[0] server = dist.Server()
server = dist.Server(port) port = server.py_server_port
val = np.random.rand(4, 5).astype(np.float32) val = np.random.rand(4, 5).astype(np.float32)
def worker(rank): def worker(rank):
......
...@@ -38,7 +38,7 @@ def test_syncbn(): ...@@ -38,7 +38,7 @@ def test_syncbn():
running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32) running_var = np.ones((1, nr_chan, 1, 1), dtype=np.float32)
steps = 4 steps = 4
nr_ranks = 2 nr_ranks = 2
server = dist.Server(0) server = dist.Server()
port = server.py_server_port port = server.py_server_port
def worker(rank, data, yv_expect, running_mean, running_var): def worker(rank, data, yv_expect, running_mean, running_var):
......
...@@ -28,25 +28,16 @@ def test_min_max_observer(): ...@@ -28,25 +28,16 @@ def test_min_max_observer():
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_sync_min_max_observer(): def test_sync_min_max_observer():
x = np.random.rand(6, 3, 3, 3).astype("float32") word_size = get_device_count_by_fork("gpu")
x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max() np_min, np_max = x.min(), x.max()
world_size = 2
port = dist.get_free_ports(1)[0]
server = dist.Server(port)
def worker(rank, slc): @dist.launcher
dist.init_process_group("localhost", port, world_size, rank, rank) def worker():
rank = dist.get_rank()
m = ob.SyncMinMaxObserver() m = ob.SyncMinMaxObserver()
y = mge.tensor(x[slc]) y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
m(y) m(y)
assert m.min_val == np_min and m.max_val == np_max assert m.min_val == np_min and m.max_val == np_max
procs = [] worker()
for rank in range(world_size):
slc = slice(rank * 3, (rank + 1) * 3)
p = mp.Process(target=worker, args=(rank, slc,), daemon=True)
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册