test_save_load.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
import numpy as np

import megengine as mge
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.core.tensor.raw_tensor import RawTensor
from megengine.module import Module


class Simple(Module):
    def __init__(self):
12
        super().__init__()
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        self.a = Parameter(1.23, dtype=np.float32)

    def forward(self, x):
        x = x * self.a
        return x


def test_save_load():
    net = Simple()

    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
    optim.zero_grad()

    data = tensor([2.34])

    with optim.record():
        loss = net(data)
        optim.backward(loss)

    optim.step()

    model_name = "simple.pkl"
    print("save to {}".format(model_name))

    mge.save(
        {
            "name": "simple",
            "state_dict": net.state_dict(),
            "opt_state": optim.state_dict(),
        },
        model_name,
    )

    # Load param to cpu
    checkpoint = mge.load(model_name, map_location="cpu0")
48
    device_save = mge.get_default_device()
49 50 51 52 53 54 55 56 57 58 59 60
    mge.set_default_device("cpu0")
    net = Simple()
    net.load_state_dict(checkpoint["state_dict"])
    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
    optim.load_state_dict(checkpoint["opt_state"])
    print("load done")

    with optim.record():
        loss = net([1.23])
        optim.backward(loss)

    optim.step()
61 62
    # Restore device
    mge.set_default_device(device_save)