From d6db4fea93eecc2b6382d9e6343e78b12a1c61af Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 1 Jul 2021 13:51:11 +0800 Subject: [PATCH] fix(mge/module): set no_cache=true when loading state dict GitOrigin-RevId: 83281a3d4756bb257a991454ccc4c3b477a21b4c --- imperative/python/megengine/module/module.py | 6 +++- .../python/test/integration/test_save_load.py | 32 +++++++++++++++++++ imperative/src/impl/ops/elemwise.cpp | 3 ++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 6560302fa..0838553fa 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -600,7 +600,11 @@ class Module(metaclass=ABCMeta): k, var_shape, to_be_load_shape ) ) - var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) + var._reset( + type(var)( + to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True + ) + ) loaded.append(k) return set(loaded), set(skipped) diff --git a/imperative/python/test/integration/test_save_load.py b/imperative/python/test/integration/test_save_load.py index 96434d234..f80af3b8c 100644 --- a/imperative/python/test/integration/test_save_load.py +++ b/imperative/python/test/integration/test_save_load.py @@ -11,6 +11,7 @@ import numpy as np import megengine as mge import megengine.autodiff as ad +import megengine.module as M import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -26,6 +27,37 @@ class Simple(Module): return x +class Net(Module): + def __init__(self): + super().__init__() + self.fc = M.Linear(1, 1) + + def forward(self, images): + x = self.fc(images) + loss = x.mean() * 10000 + return loss + + +def test_load_state_dict_no_cache(monkeypatch): + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", "1") + net = Net() + + optim = optimizer.SGD(net.parameters(), lr=0.1) + gm = ad.GradManager().attach(net.parameters()) + state = { + "fc.weight": np.array([[0]], dtype=np.float32), + "fc.bias": np.array([0.0], dtype=np.float32), + } + net.load_state_dict(state) + images = mge.tensor([[0]], dtype=np.float32) + with gm: + loss = net(images) + gm.backward(loss) + optim.step() + optim.clear_grad() + + def test_save_load(): net = Simple() diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 422e72b15..40c69d8c0 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -224,6 +224,9 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( SmallVector apply_inplace_add_on_physical_tensor( const OpDef& def, const SmallVector& inputs){ + mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(), + "This inplace modification may change the elements of other tensors. " + "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly."); auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; auto tensor_to_scalar = [](const TensorPtr& tensor) -> float { -- GitLab