diff --git a/imperative/python/megengine/autodiff/__init__.py b/imperative/python/megengine/autodiff/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3def6f814714ee909a5ae722b87c99ae30b23633 --- /dev/null +++ b/imperative/python/megengine/autodiff/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +from .grad_manager import GradManager diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 84a8f55fa53d0585fa152cbe13509d6c78aa34f1..00faba33d7692a7cc140d1b0f47f47eae3b6c9e5 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -10,8 +10,8 @@ class GradManager: self._recording = False self._grad = None - def register(self, params, callback=None): - self._call_back_pair.append([params, callback]) + def register(self, params, callbacks=None): + self._call_back_pair.append([list(params), callbacks or []]) def backward(self, ys, dys=None): if not self._recording: @@ -24,7 +24,7 @@ class GradManager: if not isinstance(ys, (tuple, list)): ys = [ys] if dys is None: - dys = [tensor(1).broadcast(y.shape) for y in ys] + dys = [tensor(1.0) for y in ys] if not isinstance(dys, (tuple, list)): dys = [dys] try: @@ -42,7 +42,14 @@ class GradManager: self._recording = True self._grad = grad for params, callbacks in self._call_back_pair: - grad.wrt(*params, callback=callbacks) + + def callback(param, grad, callbacks=callbacks): + ret = grad + for cb in callbacks: + ret = cb(param, ret) + param.grad = ret + + grad.wrt(*params, callback=callback) with grad: yield finally: diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index 42d8ae9c682d903b93d1bb4a54dbda7127f49d63..c30e4113f4fb9fe3f5f4870ff5cf2f79f746274c 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -260,13 +260,9 @@ class Grad: cache[v] = g if last_written_to[v] == (seqno, i): if v.callback: - grad = v.callback( + v.callback( v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] ) - if getattr(v.owner(), "grad", None) is None: - v.owner().grad = grad - else: - v.owner().grad += grad if v.opnode is None: # won't read by backward, mark consumed cache[v] = None diff --git a/imperative/python/megengine/optimizer/multi_step_lr.py b/imperative/python/megengine/optimizer/multi_step_lr.py index fc3a43f482269d8249d8c8c947645ebe1a023f35..602f9228c1632932400f75300d80bbf5a1c05f85 100644 --- a/imperative/python/megengine/optimizer/multi_step_lr.py +++ b/imperative/python/megengine/optimizer/multi_step_lr.py @@ -9,8 +9,8 @@ from bisect import bisect_right from typing import Iterable as Iter -from .optimizer import Optimizer from .lr_scheduler import LRScheduler +from .optimizer import Optimizer class MultiStepLR(LRScheduler): diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index 88cffd0713143ba74c8350070c0604e9bea460b4..f74a2cc6087e29de550b37ae36c7b6b3ca93e816 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -53,10 +53,6 @@ class SGD(Optimizer): for param in param_group["params"]: - if param.__wrapped__ in self._grad_skip: - self._grad_skip.remove(param.__wrapped__) - continue - if not isinstance(param.grad, Buffer): raise TypeError( "grad must be a Buffer, maybe you forget to call backward()?" @@ -76,5 +72,3 @@ class SGD(Optimizer): self._state[param]["momentum_buffer"]._reset(v) else: param -= lr * grad - - assert len(self._grad_skip) == 0