diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 00faba33d7692a7cc140d1b0f47f47eae3b6c9e5..489067b5e384c55ee2691057552af26e853259eb 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -42,14 +42,15 @@ class GradManager: self._recording = True self._grad = grad for params, callbacks in self._call_back_pair: + for p in params: - def callback(param, grad, callbacks=callbacks): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - param.grad = ret + def callback(param, grad, callbacks=callbacks, p=p): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + p.grad = ret - grad.wrt(*params, callback=callback) + grad.wrt(p, callback=callback) with grad: yield finally: diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index f74a2cc6087e29de550b37ae36c7b6b3ca93e816..4e4dafb81e57cfdef3cb9b78e6d9a4cdeb110188 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -52,12 +52,6 @@ class SGD(Optimizer): momentum = param_group["momentum"] for param in param_group["params"]: - - if not isinstance(param.grad, Buffer): - raise TypeError( - "grad must be a Buffer, maybe you forget to call backward()?" - ) - if not param.requires_grad: continue