提交 89ed7ab2 编写于 作者: M Megvii Engine Team

test(imperative): speed up dtr test

GitOrigin-RevId: 57f092d7299d26292a62b7094e03f71128d00984
上级 a2a09ef9
...@@ -90,7 +90,7 @@ class ResNet(M.Module): ...@@ -90,7 +90,7 @@ class ResNet(M.Module):
@pytest.mark.require_ngpu(1) @pytest.mark.require_ngpu(1)
def test_dtr_resnet1202(): def test_dtr_resnet1202():
batch_size = 64 batch_size = 8
resnet1202 = ResNet(BasicBlock, [200, 200, 200]) resnet1202 = ResNet(BasicBlock, [200, 200, 200])
opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4) opt = optim.SGD(resnet1202.parameters(), lr=0.05, momentum=0.9, weight_decay=1e-4)
gm = GradManager().attach(resnet1202.parameters()) gm = GradManager().attach(resnet1202.parameters())
...@@ -103,12 +103,24 @@ def test_dtr_resnet1202(): ...@@ -103,12 +103,24 @@ def test_dtr_resnet1202():
gm.backward(loss) gm.backward(loss)
return pred, loss return pred, loss
_, free_mem = mge.device.get_mem_status_bytes()
tensor_mem = free_mem - (2 ** 30)
if tensor_mem > 0:
x = np.ones((1, int(tensor_mem / 4)), dtype=np.float32)
else:
x = np.ones((1,), dtype=np.float32)
t = mge.tensor(x)
mge.dtr.enable() mge.dtr.enable()
mge.dtr.enable_sqrt_sampling = True
data = np.random.randn(batch_size, 3, 32, 32).astype("float32") data = np.random.randn(batch_size, 3, 32, 32).astype("float32")
label = np.random.randint(0, 10, size=(batch_size,)).astype("int32") label = np.random.randint(0, 10, size=(batch_size,)).astype("int32")
for step in range(10): for _ in range(2):
opt.clear_grad() opt.clear_grad()
_, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm) _, loss = train_func(mge.tensor(data), mge.tensor(label), net=resnet1202, gm=gm)
opt.step() opt.step()
loss.item() loss.item()
t.numpy()
mge.dtr.disable()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册