From 696d2c2ef116451a95cd5848e7a64c7e0c75aff9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 13 May 2021 14:56:53 +0800 Subject: [PATCH] fix(mge/autodiff): check tensors to be attached GitOrigin-RevId: c4f3f808764bac2d852e34bc118a76cde313b0c7 --- imperative/python/megengine/autodiff/grad_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 3d5697959..72c689fd4 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -1,7 +1,7 @@ import weakref from collections import defaultdict from contextlib import contextmanager -from typing import Callable +from typing import Callable, Iterable from ..core._imperative_rt.core2 import pop_scope, push_scope, set_option from ..core.autodiff.grad import Grad @@ -121,7 +121,7 @@ class GradManager: self._after_backward_callback = [] self._gradients = {} - def attach(self, tensors: list, callbacks=None): + def attach(self, tensors: Iterable[Tensor], callbacks=None): r""" Instruct GradManager to track operations on tensors, so that gradients with respect to those tensors could be evaluated later. @@ -199,6 +199,7 @@ class GradManager: return spec for x in tensors: + assert isinstance(x, Tensor), "Object to be attached should be Tensor" spec = self._attach_specs.get(id(x)) new_attach = spec is None if spec is None: -- GitLab