提交 9faa32fc 编写于 作者: M Megvii Engine Team

fix(mge/imperative): fix grad callback

GitOrigin-RevId: 6f843b0106117ca24d08efb5685cd09171197430
上级 6d4fd938
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .grad_manager import GradManager
......@@ -10,8 +10,8 @@ class GradManager:
self._recording = False
self._grad = None
def register(self, params, callback=None):
self._call_back_pair.append([params, callback])
def register(self, params, callbacks=None):
self._call_back_pair.append([list(params), callbacks or []])
def backward(self, ys, dys=None):
if not self._recording:
......@@ -24,7 +24,7 @@ class GradManager:
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
dys = [tensor(1).broadcast(y.shape) for y in ys]
dys = [tensor(1.0) for y in ys]
if not isinstance(dys, (tuple, list)):
dys = [dys]
try:
......@@ -42,7 +42,14 @@ class GradManager:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
grad.wrt(*params, callback=callbacks)
def callback(param, grad, callbacks=callbacks):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
param.grad = ret
grad.wrt(*params, callback=callback)
with grad:
yield
finally:
......
......@@ -260,13 +260,9 @@ class Grad:
cache[v] = g
if last_written_to[v] == (seqno, i):
if v.callback:
grad = v.callback(
v.callback(
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
)
if getattr(v.owner(), "grad", None) is None:
v.owner().grad = grad
else:
v.owner().grad += grad
if v.opnode is None:
# won't read by backward, mark consumed
cache[v] = None
......
......@@ -9,8 +9,8 @@
from bisect import bisect_right
from typing import Iterable as Iter
from .optimizer import Optimizer
from .lr_scheduler import LRScheduler
from .optimizer import Optimizer
class MultiStepLR(LRScheduler):
......
......@@ -53,10 +53,6 @@ class SGD(Optimizer):
for param in param_group["params"]:
if param.__wrapped__ in self._grad_skip:
self._grad_skip.remove(param.__wrapped__)
continue
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
......@@ -76,5 +72,3 @@ class SGD(Optimizer):
self._state[param]["momentum_buffer"]._reset(v)
else:
param -= lr * grad
assert len(self._grad_skip) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册