提交 d6db4fea 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/module): set no_cache=true when loading state dict

GitOrigin-RevId: 83281a3d4756bb257a991454ccc4c3b477a21b4c
上级 fea1bba2
......@@ -600,7 +600,11 @@ class Module(metaclass=ABCMeta):
k, var_shape, to_be_load_shape
)
)
var._reset(type(var)(to_be_load, dtype=to_be_load.dtype, device=var.device))
var._reset(
type(var)(
to_be_load, dtype=to_be_load.dtype, device=var.device, no_cache=True
)
)
loaded.append(k)
return set(loaded), set(skipped)
......
......@@ -11,6 +11,7 @@ import numpy as np
import megengine as mge
import megengine.autodiff as ad
import megengine.module as M
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import Module
......@@ -26,6 +27,37 @@ class Simple(Module):
return x
class Net(Module):
def __init__(self):
super().__init__()
self.fc = M.Linear(1, 1)
def forward(self, images):
x = self.fc(images)
loss = x.mean() * 10000
return loss
def test_load_state_dict_no_cache(monkeypatch):
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", "1")
net = Net()
optim = optimizer.SGD(net.parameters(), lr=0.1)
gm = ad.GradManager().attach(net.parameters())
state = {
"fc.weight": np.array([[0]], dtype=np.float32),
"fc.bias": np.array([0.0], dtype=np.float32),
}
net.load_state_dict(state)
images = mge.tensor([[0]], dtype=np.float32)
with gm:
loss = net(images)
gm.backward(loss)
optim.step()
optim.clear_grad()
def test_save_load():
net = Simple()
......
......@@ -224,6 +224,9 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs){
mgb_assert(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs correctly.");
auto dest = inputs[0], delta = inputs[1],
alpha = inputs[2], beta = inputs[3];
auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册