提交 65432d3b 编写于 作者: M Megvii Engine Team

fix(mge/module): fix torch subgraph under jit.trace with symbolic=False

GitOrigin-RevId: a208ba79d964baf78bdd9d10264dcb9166bb8506
上级 862de28a
...@@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase): ...@@ -305,6 +305,8 @@ class PyTorchSubgraphImplOpr(mgb.craniotome.CraniotomeBase):
ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs") ret.__dict__["_last_forward_inputs"] = d0.pop("_last_forward_inputs")
ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs") ret.__dict__["_last_forward_outputs"] = d0.pop("_last_forward_outputs")
ret.__dict__["_last_forward_params"] = d0.pop("_last_forward_params")
ret.__dict__["_func"] = d0.pop("_func")
d0.pop("_grad_opr") d0.pop("_grad_opr")
later_copy = self._grad_opr in _copy_dict later_copy = self._grad_opr in _copy_dict
......
...@@ -13,8 +13,11 @@ from helpers import randomTorch ...@@ -13,8 +13,11 @@ from helpers import randomTorch
import megengine as mge import megengine as mge
import megengine._internal as mgb import megengine._internal as mgb
import megengine.functional import megengine.functional
import megengine.optimizer as optimizer
from megengine import get_default_device, set_default_device from megengine import get_default_device, set_default_device
from megengine.core import Parameter, tensor from megengine.core import Parameter, tensor
from megengine.jit import trace
from megengine.module import Module as MGEModule
from megengine.module.pytorch import PyTorchModule from megengine.module.pytorch import PyTorchModule
from megengine.test import assertTensorClose from megengine.test import assertTensorClose
...@@ -72,3 +75,68 @@ def test_pytorch_backward(): ...@@ -72,3 +75,68 @@ def test_pytorch_backward():
return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False) return mge.functional.grad(mge_e, mge_a, use_virtual_grad=False)
assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy()) assertTensorClose(get_pytorch_backward().numpy(), get_mge_backward().numpy())
def test_pytorch_mixed():
init_param = (np.array([2.0], dtype=np.float32), np.array([3.0], dtype=np.float32))
lr = 1.0
class Mixed(MGEModule):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.multiplier = torch.nn.Parameter(torch.tensor(init_param[0]))
def forward(self, inp):
return inp * self.multiplier
def __init__(self):
super().__init__()
self.torch_module = PyTorchModule(self.SubModule())
a = list(self.SubModule().named_parameters(recurse=True))
a = list(self.SubModule().parameters())
self.multiplier = Parameter(np.array(init_param[1]), dtype=np.float32)
def forward(self, inp):
return self.torch_module(inp) * self.multiplier
def run(step, enable_trace, use_symbolic):
def train_func(data, net=None, opt=None):
pred = net(data)
opt.backward(pred)
return pred
if enable_trace:
train_func = trace(train_func, symbolic=use_symbolic)
net = Mixed()
data = tensor()
opt = optimizer.SGD(net.parameters(), lr=lr)
saved_param = init_param
for i in range(step):
opt.zero_grad()
data.set_value([i + 1.0])
output = train_func(data, net=net, opt=opt)
opt.step()
expect_param = (
saved_param[0] - lr * saved_param[1] * data.numpy(),
saved_param[1] - lr * saved_param[0] * data.numpy(),
)
assertTensorClose(
output.numpy(), saved_param[0] * saved_param[1] * data.numpy()
)
torch_param = net.torch_module._torch_params[0].detach().cpu()
assertTensorClose(torch_param.numpy(), expect_param[0])
assertTensorClose(net.multiplier.numpy(), expect_param[1])
saved_param = expect_param
run(1, False, False)
run(1, True, True)
run(1, True, False)
run(2, False, False)
run(2, True, True)
run(2, True, False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册