diff --git a/imperative/python/megengine/module/module.py b/imperative/python/megengine/module/module.py index 6560302fa403f072f3671ed3f1fe5c35eeee5e6c..0838553fa3f12fd329f3985b6d4fd79851194e1b 100644 --- a/imperative/python/megengine/module/module.py +++ b/imperative/python/megengine/module/module.py @@ -600,7 +600,11 @@ class Module(metaclass=ABCMeta): k, var_shape, to_be_load_shape ) ) - var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device)) + var._reset( + type(var)( + to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True + ) + ) loaded.append(k) return set(loaded), set(skipped) diff --git a/imperative/python/test/integration/test_save_load.py b/imperative/python/test/integration/test_save_load.py index 96434d234fdc77e1a28071e660f066c42218dff0..f80af3b8c072e2d5557660c4e7f22725270b2bdb 100644 --- a/imperative/python/test/integration/test_save_load.py +++ b/imperative/python/test/integration/test_save_load.py @@ -11,6 +11,7 @@ import numpy as np import megengine as mge import megengine.autodiff as ad +import megengine.module as M import megengine.optimizer as optimizer from megengine import Parameter, tensor from megengine.module import Module @@ -26,6 +27,37 @@ class Simple(Module): return x +class Net(Module): + def __init__(self): + super().__init__() + self.fc = M.Linear(1, 1) + + def forward(self, images): + x = self.fc(images) + loss = x.mean() * 10000 + return loss + + +def test_load_state_dict_no_cache(monkeypatch): + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", "1") + net = Net() + + optim = optimizer.SGD(net.parameters(), lr=0.1) + gm = ad.GradManager().attach(net.parameters()) + state = { + "fc.weight": np.array([[0]], dtype=np.float32), + "fc.bias": np.array([0.0], dtype=np.float32), + } + net.load_state_dict(state) + images = mge.tensor([[0]], dtype=np.float32) + with gm: + loss = net(images) + gm.backward(loss) + optim.step() + optim.clear_grad() + + def test_save_load(): net = Simple() diff --git a/imperative/src/impl/ops/elemwise.cpp b/imperative/src/impl/ops/elemwise.cpp index 422e72b15aeb3c2cca7ea84604fb0bc292e3ccfa..40c69d8c05c0b2bdd9d5ee15577d540e91003a94 100644 --- a/imperative/src/impl/ops/elemwise.cpp +++ b/imperative/src/impl/ops/elemwise.cpp @@ -224,6 +224,9 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( SmallVector apply_inplace_add_on_physical_tensor( const OpDef& def, const SmallVector& inputs){ + mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(), + "This inplace modification may change the elements of other tensors. " + "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly."); auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3]; auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {