test_save_load.py 2.7 KB
Newer Older
1 2
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
3
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
4 5 6 7 8 9
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os

10 11 12
import numpy as np

import megengine as mge
13
import megengine.autodiff as ad
14
import megengine.module as M
15 16 17 18 19 20 21
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import Module


class Simple(Module):
    def __init__(self):
22
        super().__init__()
23
        self.a = Parameter([1.23], dtype=np.float32)
24 25 26 27 28 29

    def forward(self, x):
        x = x * self.a
        return x


30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
class Net(Module):
    def __init__(self):
        super().__init__()
        self.fc = M.Linear(1, 1)

    def forward(self, images):
        x = self.fc(images)
        loss = x.mean() * 10000
        return loss


def test_load_state_dict_no_cache(monkeypatch):
    with monkeypatch.context() as mk:
        mk.setenv("MEGENGINE_INPLACE_UPDATE", "1")
        net = Net()

        optim = optimizer.SGD(net.parameters(), lr=0.1)
        gm = ad.GradManager().attach(net.parameters())
        state = {
            "fc.weight": np.array([[0]], dtype=np.float32),
            "fc.bias": np.array([0.0], dtype=np.float32),
        }
        net.load_state_dict(state)
        images = mge.tensor([[0]], dtype=np.float32)
        with gm:
            loss = net(images)
            gm.backward(loss)
            optim.step()
            optim.clear_grad()


61 62 63 64
def test_save_load():
    net = Simple()

    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
65
    optim.clear_grad()
M
Megvii Engine Team 已提交
66
    gm = ad.GradManager().attach(net.parameters())
67 68 69

    data = tensor([2.34])

M
Megvii Engine Team 已提交
70
    with gm:
71
        loss = net(data)
72
        gm.backward(loss)
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

    optim.step()

    model_name = "simple.pkl"
    print("save to {}".format(model_name))

    mge.save(
        {
            "name": "simple",
            "state_dict": net.state_dict(),
            "opt_state": optim.state_dict(),
        },
        model_name,
    )

    # Load param to cpu
    checkpoint = mge.load(model_name, map_location="cpu0")
90
    device_save = mge.get_default_device()
91 92 93 94 95 96
    mge.set_default_device("cpu0")
    net = Simple()
    net.load_state_dict(checkpoint["state_dict"])
    optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
    optim.load_state_dict(checkpoint["opt_state"])
    print("load done")
97
    os.remove("simple.pkl")
98

M
Megvii Engine Team 已提交
99
    with gm:
100
        loss = net([1.23])
101
        gm.backward(loss)
102 103

    optim.step()
104 105
    # Restore device
    mge.set_default_device(device_save)