提交 8c482b67 编写于 作者: M Megvii Engine Team

fix(mge/grad): make register_after_backward_callback private

GitOrigin-RevId: 8eb6c0e6284b6e3b926778833f3bda06397baac7
上级 66b6daf7
......@@ -28,7 +28,7 @@ class GradManager:
self._call_back_dict[id(p)].append(cb)
return self
def register_after_backward_callback(self, callback):
def _register_after_backward_callback(self, callback):
self._after_backward_callback.append(callback)
return self
......
......@@ -104,7 +104,7 @@ class AllreduceCallback:
gm = get_backwarding_grad_manager()
assert isinstance(gm, GradManager)
if gm not in self._marked_gm:
gm.register_after_backward_callback(self._flush)
gm._register_after_backward_callback(self._flush)
self._marked_gm.add(gm)
self._params.append(param)
self._futures_dict[param] = FakeTensor(ack=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册