diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 3d569795941273d25910c7c64e4a61efa20e93ad..72c689fd4453bcb4143d6955dacc1eef8c0a77f4 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,7 +1,7 @@ import weakref from collections import defaultdict from contextlib import contextmanager -from typing import Callable +from typing import Callable, Iterable from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core.autodiff.grad import Grad @@ -121,7 +121,7 @@ class GradManager: self._after_backward_callback = [] self._gradients = {} - def attach(self, tensors: list, callbacks=None): + def attach(self, tensors: Iterable[Tensor], callbacks=None): r""" Instruct GradManager to track operations on tensors, so that gradients with respect to those tensors could be evaluated later. @@ -199,6 +199,7 @@ class GradManager: return spec for x in tensors: + assert isinstance(x, Tensor), "Object to be attached should be Tensor" spec = self._attach_specs.get(id(x)) new_attach = spec is None if spec is None: