提交 80d3a06f 编写于 作者: M Megvii Engine Team

fix(mge/param_pack): fix param pack with just one param

GitOrigin-RevId: 0d28a12e594c43e42a7486857dace201da2a3dfc
上级 876e799e
......@@ -82,7 +82,7 @@ class ParamPack(Module):
group = group[idx:]
if idx == 1:
# ignore param packs with only one item
self._packed_params.append(params[0])
self._packed_params.append(params[0]['tensor'])
self._grouped_params.append(params)
continue
......
......@@ -110,6 +110,42 @@ def test_static_graph_parampack():
pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
@pytest.mark.slow
def test_nopack_parampack():
net = XORNet()
net = ParamPack(net,
max_size_per_group=0,
max_nr_params_per_group=0)
opt = SGD(
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4
)
@trace(symbolic=True)
def train(data, label):
pred = net(data)
opt.zero_grad()
loss = cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return loss
@trace(symbolic=True)
def infer(data):
return net(data)
train_dataset = minibatch_generator()
losses = []
for data, label in itertools.islice(train_dataset, 2000):
loss = train(data, label)
loss = loss[0][0]
opt.step()
losses.append(loss.numpy())
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"
data, _ = next(train_dataset)
pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"
@pytest.mark.slow
def test_dynamic_graph_parampack():
net = XORNet()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册