diff --git a/python_module/megengine/module/parampack.py b/python_module/megengine/module/parampack.py index d91aee83ea1cea41a4a11e33feecb3ea8a1c372e..fad1e27df0b04e8decc5474678cfc6a88454621e 100644 --- a/python_module/megengine/module/parampack.py +++ b/python_module/megengine/module/parampack.py @@ -82,7 +82,7 @@ class ParamPack(Module): group = group[idx:] if idx == 1: # ignore param packs with only one item - self._packed_params.append(params[0]) + self._packed_params.append(params[0]['tensor']) self._grouped_params.append(params) continue diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py index 678971a2f8e9fb8baa7c5c79a057b4bbf6b55850..0a1b3780c6dd593e5338b6800cc845f018ad9fa8 100644 --- a/python_module/test/integration/test_parampack.py +++ b/python_module/test/integration/test_parampack.py @@ -110,6 +110,42 @@ def test_static_graph_parampack(): pred = infer(data).numpy() assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" +@pytest.mark.slow +def test_nopack_parampack(): + net = XORNet() + net = ParamPack(net, + max_size_per_group=0, + max_nr_params_per_group=0) + opt = SGD( + net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + @trace(symbolic=True) + def train(data, label): + pred = net(data) + opt.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + @trace(symbolic=True) + def infer(data): + return net(data) + + train_dataset = minibatch_generator() + losses = [] + + for data, label in itertools.islice(train_dataset, 2000): + loss = train(data, label) + loss = loss[0][0] + opt.step() + losses.append(loss.numpy()) + assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" + + data, _ = next(train_dataset) + pred = infer(data).numpy() + assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + @pytest.mark.slow def test_dynamic_graph_parampack(): net = XORNet()