diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index 528b805ee7a7c88a77b701de0938f25199d0deb5..fefd2a682a3923e2b2a0a202cf13827b6f8a71c9 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -28,7 +28,7 @@ class GradManager: self._call_back_dict[id(p)].append(cb) return self - def register_after_backward_callback(self, callback): + def _register_after_backward_callback(self, callback): self._after_backward_callback.append(callback) return self diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 693da1e8d0c8689cf6d1d844412e6b2360a0757f..aebc9b08702305c1cc341c72aa12b37a7483b1dd 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -104,7 +104,7 @@ class AllreduceCallback: gm = get_backwarding_grad_manager() assert isinstance(gm, GradManager) if gm not in self._marked_gm: - gm.register_after_backward_callback(self._flush) + gm._register_after_backward_callback(self._flush) self._marked_gm.add(gm) self._params.append(param) self._futures_dict[param] = FakeTensor(ack=False)