diff --git a/imperative/python/megengine/jit/xla_backend.py b/imperative/python/megengine/jit/xla_backend.py index 6e9e48e70854f9b0eb3323d17ec9058ad2e560b5..992caedcbd9d1b03aef894a9e43f221f13c58e42 100644 --- a/imperative/python/megengine/jit/xla_backend.py +++ b/imperative/python/megengine/jit/xla_backend.py @@ -4,7 +4,6 @@ from ..core._imperative_rt import CompNode from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._trace_option import set_use_xla_backend from ..device import get_default_device -from ..distributed import get_mm_server_addr, is_distributed from ..utils.dlpack import from_dlpack, to_dlpack from .tracing import trace @@ -65,6 +64,7 @@ class xla_trace(trace): from ..utils.module_utils import get_expand_structure from ..xla.device import get_xla_backend_and_device from ..tensor import Tensor + from ..distributed import get_mm_server_addr, is_distributed assert self.traced if self.overall: