提交 5ae89c79 编写于 作者: M Megvii Engine Team

refactor(mgb/grad): place grad at param.grad

GitOrigin-RevId: fddeace40248f2edf7781458ef408529d75022ca
上级 9faa32fc
......@@ -42,14 +42,15 @@ class GradManager:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
for p in params:
def callback(param, grad, callbacks=callbacks):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
param.grad = ret
def callback(param, grad, callbacks=callbacks, p=p):
ret = grad
for cb in callbacks:
ret = cb(param, ret)
p.grad = ret
grad.wrt(*params, callback=callback)
grad.wrt(p, callback=callback)
with grad:
yield
finally:
......
......@@ -52,12 +52,6 @@ class SGD(Optimizer):
momentum = param_group["momentum"]
for param in param_group["params"]:
if not isinstance(param.grad, Buffer):
raise TypeError(
"grad must be a Buffer, maybe you forget to call backward()?"
)
if not param.requires_grad:
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册