提交 dadd5086 编写于 作者: M Megvii Engine Team

test(param_pack): more cases for param pack concat

GitOrigin-RevId: 0700b548ab300811528023a334e59c1c156dcd49
上级 346d2420
......@@ -22,9 +22,11 @@ from megengine.optimizer import SGD
class Simple(Module):
def __init__(self):
def __init__(self, param_shape):
super().__init__()
self.params = [Parameter(1.0, dtype=np.float32) for i in range(10)]
self.params = [
Parameter(np.ones(param_shape), dtype=np.float32) for i in range(10)
]
def forward(self, x):
for p in self.params:
......@@ -34,51 +36,35 @@ class Simple(Module):
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_param_pack():
data = np.ones([1], dtype="float32")
@dist.launcher
@pytest.mark.parametrize(
"threshold", [0, 128, None], ids=["no_pack", "small_pack", "large_pack"]
)
@pytest.mark.parametrize("param_shape", [(16,), (128, 256), (2, 1024, 1024)])
def test_param_pack(param_shape, threshold, n_iters=100):
data = np.ones(param_shape, dtype="float32")
@dist.launcher(n_gpus=2)
def worker():
net = Simple()
opt = SGD(net.parameters(), lr=0.1)
gm = ad.GradManager().attach(
net.parameters(), callbacks=[dist.make_allreduce_cb("MEAN", dist.WORLD)]
)
opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)
for p in net.params:
np.testing.assert_equal(p.grad.numpy(), 1)
worker()
@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_param_pack_with_no_param():
data = np.ones([1], dtype="float32")
@dist.launcher
def worker():
net = Simple()
net = Simple(param_shape)
opt = SGD(net.parameters(), lr=0.1)
allreduce_cb = dist.make_allreduce_cb("MEAN", dist.WORLD)
allreduce_cb._param_pack_thd = 0
if threshold is not None:
allreduce_cb._param_pack_thd = threshold
gm = ad.GradManager().attach(net.parameters(), callbacks=[allreduce_cb])
opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)
def run():
opt.clear_grad()
with gm:
x = tensor(data)
loss = net(x)
loss = loss.sum()
gm.backward(loss)
for i in range(n_iters):
run()
for p in net.params:
np.testing.assert_equal(p.grad.numpy(), 1)
np.testing.assert_equal(p.grad.numpy(), np.ones_like(p.grad.numpy()))
worker()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册