提交 696d2c2e 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): check tensors to be attached

GitOrigin-RevId: c4f3f808764bac2d852e34bc118a76cde313b0c7
上级 984d85ca
import weakref import weakref
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable from typing import Callable, Iterable
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad from ..core.autodiff.grad import Grad
...@@ -121,7 +121,7 @@ class GradManager: ...@@ -121,7 +121,7 @@ class GradManager:
self._after_backward_callback = [] self._after_backward_callback = []
self._gradients = {} self._gradients = {}
def attach(self, tensors: list, callbacks=None): def attach(self, tensors: Iterable[Tensor], callbacks=None):
r""" r"""
Instruct GradManager to track operations on tensors, so that gradients with respect Instruct GradManager to track operations on tensors, so that gradients with respect
to those tensors could be evaluated later. to those tensors could be evaluated later.
...@@ -199,6 +199,7 @@ class GradManager: ...@@ -199,6 +199,7 @@ class GradManager:
return spec return spec
for x in tensors: for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
spec = self._attach_specs.get(id(x)) spec = self._attach_specs.get(id(x))
new_attach = spec is None new_attach = spec is None
if spec is None: if spec is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册