test_save_load.py 1.5 KB
Newer Older
1 2 3
import numpy as np

import megengine as mge
4
import megengine.autodiff as ad
5 6 7 8 9 10 11 12
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.core.tensor.raw_tensor import RawTensor
from megengine.module import Module


class Simple(Module):
    def __init__(self):
13
        super().__init__()
14
        self.a = Parameter([1.23], dtype=np.float32)
15 16 17 18 19 20 21 22 23 24

    def forward(self, x):
        x = x * self.a
        return x


def test_save_load():
    net = Simple()

    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
25
    optim.clear_grad()
M
Megvii Engine Team 已提交
26
    gm = ad.GradManager().attach(net.parameters())
27 28 29

    data = tensor([2.34])

M
Megvii Engine Team 已提交
30
    with gm:
31
        loss = net(data)
32
        gm.backward(loss)
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

    optim.step()

    model_name = "simple.pkl"
    print("save to {}".format(model_name))

    mge.save(
        {
            "name": "simple",
            "state_dict": net.state_dict(),
            "opt_state": optim.state_dict(),
        },
        model_name,
    )

    # Load param to cpu
    checkpoint = mge.load(model_name, map_location="cpu0")
50
    device_save = mge.get_default_device()
51 52 53 54 55 56 57
    mge.set_default_device("cpu0")
    net = Simple()
    net.load_state_dict(checkpoint["state_dict"])
    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
    optim.load_state_dict(checkpoint["opt_state"])
    print("load done")

M
Megvii Engine Team 已提交
58
    with gm:
59
        loss = net([1.23])
60
        gm.backward(loss)
61 62

    optim.step()
63 64
    # Restore device
    mge.set_default_device(device_save)