diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index f9e182ef9e10f85f978812897455ba75f45847f7..06a401599c9abf2aae988ebef1f7ef9ae2e4b70d 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -178,8 +178,10 @@ class Function: return self._default_rule(*args), self.backward def __call__(self, *args): + from ...tensor import Tensor + for arg in args: - if not isinstance(arg, core2.Tensor): + if not isinstance(arg, Tensor): raise TypeError( "op Function expect type Tensor as inputs, got {}".format(type(arg)) ) @@ -191,6 +193,8 @@ class Function: grad = Grad.key2grad[grad_key] group = [ref() for ref in grad._group] + origin_args = [Tensor(arg) for arg in args] + for grad in group: grad.suppress() outputs, backward = self._grad_rule(*args) @@ -199,13 +203,13 @@ class Function: def normalized_backward(*output_grads): input_grads = backward(*output_grads) - if isinstance(input_grads, core2.Tensor) or input_grads is None: + if isinstance(input_grads, Tensor) or input_grads is None: input_grads = (input_grads,) return input_grads if self.__single_output: outputs = (outputs,) - outputs = core2.set_grad(normalized_backward, args, outputs) + outputs = core2.set_grad(normalized_backward, origin_args, outputs) if self.__single_output: (outputs,) = outputs return outputs diff --git a/imperative/python/test/unit/core/test_function.py b/imperative/python/test/unit/core/test_function.py index 8a5e9e8efd1356bda9ec6ff4ba48707d1ac2959e..dd02efeec940bd85102265f5a09c55e027d8c5c6 100644 --- a/imperative/python/test/unit/core/test_function.py +++ b/imperative/python/test/unit/core/test_function.py @@ -347,3 +347,37 @@ def test_multiple_grad(): np.testing.assert_almost_equal(loss.numpy(), (av * 10)) np.testing.assert_almost_equal(net.a.numpy(), (av - 20)) + + +def test_inplace_forward(): + data_shape = (9, 2, 6) + av = np.random.random(data_shape).astype(np.float32) + + class MulFunc(Function): + def forward(self, a): + self.a = a + a *= 10 + return a + + def backward(self, grad_o): + return grad_o * 10 + + class Simple(Module): + def __init__(self, a): + super().__init__() + self.a = Parameter(a, dtype=np.float32) + self.layer1 = MulFunc() + + def forward(self): + x = self.layer1(self.a) + return x + + net = Simple(av) + gm = ad.GradManager().attach(net.parameters()) + opt = optimizer.SGD(net.parameters(), lr=1.0) + + opt.clear_grad() + with gm: + loss = net() + gm.backward(loss.sum()) + opt.step()