提交 ff05667b 编写于 作者: M Megvii Engine Team

test(dist): refactor distributed test with fixtures

GitOrigin-RevId: e69acb72da257ce75d096e0568a43552930a26c3
上级 9fb8444d
import os
import platform
import sys
import pytest
import megengine.functional
import megengine.module
from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync
from megengine.distributed.helper import get_device_count_by_fork
from megengine.jit import trace as _trace
from megengine.module import Linear, Module
sys.path.append(os.path.join(os.path.dirname(__file__), "helpers"))
_ngpu = get_device_count_by_fork("gpu")
@pytest.fixture(autouse=True)
def skip_by_ngpu(request):
if request.node.get_closest_marker("require_ngpu"):
require_ngpu = int(request.node.get_closest_marker("require_ngpu").args[0])
if require_ngpu > _ngpu:
pytest.skip("skipped for ngpu unsatisfied: {}".format(require_ngpu))
@pytest.fixture(autouse=True)
def skip_distributed(request):
if request.node.get_closest_marker("distributed_isolated"):
if platform.system() in ("Windows", "Darwin"):
pytest.skip(
"skipped for distributed unsupported at platform: {}".format(
platform.system()
)
)
......@@ -21,7 +21,6 @@ import megengine.autodiff as ad
import megengine.distributed as dist
import megengine.functional as F
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.debug_param import set_conv_execution_strategy
from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module
from megengine.optimizer import SGD
......@@ -194,11 +193,8 @@ def run_test(
worker(max_err)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
def test_dp_correctness():
model_name = "mnist_model_with_test.mge"
model_path = os.path.join(os.path.dirname(__file__), model_name)
......
......@@ -32,11 +32,8 @@ class Simple(Module):
return x
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
def test_param_pack():
data = np.ones([1], dtype="float32")
......@@ -61,11 +58,8 @@ def test_param_pack():
worker()
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
def test_param_pack_with_no_param():
data = np.ones([1], dtype="float32")
......
......@@ -10,6 +10,7 @@ import itertools
import os
import numpy as np
import pytest
import megengine
import megengine.autodiff as ad
......@@ -22,57 +23,19 @@ from megengine.module import Module
class Simple(Module):
def __init__(self):
super().__init__()
self.a = Parameter([1.23], dtype=np.float32)
self.a = Parameter([1.23], dtype="float32")
def forward(self, x):
x = x * self.a
return x
def test_sgd_momentum():
net = Simple()
@pytest.mark.parametrize("trace_mode", [True, False, None])
@pytest.mark.parametrize("inplace_mode", [True, False])
def test_sgd_momentum(monkeypatch, trace_mode, inplace_mode):
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
optim.clear_grad()
gm = ad.GradManager().attach(net.parameters())
data = tensor([2.34])
# do a step of train
with gm:
loss = net(data)
gm.backward(loss)
optim.step()
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)
# do a step of infer
loss = net(data)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(optim._state[net.a]["momentum_buffer"].numpy(), 2.34)
# do a step of train
optim.clear_grad()
with gm:
loss = net(data)
gm.backward(loss)
optim.step()
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5
)
def test_sgd_momentum_trace():
origin_inplace = os.getenv("MEGENGINE_INPLACE_UPDATE")
symbolic = (True, False)
inplace = (0, 1)
for symbolic, inplace in itertools.product(symbolic, inplace):
os.environ["MEGENGINE_INPLACE_UPDATE"] = str(inplace)
@trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None, gm=None):
optim.clear_grad()
with gm:
......@@ -81,11 +44,16 @@ def test_sgd_momentum_trace():
optim.step()
return loss
@trace(symbolic=symbolic)
if trace_mode is not None:
train_func = trace(symbolic=trace_mode)(train_func)
def eval_func(data, *, model=None, optim=None, gm=None):
loss = net(data)
return loss
if trace_mode is not None:
eval_func = trace(symbolic=trace_mode)(eval_func)
net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
gm = ad.GradManager().attach(net.parameters())
......@@ -109,7 +77,3 @@ def test_sgd_momentum_trace():
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5
)
if origin_inplace:
os.environ["MEGENGINE_INPLACE_UPDATE"] = origin_inplace
else:
del os.environ["MEGENGINE_INPLACE_UPDATE"]
......@@ -133,15 +133,12 @@ def test_regression_1762():
gm.backward(loss)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_remote_grad():
@pytest.mark.parametrize(
"trace_mode", [True, False, None], ids=["symbolic", "trace", "no_trace"]
)
def test_remote_grad(trace_mode):
@dist.launcher
def worker():
rank = dist.get_rank()
......@@ -166,14 +163,10 @@ def test_remote_grad():
gm.backward(y)
opt.step().clear_grad()
train_funcs = [
train_func,
trace(symbolic=False)(train_func),
trace(symbolic=True)(train_func),
]
if trace_mode is not None:
train_func = trace(symbolic=trace_mode)(train_func)
for func in train_funcs:
for i in range(3):
func(x)
for i in range(3):
train_func(x)
worker()
......@@ -51,13 +51,7 @@ def save_to(self, name="grad"):
return callback
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_dist_grad():
world_size = 2
......
......@@ -37,20 +37,15 @@ def _assert_q_val(q, val):
assert ret == val
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("backend", ["nccl"])
@pytest.mark.isolated_distributed
def test_init_process_group():
def test_init_process_group(backend):
world_size = 2
server = dist.Server()
port = server.py_server_port
def worker(rank, backend):
def worker(rank):
dist.init_process_group("localhost", port, world_size, rank, rank, backend)
assert dist.is_distributed() == True
assert dist.get_rank() == rank
......@@ -67,27 +62,18 @@ def test_init_process_group():
assert isinstance(dist.get_client(), dist.Client)
def check(backend):
procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, backend))
p.start()
procs.append(p)
for p in procs:
p.join(20)
assert p.exitcode == 0
procs = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank,))
p.start()
procs.append(p)
check("nccl")
for p in procs:
p.join(20)
assert p.exitcode == 0
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 3, reason="need more gpu device")
@pytest.mark.require_ngpu(3)
@pytest.mark.isolated_distributed
def test_new_group():
world_size = 3
......@@ -106,13 +92,7 @@ def test_new_group():
worker()
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_group_barrier():
world_size = 2
......@@ -142,13 +122,7 @@ def test_group_barrier():
assert p.exitcode == 0
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_synchronized():
world_size = 2
......@@ -186,17 +160,9 @@ def test_synchronized():
assert p.exitcode == 0
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_user_set_get():
world_size = 2
@dist.launcher
def worker():
# set in race condition
......
......@@ -33,15 +33,10 @@ from megengine.functional.distributed import (
)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_reduce_sum():
def test_reduce_sum(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -52,27 +47,18 @@ def test_reduce_sum():
else:
assert np.allclose(output.numpy(), 0)
def check(shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, None)
worker(data, expect)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = x + y
data = (x, y)
expect = (z, None)
worker(data, expect)
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_broadcast():
def test_broadcast(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -80,26 +66,17 @@ def test_broadcast():
output = broadcast(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape)
y = x + 1
data = (x, y)
expect = (x, x)
worker(data, expect)
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
x = np.random.random_sample(shape).astype("float32")
y = x + 1
data = (x, y)
expect = (x, x)
worker(data, expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_all_gather():
def test_all_gather(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -107,27 +84,18 @@ def test_all_gather():
output = all_gather(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y))
data = (x, y)
expect = (z, z)
worker(data, expect)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = np.concatenate((x, y))
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (88, 44)], ids=str)
@pytest.mark.isolated_distributed
def test_reduce_scatter_sum():
def test_reduce_scatter_sum(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -135,27 +103,18 @@ def test_reduce_scatter_sum():
output = reduce_scatter_sum(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = x + y
data = (x, y)
expect = (z[: shape[0] // 2], z[shape[0] // 2 :])
worker(data, expect)
for shape in [(2, 4), (8, 10), (88, 44)]:
check(shape)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = x + y
data = (x, y)
expect = (z[: shape[0] // 2], z[shape[0] // 2 :])
worker(data, expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_all_reduce_sum():
def test_all_reduce_sum(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -163,27 +122,18 @@ def test_all_reduce_sum():
output = all_reduce_sum(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, z)
worker(data, expect)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = x + y
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_all_reduce_max():
def test_all_reduce_max(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -191,27 +141,18 @@ def test_all_reduce_max():
output = all_reduce_max(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.maximum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = np.maximum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_all_reduce_min():
def test_all_reduce_min(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -219,27 +160,18 @@ def test_all_reduce_min():
output = all_reduce_min(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.minimum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = np.minimum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_gather():
def test_gather(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -250,27 +182,18 @@ def test_gather():
else:
assert np.allclose(output.numpy(), 0)
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
z = np.concatenate((x, y))
data = (x, y)
expect = (z, None)
worker(data, expect)
for shape in [(2, 3), (8, 10), (99, 77)]:
check(shape)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = np.concatenate((x, y))
data = (x, y)
expect = (z, None)
worker(data, expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (100, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_scatter():
def test_scatter(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -278,26 +201,17 @@ def test_scatter():
output = scatter(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = x + 1
data = (x, y)
expect = (x[: shape[0] // 2], x[shape[0] // 2 :])
worker(data, expect)
x = np.random.random_sample(shape).astype("float32")
y = x + 1
data = (x, y)
expect = (x[: shape[0] // 2], x[shape[0] // 2 :])
worker(data, expect)
for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (100, 77)], ids=str)
@pytest.mark.isolated_distributed
def test_all_to_all():
def test_all_to_all(shape):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
......@@ -305,28 +219,19 @@ def test_all_to_all():
output = all_to_all(inp)
assert np.allclose(output.numpy(), expect[rank])
def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
data = (x, y)
expect = (a, b)
worker(data, expect)
for shape in [(2, 3), (8, 10), (100, 77)]:
check(shape)
x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
data = (x, y)
expect = (a, b)
worker(data, expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_io_remote():
@pytest.mark.parametrize("shape", [(), (1,), (4, 5)], ids=str)
def test_io_remote(shape):
@dist.launcher(n_gpus=2)
def worker(val, shape):
rank = dist.get_rank()
......@@ -339,6 +244,5 @@ def test_io_remote():
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())
for shape in [(), (1,), (4, 5)]:
val = np.random.rand(*shape)
worker(val, shape)
val = np.random.random_sample(shape).astype("float32")
worker(val, shape)
......@@ -355,26 +355,17 @@ def copy_test(dst, src):
assert np.allclose(data, z.numpy())
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") == 0, reason="CUDA is disabled")
@pytest.mark.require_ngpu(1)
def test_copy_h2d():
copy_test("cpu0", "gpu0")
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") == 0, reason="CUDA is disabled")
@pytest.mark.require_ngpu(1)
def test_copy_d2h():
copy_test("gpu0", "cpu0")
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
def test_copy_d2d():
copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1")
......@@ -22,13 +22,7 @@ from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
_assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_syncbn():
nr_chan = 8
......@@ -125,9 +119,6 @@ def test_batchnorm():
_assert_allclose(yv1.numpy(), yv_expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
def test_syncbn1d():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -215,9 +206,6 @@ def test_batchnorm2d():
_assert_allclose(yv1.numpy(), yv_expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
def test_syncbn2d():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......@@ -285,9 +273,6 @@ def test_batchnorm_no_stats():
_assert_allclose(yv.numpy(), yv_expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
def test_syncbn_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 4)
......@@ -333,9 +318,6 @@ def test_batchnorm2d_no_stats():
_assert_allclose(yv.numpy(), yv_expect)
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
def test_syncbn2d_no_stats():
nr_chan = 8
data_shape = (3, nr_chan, 16, 16)
......
......@@ -65,13 +65,7 @@ def test_passive_observer():
assert m.get_qparams() == {"scale": mge.tensor(2.0)}
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_sync_min_max_observer():
word_size = get_device_count_by_fork("gpu")
......@@ -89,13 +83,7 @@ def test_sync_min_max_observer():
worker()
@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
word_size = get_device_count_by_fork("gpu")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册