提交 696d2c2e 编写于 作者: M Megvii Engine Team

fix(mge/autodiff): check tensors to be attached

GitOrigin-RevId: c4f3f808764bac2d852e34bc118a76cde313b0c7
上级 984d85ca
import weakref
from collections import defaultdict
from contextlib import contextmanager
from typing import Callable
from typing import Callable, Iterable
from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option
from ..core.autodiff.grad import Grad
......@@ -121,7 +121,7 @@ class GradManager:
self._after_backward_callback = []
self._gradients = {}
def attach(self, tensors: list, callbacks=None):
def attach(self, tensors: Iterable[Tensor], callbacks=None):
r"""
Instruct GradManager to track operations on tensors, so that gradients with respect
to those tensors could be evaluated later.
......@@ -199,6 +199,7 @@ class GradManager:
return spec
for x in tensors:
assert isinstance(x, Tensor), "Object to be attached should be Tensor"
spec = self._attach_specs.get(id(x))
new_attach = spec is None
if spec is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册