From cae8c8a4a40076cc7431459225903adde7fa7adc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 10 Nov 2020 12:54:00 +0800 Subject: [PATCH] test(mge/parampack): add test for parampack GitOrigin-RevId: f3e423dffce35fb6f6a9b62f96fe4635c0fecfd5 --- .../test/integration/test_param_pack.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 imperative/python/test/integration/test_param_pack.py diff --git a/imperative/python/test/integration/test_param_pack.py b/imperative/python/test/integration/test_param_pack.py new file mode 100644 index 00000000..35ac665c --- /dev/null +++ b/imperative/python/test/integration/test_param_pack.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import multiprocessing as mp +import platform + +import numpy as np +import pytest + +import megengine +import megengine.autodiff as ad +import megengine.distributed as dist +import megengine.optimizer as optimizer +from megengine import Parameter, tensor +from megengine.distributed.helper import get_device_count_by_fork +from megengine.module import Module +from megengine.optimizer import SGD + + +class Simple(Module): + def __init__(self): + super().__init__() + self.params = [Parameter(1.0, dtype=np.float32) for i in range(10)] + + def forward(self, x): + for p in self.params: + x = x * p + return x + + +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +def test_param_pack(): + data = np.ones([1], dtype="float32") + + @dist.launcher + def worker(): + net = Simple() + opt = SGD(net.parameters(), lr=0.1) + + gm = ad.GradManager().attach( + net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)] + ) + + opt.clear_grad() + with gm: + x = tensor(data) + loss = net(x) + loss = loss.sum() + gm.backward(loss) + for p in net.params: + np.testing.assert_equal(p.grad.numpy(), 1) + + worker() + + +@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") +@pytest.mark.isolated_distributed +@pytest.mark.skipif( + platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" +) +def test_param_pack_with_no_param(): + data = np.ones([1], dtype="float32") + + @dist.launcher + def worker(): + net = Simple() + opt = SGD(net.parameters(), lr=0.1) + + allreduce_cb = dist.make_allreduce_cb("MEAN", dist.WORLD) + allreduce_cb._param_pack_thd = 0 + gm = ad.GradManager().attach(net.parameters(), callbacks=[allreduce_cb]) + + opt.clear_grad() + with gm: + x = tensor(data) + loss = net(x) + loss = loss.sum() + gm.backward(loss) + for p in net.params: + np.testing.assert_equal(p.grad.numpy(), 1) + + worker() -- GitLab