From 8c482b6709569537c3eefdd28011e36c3f38cc66 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Sep 2020 18:18:42 +0800 Subject: [PATCH] fix(mge/grad): make register_after_backward_callback private GitOrigin-RevId: 8eb6c0e6284b6e3b926778833f3bda06397baac7 --- imperative/python/megengine/autodiff/grad_manager.py | 2 +- imperative/python/megengine/distributed/helper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 528b805e..fefd2a68 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -28,7 +28,7 @@ class GradManager: self._call_back_dict[id(p)].append(cb) return self - def register_after_backward_callback(self, callback): + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 693da1e8..aebc9b08 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -104,7 +104,7 @@ class AllreduceCallback: gm = get_backwarding_grad_manager() assert isinstance(gm, GradManager) if gm not in self._marked_gm: - gm.register_after_backward_callback(self._flush) + gm._register_after_backward_callback(self._flush) self._marked_gm.add(gm) self._params.append(param) self._futures_dict[param] = FakeTensor(ack=False) -- GitLab