提交 a49d5bf9 编写于 作者: M Megvii Engine Team

fix(autodiff): fix inplace operation on autodiff.Function

GitOrigin-RevId: a658680f35dead712faab6f0b094cda9caa56e76
上级 7252825c
......@@ -178,8 +178,10 @@ class Function:
return self._default_rule(*args), self.backward
def __call__(self, *args):
from ...tensor import Tensor
for arg in args:
if not isinstance(arg, core2.Tensor):
if not isinstance(arg, Tensor):
raise TypeError(
"op Function expect type Tensor as inputs, got {}".format(type(arg))
)
......@@ -191,6 +193,8 @@ class Function:
grad = Grad.key2grad[grad_key]
group = [ref() for ref in grad._group]
origin_args = [Tensor(arg) for arg in args]
for grad in group:
grad.suppress()
outputs, backward = self._grad_rule(*args)
......@@ -199,13 +203,13 @@ class Function:
def normalized_backward(*output_grads):
input_grads = backward(*output_grads)
if isinstance(input_grads, core2.Tensor) or input_grads is None:
if isinstance(input_grads, Tensor) or input_grads is None:
input_grads = (input_grads,)
return input_grads
if self.__single_output:
outputs = (outputs,)
outputs = core2.set_grad(normalized_backward, args, outputs)
outputs = core2.set_grad(normalized_backward, origin_args, outputs)
if self.__single_output:
(outputs,) = outputs
return outputs
......
......@@ -347,3 +347,37 @@ def test_multiple_grad():
np.testing.assert_almost_equal(loss.numpy(), (av * 10))
np.testing.assert_almost_equal(net.a.numpy(), (av - 20))
def test_inplace_forward():
data_shape = (9, 2, 6)
av = np.random.random(data_shape).astype(np.float32)
class MulFunc(Function):
def forward(self, a):
self.a = a
a *= 10
return a
def backward(self, grad_o):
return grad_o * 10
class Simple(Module):
def __init__(self, a):
super().__init__()
self.a = Parameter(a, dtype=np.float32)
self.layer1 = MulFunc()
def forward(self):
x = self.layer1(self.a)
return x
net = Simple(av)
gm = ad.GradManager().attach(net.parameters())
opt = optimizer.SGD(net.parameters(), lr=1.0)
opt.clear_grad()
with gm:
loss = net()
gm.backward(loss.sum())
opt.step()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册