From 5ae89c799b20749885269f24d66e567a2220ac9b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Sep 2020 23:46:51 +0800 Subject: [PATCH] refactor(mgb/grad): place grad at param.grad GitOrigin-RevId: fddeace40248f2edf7781458ef408529d75022ca --- .../python/megengine/autodiff/grad_manager.py | 13 +++++++------ imperative/python/megengine/optimizer/sgd.py | 6 ------ 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 00faba33d..489067b5e 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -42,14 +42,15 @@ class GradManager: self._recording = True self._grad = grad for params, callbacks in self._call_back_pair: + for p in params: - def callback(param, grad, callbacks=callbacks): - ret = grad - for cb in callbacks: - ret = cb(param, ret) - param.grad = ret + def callback(param, grad, callbacks=callbacks, p=p): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + p.grad = ret - grad.wrt(*params, callback=callback) + grad.wrt(p, callback=callback) with grad: yield finally: diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index f74a2cc60..4e4dafb81 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -52,12 +52,6 @@ class SGD(Optimizer): momentum = param_group["momentum"] for param in param_group["params"]: - - if not isinstance(param.grad, Buffer): - raise TypeError( - "grad must be a Buffer, maybe you forget to call backward()?" - ) - if not param.requires_grad: continue -- GitLab