From d4ffdb89772a2382f8a8c7fa595166986c7063fd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 28 Apr 2021 16:52:13 +0800 Subject: [PATCH] test(param_pack): more cases for param pack concat GitOrigin-RevId: 0700b548ab300811528023a334e59c1c156dcd49 --- .../test/integration/test_param_pack.py | 68 ++++++++----------- 1 file changed, 27 insertions(+), 41 deletions(-) diff --git a/imperative/python/test/integration/test_param_pack.py b/imperative/python/test/integration/test_param_pack.py index 8c867fb58..42ffc1557 100644 --- a/imperative/python/test/integration/test_param_pack.py +++ b/imperative/python/test/integration/test_param_pack.py @@ -22,9 +22,11 @@ from megengine.optimizer import SGD class Simple(Module): - def __init__(self): + def __init__(self, param_shape): super().__init__() - self.params = [Parameter(1.0, dtype=np.float32) for i in range(10)] + self.params = [ + Parameter(np.ones(param_shape), dtype=np.float32) for i in range(10) + ] def forward(self, x): for p in self.params: @@ -34,51 +36,35 @@ class Simple(Module): @pytest.mark.require_ngpu(2) @pytest.mark.isolated_distributed -def test_param_pack(): - data = np.ones([1], dtype="float32") - - @dist.launcher +@pytest.mark.parametrize( + "threshold", [0, 128, None], ids=["no_pack", "small_pack", "large_pack"] +) +@pytest.mark.parametrize("param_shape", [(16,), (128, 256), (2, 1024, 1024)]) +def test_param_pack(param_shape, threshold, n_iters=100): + data = np.ones(param_shape, dtype="float32") + + @dist.launcher(n_gpus=2) def worker(): - net = Simple() - opt = SGD(net.parameters(), lr=0.1) - - gm = ad.GradManager().attach( - net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] - ) - - opt.clear_grad() - with gm: - x = tensor(data) - loss = net(x) - loss = loss.sum() - gm.backward(loss) - for p in net.params: - np.testing.assert_equal(p.grad.numpy(), 1) - - worker() - - -@pytest.mark.require_ngpu(2) -@pytest.mark.isolated_distributed -def test_param_pack_with_no_param(): - data = np.ones([1], dtype="float32") - - @dist.launcher - def worker(): - net = Simple() + net = Simple(param_shape) opt = SGD(net.parameters(), lr=0.1) allreduce_cb = dist.make_allreduce_cb("MEAN", dist.WORLD) - allreduce_cb._param_pack_thd = 0 + if threshold is not None: + allreduce_cb._param_pack_thd = threshold gm = ad.GradManager().attach(net.parameters(), callbacks=[allreduce_cb]) - opt.clear_grad() - with gm: - x = tensor(data) - loss = net(x) - loss = loss.sum() - gm.backward(loss) + def run(): + opt.clear_grad() + with gm: + x = tensor(data) + loss = net(x) + loss = loss.sum() + gm.backward(loss) + + for i in range(n_iters): + run() + for p in net.params: - np.testing.assert_equal(p.grad.numpy(), 1) + np.testing.assert_equal(p.grad.numpy(), np.ones_like(p.grad.numpy())) worker() -- GitLab