提交 65432d3b 编写于 作者: M Megvii Engine Team

fix(mge/module): fix torch subgraph under jit.trace with symbolic=False

GitOrigin-RevId: a208ba79d964baf78bdd9d10264dcb9166bb8506
上级 862de28a
......@@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase):
ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs")
ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs")
ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params")
ret.__dict__["_func"] = d0.pop("_func")
d0.pop("_grad_opr")
later_copy = self._grad_opr in _copy_dict
......
......@@ -13,8 +13,11 @@ from helpers import randomTorch
import megengine as mge
import megengine._internal as mgb
import megengine.functional
import megengine.optimizer as optimizer
from megengine import get_default_device, set_default_device
from megengine.core import Parameter, tensor
from megengine.jit import trace
from megengine.module import Module as MGEModule
from megengine.module.pytorch import PyTorchModule
from megengine.test import assertTensorClose
......@@ -72,3 +75,68 @@ def test_pytorch_backward():
return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False)
assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy())
def test_pytorch_mixed():
init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32))
lr = 1.0
class Mixed(MGEModule):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0]))
def forward(self, inp):
return inp * self.multiplier
def __init__(self):
super().__init__()
self.torch_module = PyTorchModule(self.SubModule())
a = list(self.SubModule().named_parameters(recurse=True))
a = list(self.SubModule().parameters())
self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32)
def forward(self, inp):
return self.torch_module(inp) * self.multiplier
def run(step, enable_trace, use_symbolic):
def train_func(data, net=None, opt=None):
pred = net(data)
opt.backward(pred)
return pred
if enable_trace:
train_func = trace(train_func, symbolic=use_symbolic)
net = Mixed()
data = tensor()
opt = optimizer.SGD(net.parameters(), lr=lr)
saved_param = init_param
for i in range(step):
opt.zero_grad()
data.set_value([i + 1.0])
output = train_func(data, net=net, opt=opt)
opt.step()
expect_param = (
saved_param[0] - lr * saved_param[1] * data.numpy(),
saved_param[1] - lr * saved_param[0] * data.numpy(),
)
assertTensorClose(
output.numpy(), saved_param[0] * saved_param[1] * data.numpy()
)
torch_param = net.torch_module._torch_params[0].detach().cpu()
assertTensorClose(torch_param.numpy(), expect_param[0])
assertTensorClose(net.multiplier.numpy(), expect_param[1])
saved_param = expect_param
run(1, False, False)
run(1, True, True)
run(1, True, False)
run(2, False, False)
run(2, True, True)
run(2, True, False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册