From 5e013d8c57690311e532b18d154055c4d8bb7051 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 6 Jul 2023 20:05:34 +0800 Subject: [PATCH] refactor(xla): add xla acknowledgement GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd --- ACKNOWLEDGMENTS | 5 + imperative/python/megengine/xla/__init__.py | 2 + .../python/megengine/xla/lib/__init__.py | 46 ++- imperative/python/megengine/xla/lib/config.py | 295 +++++++++--------- .../python/megengine/xla/lib/xla_bridge.py | 136 ++++---- 5 files changed, 244 insertions(+), 240 deletions(-) diff --git a/ACKNOWLEDGMENTS b/ACKNOWLEDGMENTS index 14cf2e577..e0e425f57 100644 --- a/ACKNOWLEDGMENTS +++ b/ACKNOWLEDGMENTS @@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved. 5. MACE Copyright 2018 Xiaomi Inc. All rights reserved. +6. XLA +Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +7. JAX +Copyright 2018 The JAX Authors. Terms of Apache License Version 2.0 --------------------------------------------------- diff --git a/imperative/python/megengine/xla/__init__.py b/imperative/python/megengine/xla/__init__.py index 5dcd05e12..1f92460f0 100644 --- a/imperative/python/megengine/xla/__init__.py +++ b/imperative/python/megengine/xla/__init__.py @@ -1 +1,3 @@ +# some code of this directory is from jax: https://github.com/google/jax + from .build import build_xla diff --git a/imperative/python/megengine/xla/lib/__init__.py b/imperative/python/megengine/xla/lib/__init__.py index 2f06ffc69..c1b74c3da 100644 --- a/imperative/python/megengine/xla/lib/__init__.py +++ b/imperative/python/megengine/xla/lib/__init__.py @@ -1,3 +1,5 @@ +# code of this directory is mainly from jax: https://github.com/google/jax + try: import mge_xlalib as mge_xlalib except ModuleNotFoundError as err: @@ -9,8 +11,10 @@ except ModuleNotFoundError as err: raise ModuleNotFoundError(msg) import gc -import pathlib +import os import platform +import subprocess +import sys import warnings from typing import Optional @@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features() xla_extension = xla_client._xla pytree = xla_client._xla.pytree -jax_jit = xla_client._xla.jax_jit +# we use some api in jaxlib +xla_jit = xla_client._xla.jax_jit pmap_lib = xla_client._xla.pmap_lib @@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback) xla_extension_version: int = getattr(xla_client, "_version", 0) mlir_api_version = xla_client.mlir_api_version - -def _cuda_path() -> Optional[str]: - _mgexlalib_path = pathlib.Path(mge_xlalib.__file__).parent - path = _mgexlalib_path.parent / "nvidia" / "cuda_nvcc" - if path.is_dir(): - return str(path) - path = _mgexlalib_path / "cuda" - if path.is_dir(): - return str(path) - return None - - -cuda_path = _cuda_path() +# Finds the CUDA install path +def _find_cuda_root_dir() -> Optional[str]: + cuda_root_dir = os.environ.get("CUDA_ROOT_DIR") + if cuda_root_dir is None: + try: + which = "where" if sys.platform == "win32" else "which" + with open(os.devnull, "w") as devnull: + nvcc = ( + subprocess.check_output([which, "nvcc"], stderr=devnull) + .decode() + .rstrip("\r\n") + ) + cuda_root_dir = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + if sys.platform == "win32": + assert False, "xla not supported on windows" + else: + cuda_root_dir = "/usr/local/cuda" + if not os.path.exists(cuda_root_dir): + cuda_root_dir = None + return cuda_root_dir + + +cuda_path = _find_cuda_root_dir() transfer_guard_lib = xla_client._xla.transfer_guard_lib diff --git a/imperative/python/megengine/xla/lib/config.py b/imperative/python/megengine/xla/lib/config.py index 83349246d..0d6f3d69c 100644 --- a/imperative/python/megengine/xla/lib/config.py +++ b/imperative/python/megengine/xla/lib/config.py @@ -1,3 +1,4 @@ +# code of this file is mainly from jax: https://github.com/google/jax import contextlib import functools import itertools @@ -10,9 +11,10 @@ from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional import mge_xlalib.xla_client as xla_client -from . import jax_jit as libjax_jit +from . import xla_jit as libxla_jit -jax_jit = xla_client._xla.jax_jit +# we use some api in jaxlib +xla_jit = xla_client._xla.jax_jit transfer_guard_lib = xla_client._xla.transfer_guard_lib @@ -46,10 +48,9 @@ def int_env(varname: str, default: int) -> int: UPGRADE_BOOL_HELP = ( - " This will be enabled by default in future versions of JAX, at which " + " This will be enabled by default in future versions of XLA, at which " "point all uses of the flag will be considered deprecated (following " "the `API compatibility policy " - "`_)." ) UPGRADE_BOOL_EXTRA_DESC = " (transient)" @@ -181,21 +182,21 @@ class Config: Example: enable_foo = config.define_bool_state( - name='jax_enable_foo', + name='xla_enable_foo', default=False, help='Enable foo.') - # Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo + # Now the XLA_ENABLE_FOO shell environment variable and --xla_enable_foo # command-line flag can be used to control the process-level value of # the configuration option, in addition to using e.g. - # ``config.update("jax_enable_foo", True)`` directly. We can also use a + # ``config.update("xla_enable_foo", True)`` directly. We can also use a # context manager: with enable_foo(True): ... The value of the thread-local state or flag can be accessed via - ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is + ``config.xla_enable_foo``. Reading it via ``config.FLAGS.xla_enable_foo`` is an error. """ @@ -248,7 +249,7 @@ class Config: name = name.lower() default = os.getenv(name.upper(), default) if default is not None and default not in enum_values: - raise ValueError(f'Invalid value "{default}" for JAX flag {name}') + raise ValueError(f'Invalid value "{default}" for XLA flag {name}') self.DEFINE_enum( name, default, @@ -303,7 +304,7 @@ class Config: try: default = int(default_env) except ValueError: - raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') + raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}') self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) @@ -350,7 +351,7 @@ class Config: try: default = float(default_env) except ValueError: - raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') + raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}') self.DEFINE_float(name, default, help=help, update_hook=update_global_hook) self._contextmanager_flags.add(name) @@ -465,7 +466,7 @@ class Config: Values included in this set should also most likely be included in the C++ JIT state, which is handled separately.""" - tls = jax_jit.thread_local_state() + tls = xla_jit.thread_local_state() axis_env_state = () mesh_context_manager = () context = tls.extra_jit_context @@ -477,15 +478,14 @@ class Config: axis_env_state, mesh_context_manager, self.x64_enabled, - self.jax_numpy_rank_promotion, - self.jax_default_matmul_precision, - self.jax_dynamic_shapes, - self.jax_numpy_dtype_promotion, - self.jax_default_device, - self.jax_array, - self.jax_threefry_partitionable, - # Technically this affects jaxpr->MHLO lowering, not tracing. - self.jax_hlo_source_file_canonicalization_regex, + self.xla_numpy_rank_promotion, + self.xla_default_matmul_precision, + self.xla_dynamic_shapes, + self.xla_numpy_dtype_promotion, + self.xla_default_device, + self.xla_array, + self.xla_threefry_partitionable, + self.xla_hlo_source_file_canonicalization_regex, ) @@ -507,7 +507,7 @@ class _StateContextManager: default_value: Any = no_default, ): self._name = name - self.__name__ = name[4:] if name.startswith("jax_") else name + self.__name__ = name[4:] if name.startswith("xla_") else name self.__doc__ = ( f"Context manager for `{name}` config option" f"{extra_description}.\n\n{help}" @@ -599,7 +599,7 @@ class _GlobalExtraJitContext(NamedTuple): def _update_global_jit_state(**kw): - gs = jax_jit.global_state() + gs = xla_jit.global_state() context = gs.extra_jit_context or _GlobalExtraJitContext() gs.extra_jit_context = context._replace(**kw) @@ -626,7 +626,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): class _ThreadLocalStateCache(threading.local): """"A thread local cache for _ThreadLocalExtraJitContext - The extra_jit_context in jax_jit.thread_local_state() may get updated and thus + The extra_jit_context in xla_jit.thread_local_state() may get updated and thus incurring dispatch overhead for comparing this python object during jit calls. We want to duduplicate the objects that have the same hash/equality to also have the same object ID, since the equality check is much faster if the object @@ -641,7 +641,7 @@ _thread_local_state_cache = _ThreadLocalStateCache() def update_thread_local_jit_state(**kw): - tls = jax_jit.thread_local_state() + tls = xla_jit.thread_local_state() # After xla_client._version >= 70, the thread_local object will necessarily # be initialized when accessed. The following line can be removed when the context = tls.extra_jit_context or _ThreadLocalExtraJitContext() @@ -650,25 +650,25 @@ def update_thread_local_jit_state(**kw): flags.DEFINE_integer( - "jax_tracer_error_num_traceback_frames", - int_env("JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5), - help="Set the number of stack frames in JAX tracer error messages.", + "xla_tracer_error_num_traceback_frames", + int_env("XLA_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5), + help="Set the number of stack frames in XLA tracer error messages.", ) flags.DEFINE_bool( - "jax_pprint_use_color", - bool_env("JAX_PPRINT_USE_COLOR", True), - help="Enable jaxpr pretty-printing with colorful syntax highlighting.", + "xla_pprint_use_color", + bool_env("XLA_PPRINT_USE_COLOR", True), + help="Enable pretty-printing with colorful syntax highlighting.", ) flags.DEFINE_bool( - "jax_host_callback_inline", - bool_env("JAX_HOST_CALLBACK_INLINE", False), + "xla_host_callback_inline", + bool_env("XLA_HOST_CALLBACK_INLINE", False), help="Inline the host_callback, if not in a staged context.", ) flags.DEFINE_integer( - "jax_host_callback_max_queue_byte_size", - int_env("JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)), + "xla_host_callback_max_queue_byte_size", + int_env("XLA_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)), help=( "The size in bytes of the buffer used to hold outfeeds from each " "device. When this capacity is reached consuming outfeeds from the " @@ -678,8 +678,8 @@ flags.DEFINE_integer( lower_bound=int(16 * 1e6), ) flags.DEFINE_bool( - "jax_host_callback_outfeed", - bool_env("JAX_HOST_CALLBACK_OUTFEED", False), + "xla_host_callback_outfeed", + bool_env("XLA_HOST_CALLBACK_OUTFEED", False), help=( "Use outfeed implementation for host_callback, even on CPU and GPU. " "If false, use the CustomCall implementation. " @@ -687,8 +687,8 @@ flags.DEFINE_bool( ), ) flags.DEFINE_bool( - "jax_host_callback_ad_transforms", - bool_env("JAX_HOST_CALLBACK_AD_TRANSFORMS", False), + "xla_host_callback_ad_transforms", + bool_env("XLA_HOST_CALLBACK_AD_TRANSFORMS", False), help=( "Enable support for jvp/vjp for the host_callback primitives. Default is " "False, which means that host_callback operates only on primals. " @@ -696,65 +696,63 @@ flags.DEFINE_bool( ), ) -# TODO: remove flag when XLA:CPU is improved. -jax2tf_associative_scan_reductions = config.define_bool_state( - name="jax2tf_associative_scan_reductions", +xla2tf_associative_scan_reductions = config.define_bool_state( + name="xla2tf_associative_scan_reductions", default=False, help=( - "JAX has two separate lowering rules for the cumulative reduction " + "XLA has two separate lowering rules for the cumulative reduction " "primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses " "a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. " "The latter has a slow implementation on CPUs and GPUs. " - "By default, jax2tf uses the TPU lowering. Set this flag to True to " + "By default, xla2tf uses the TPU lowering. Set this flag to True to " "use the associative scan lowering usage, and only if it makes a difference " "for your application. " - "See the jax2tf README.md for more details." + "See the xla2tf README.md for more details." ), ) -jax2tf_default_native_serialization = config.define_bool_state( - name="jax2tf_default_native_serialization", - default=bool_env("JAX2TF_DEFAULT_NATIVE_SERIALIZATION", False), +xla2tf_default_native_serialization = config.define_bool_state( + name="xla2tf_default_native_serialization", + default=bool_env("XLA2TF_DEFAULT_NATIVE_SERIALIZATION", False), help=( "Sets the default value of the native_serialization parameter to " - "jax2tf.convert. Prefer using the parameter instead of the flag, the " + "xla2tf.convert. Prefer using the parameter instead of the flag, the " "flag may be removed in the future." ), ) -# TODO: remove jax2tf_default_experimental_native_lowering -jax2tf_default_experimental_native_lowering = config.define_bool_state( - name="jax2tf_default_experimental_native_lowering", - default=bool_env("JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False), - help=("DO NOT USE, deprecated in favor of jax2tf_default_native_serialization."), +xla2tf_default_experimental_native_lowering = config.define_bool_state( + name="xla2tf_default_experimental_native_lowering", + default=bool_env("XLA2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False), + help=("DO NOT USE, deprecated in favor of xla2tf_default_native_serialization."), ) -jax_platforms = config.define_string_state( - name="jax_platforms", +xla_platforms = config.define_string_state( + name="xla_platforms", default=None, help=( - "Comma-separated list of platform names specifying which platforms jax " + "Comma-separated list of platform names specifying which platforms xla " "should initialize. If any of the platforms in this list are not successfully " "initialized, an exception will be raised and the program will be aborted. " "The first platform in the list will be the default platform. " - "For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends " + "For example, config.xla_platforms=cpu,tpu means that CPU and TPU backends " "will be initialized, and the CPU backend will be used unless otherwise " "specified. If TPU initialization fails, it will raise an exception. " - "By default, jax will try to initialize all available " + "By default, xla will try to initialize all available " "platforms and will default to GPU or TPU if available, and fallback to CPU " "otherwise." ), ) enable_checks = config.define_bool_state( - name="jax_enable_checks", + name="xla_enable_checks", default=False, - help="Turn on invariant checking for JAX internals. Makes things slower.", + help="Turn on invariant checking for XLA internals. Makes things slower.", ) check_tracer_leaks = config.define_bool_state( - name="jax_check_tracer_leaks", + name="xla_check_tracer_leaks", default=False, help=( "Turn on checking for leaked tracers as soon as a trace completes. " @@ -767,7 +765,7 @@ check_tracer_leaks = config.define_bool_state( checking_leaks = functools.partial(check_tracer_leaks, True) debug_nans = config.define_bool_state( - name="jax_debug_nans", + name="xla_debug_nans", default=False, help=( "Add nan checks to every operation. When a nan is detected on the " @@ -778,7 +776,7 @@ debug_nans = config.define_bool_state( ) debug_infs = config.define_bool_state( - name="jax_debug_infs", + name="xla_debug_infs", default=False, help=( "Add inf checks to every operation. When an inf is detected on the " @@ -789,7 +787,7 @@ debug_infs = config.define_bool_state( ) log_compiles = config.define_bool_state( - name="jax_log_compiles", + name="xla_log_compiles", default=False, help=( "Log a message each time every time `jit` or `pmap` compiles an XLA " @@ -800,80 +798,78 @@ log_compiles = config.define_bool_state( ) log_compiles = config.define_bool_state( - name="jax_log_checkpoint_residuals", + name="xla_log_checkpoint_residuals", default=False, help=( - "Log a message every time jax.checkpoint (aka jax.remat) is " + "Log a message every time xla.checkpoint (aka xla.remat) is " "partially evaluated (e.g. for autodiff), printing what residuals " "are saved." ), ) parallel_functions_output_gda = config.define_bool_state( - name="jax_parallel_functions_output_gda", + name="xla_parallel_functions_output_gda", default=False, help="If True, pjit will output GDAs.", ) -def _update_jax_array_global(val): +def _update_xla_array_global(val): if val is not None and not val: raise ValueError("not supported in current version, please downgrad") -def _update_jax_array_thread_local(val): +def _update_xla_array_thread_local(val): if val is not None and not val: raise ValueError("not supported in current version, please downgrad") -jax_array = config.define_bool_state( - name="jax_array", +xla_array = config.define_bool_state( + name="xla_array", default=True, upgrade=True, - update_global_hook=_update_jax_array_global, - update_thread_local_hook=_update_jax_array_thread_local, - help=( - "If True, new pjit behavior will be enabled and `jax.Array` will be " "used." - ), + update_global_hook=_update_xla_array_global, + update_thread_local_hook=_update_xla_array_thread_local, + help=("whether use xla array"), ) jit_pjit_api_merge = config.define_bool_state( - name="jax_jit_pjit_api_merge", + name="xla_jit_pjit_api_merge", default=True, upgrade=True, help=( "If True, jit and pjit API will be merged. You can only disable it via " - "the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. " + "the environment variable i.e. `os.environ['XLA_JIT_PJIT_API_MERGE'] = '0'`. " "The merge must be disabled via an environment variable since it " - "affects JAX at import time so it needs to be disabled before jax is " + "affects XLA at import time so it needs to be disabled before xla is " "imported." ), ) spmd_mode = config.define_enum_state( - name="jax_spmd_mode", + name="xla_spmd_mode", enum_values=["allow_all", "allow_jit", "allow_pjit"], # TODO: Default to `allow_jit` when the training wheels come # off. default="allow_pjit", help=( - "Decides whether Math on `jax.Array`'s that are not fully addressable " + "Decides whether Math on `xla.Array`'s that are not fully addressable " "(i.e. spans across multiple processes) is allowed. The options are: " "* allow_pjit: Default, only `pjit` computations are allowed to " - " execute on non-fully addressable `jax.Array`s\n" - "* allow_jit: `pjit` and `jax.jit` computations are allowed to " - " execute on non-fully addressable `jax.Array`s\n" + " execute on non-fully addressable `xla.Array`s\n" + "* allow_jit: `pjit` and `xla.jit` computations are allowed to " + " execute on non-fully addressable `xla.Array`s\n" "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " - " `jax.jit` and all other operations are allowed to " - " execute on non-fully addresable `jax.Array`s." + " `xla.jit` and all other operations are allowed to " + " execute on non-fully addresable `xla.Array`s." ), ) distributed_debug = config.define_bool_state( - name="jax_distributed_debug", + name="xla_distributed_debug", default=False, help=( "Enable logging useful for debugging multi-process distributed " @@ -884,7 +880,7 @@ distributed_debug = config.define_bool_state( enable_custom_prng = config.define_bool_state( - name="jax_enable_custom_prng", + name="xla_enable_custom_prng", default=False, upgrade=True, help=( @@ -894,7 +890,7 @@ enable_custom_prng = config.define_bool_state( ) default_prng_impl = config.define_enum_state( - name="jax_default_prng_impl", + name="xla_default_prng_impl", enum_values=["threefry2x32", "rbg", "unsafe_rbg"], default="threefry2x32", help=( @@ -904,14 +900,14 @@ default_prng_impl = config.define_enum_state( ) threefry_partitionable = config.define_bool_state( - name="jax_threefry_partitionable", + name="xla_threefry_partitionable", default=False, upgrade=True, help=( "Enables internal threefry PRNG implementation changes that " "render it automatically partitionable in some cases. For use " - "with pjit and/or jax_array=True. Without this flag, using the " - "standard jax.random pseudo-random number generation may result " + "with pjit and/or xla=True. Without this flag, using the " + "standard xla.random pseudo-random number generation may result " "in extraneous communication and/or redundant distributed " "computation. With this flag, the communication overheads disappear " "in some cases." @@ -923,17 +919,17 @@ threefry_partitionable = config.define_bool_state( ) enable_custom_vjp_by_custom_transpose = config.define_bool_state( - name="jax_enable_custom_vjp_by_custom_transpose", + name="xla_enable_custom_vjp_by_custom_transpose", default=False, upgrade=True, help=( - "Enables an internal upgrade that implements `jax.custom_vjp` by " - "reduction to `jax.custom_jvp` and `jax.custom_transpose`." + "Enables an internal upgrade that implements `xla.custom_vjp` by " + "reduction to `xla.custom_jvp` and `xla.custom_transpose`." ), ) raise_persistent_cache_errors = config.define_bool_state( - name="jax_raise_persistent_cache_errors", + name="xla_raise_persistent_cache_errors", default=False, help=( "If true, exceptions raised when reading or writing to the " @@ -946,7 +942,7 @@ raise_persistent_cache_errors = config.define_bool_state( ) persistent_cache_min_compile_time_secs = config.define_float_state( - name="jax_persistent_cache_min_compile_time_secs", + name="xla_persistent_cache_min_compile_time_secs", default=1, help=( "The minimum compile time of a computation to be written to the " @@ -956,7 +952,7 @@ persistent_cache_min_compile_time_secs = config.define_float_state( ) hlo_source_file_canonicalization_regex = config.define_string_state( - name="jax_hlo_source_file_canonicalization_regex", + name="xla_hlo_source_file_canonicalization_regex", default=None, help=( "Used to canonicalize the source_path metadata of HLO instructions " @@ -969,18 +965,18 @@ hlo_source_file_canonicalization_regex = config.define_string_state( ) config.define_enum_state( - name="jax_default_dtype_bits", + name="xla_default_dtype_bits", enum_values=["32", "64"], default="64", help=( "Specify bit width of default dtypes, either 32-bit or 64-bit. " "This is a temporary flag that will be used during the process " - "of deprecating the ``jax_enable_x64`` flag." + "of deprecating the ``xla_enable_x64`` flag." ), ) numpy_dtype_promotion = config.define_enum_state( - name="jax_numpy_dtype_promotion", + name="xla_numpy_dtype_promotion", enum_values=["standard", "strict"], default="standard", help=( @@ -997,39 +993,38 @@ numpy_dtype_promotion = config.define_enum_state( def _update_x64_global(val): - libjax_jit.global_state().enable_x64 = val + libxla_jit.global_state().enable_x64 = val def _update_x64_thread_local(val): - libjax_jit.thread_local_state().enable_x64 = val + libxla_jit.thread_local_state().enable_x64 = val enable_x64 = config.define_bool_state( - name="jax_enable_x64", + name="xla_enable_x64", default=False, help="Enable 64-bit types to be used", update_global_hook=_update_x64_global, update_thread_local_hook=_update_x64_thread_local, ) -# TODO: remove after fixing users of FLAGS.x64_enabled. -config._contextmanager_flags.remove("jax_enable_x64") +config._contextmanager_flags.remove("xla_enable_x64") -Config.x64_enabled = Config.jax_enable_x64 # type: ignore +Config.x64_enabled = Config.xla_enable_x64 def _update_default_device_global(val): - libjax_jit.global_state().default_device = val + libxla_jit.global_state().default_device = val def _update_default_device_thread_local(val): - libjax_jit.thread_local_state().default_device = val + libxla_jit.thread_local_state().default_device = val def _validate_default_device(val): if val is not None and not isinstance(val, xla_client.Device): # TODO: this is a workaround for non-PJRT Device types. Remove when - # all JAX backends use a single C++ device interface. + # all XLA backends use a single C++ device interface. if "Device" in str(type(val)): logger.info( "Allowing non-`xla_client.Device` default device: %s, type: %s", @@ -1038,20 +1033,20 @@ def _validate_default_device(val): ) return raise ValueError( - "jax.default_device must be passed a Device object (e.g. " - f"`jax.devices('cpu')[0]`), got: {repr(val)}" + "xla.default_device must be passed a Device object (e.g. " + f"`xla.devices('cpu')[0]`), got: {repr(val)}" ) # TODO: default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). +# platform names as well (e.g. "cpu" to mean the same as xla.devices("cpu")[0]). default_device = config.define_string_or_object_state( - name="jax_default_device", + name="xla_default_device", default=None, help=( - "Configure the default device for JAX operations. Set to a Device " - 'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the ' - "default device for JAX operations and jit'd function calls (there is " + "Configure the default device for XLA operations. Set to a Device " + 'object (e.g. ``xla.devices("cpu")[0]``) to use that Device as the ' + "default device for XLA operations and jit'd function calls (there is " "no effect on multi-device computations, e.g. pmapped function calls). " "Set to None to use the system default device. See " ":ref:`faq-data-placement` for more information on device placement." @@ -1063,15 +1058,15 @@ default_device = config.define_string_or_object_state( def _update_disable_jit_global(val): - libjax_jit.global_state().disable_jit = val + libxla_jit.global_state().disable_jit = val def _update_disable_jit_thread_local(val): - libjax_jit.thread_local_state().disable_jit = val + libxla_jit.thread_local_state().disable_jit = val disable_jit = config.define_bool_state( - name="jax_disable_jit", + name="xla_disable_jit", default=False, help=("Disable JIT compilation and just call original Python."), update_global_hook=_update_disable_jit_global, @@ -1080,7 +1075,7 @@ disable_jit = config.define_bool_state( numpy_rank_promotion = config.define_enum_state( - name="jax_numpy_rank_promotion", + name="xla_numpy_rank_promotion", enum_values=["allow", "warn", "raise"], default="allow", help=( @@ -1094,7 +1089,7 @@ numpy_rank_promotion = config.define_enum_state( ) default_matmul_precision = config.define_enum_state( - name="jax_default_matmul_precision", + name="xla_default_matmul_precision", enum_values=["bfloat16", "tensorfloat32", "float32"], default=None, help=( @@ -1102,8 +1097,8 @@ default_matmul_precision = config.define_enum_state( "Some platforms, like TPU, offer configurable precision levels for " "matrix multiplication and convolution computations, trading off " "accuracy for speed. The precision can be controlled for each " - "operation; for example, see the :func:`jax.lax.conv_general_dilated` " - "and :func:`jax.lax.dot` docstrings. But it can be useful to control " + "operation; for example, see the :func:`xla.lax.conv_general_dilated` " + "and :func:`xla.lax.dot` docstrings. But it can be useful to control " "the default behavior obtained when an operation is not given a " "specific precision.\n\n" "This option can be used to control the default precision " @@ -1122,10 +1117,10 @@ default_matmul_precision = config.define_enum_state( ) traceback_filtering = config.define_enum_state( - name="jax_traceback_filtering", + name="xla_traceback_filtering", enum_values=["off", "tracebackhide", "remove_frames", "auto"], default="auto", - help="Controls how JAX filters internal frames out of tracebacks.\n\n" + help="Controls how XLA filters internal frames out of tracebacks.\n\n" "Valid values are:\n" ' * "off": disables traceback filtering.\n' ' * "auto": use "tracebackhide" if running under a sufficiently ' @@ -1136,20 +1131,16 @@ traceback_filtering = config.define_enum_state( " the unfiltered traceback as a __cause__ of the exception.\n", ) -# This flag is for internal use. -# TODO: Removes once we always enable cusparse lowering. -# TODO: Set to true after bug is fixed + bcoo_cusparse_lowering = config.define_bool_state( - name="jax_bcoo_cusparse_lowering", + name="xla_bcoo_cusparse_lowering", default=False, help=("Enables lowering BCOO ops to cuSparse."), ) -# TODO: remove this flag when we ensure we only succeed at trace-staging -# if the intended backend can handle lowering the result config.define_bool_state( - name="jax_dynamic_shapes", - default=bool(os.getenv("JAX_DYNAMIC_SHAPES", "")), + name="xla_dynamic_shapes", + default=bool(os.getenv("XLA_DYNAMIC_SHAPES", "")), help=( "Enables experimental features for staging out computations with " "dynamic shapes." @@ -1163,14 +1154,14 @@ config.define_bool_state( # This flag is temporary during rollout of the remat barrier. # TODO: Remove if there are no complaints. config.define_bool_state( - name="jax_remat_opt_barrier", + name="xla_remat_opt_barrier", default=True, help=("Enables using optimization-barrier op for lowering remat."), ) # TODO: Remove flag once coordination service has rolled out. config.define_bool_state( - name="jax_coordination_service", + name="xla_coordination_service", default=True, help=( "Use coordination service (experimental) instead of the default PjRT " @@ -1180,17 +1171,17 @@ config.define_bool_state( # TODO: set default to True, then remove config.define_bool_state( - name="jax_eager_pmap", + name="xla_eager_pmap", default=True, upgrade=True, - help="Enable eager-mode pmap when jax_disable_jit is activated.", + help="Enable eager-mode pmap when xla_disable_jit is activated.", ) config.define_bool_state( - name="jax_experimental_unsafe_xla_runtime_errors", + name="xla_experimental_unsafe_xla_runtime_errors", default=False, help=( - "Enable XLA runtime errors for jax.experimental.checkify.checks " + "Enable XLA runtime errors for xla.experimental.checkify.checks " "on CPU and GPU. These errors are async, might get lost and are not " "very readable. But, they crash the computation and enable you " "to write jittable checks without needing to checkify. Does not " @@ -1242,10 +1233,10 @@ def _update_transfer_guard(state, key, val): transfer_guard_host_to_device = config.define_enum_state( - name="jax_transfer_guard_host_to_device", + name="xla_transfer_guard_host_to_device", enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. + # accidentally overriding --xla_transfer_guard. default=None, help=( "Select the transfer guard level for host-to-device transfers. " @@ -1260,10 +1251,10 @@ transfer_guard_host_to_device = config.define_enum_state( ) transfer_guard_device_to_device = config.define_enum_state( - name="jax_transfer_guard_device_to_device", + name="xla_transfer_guard_device_to_device", enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. + # accidentally overriding --xla_transfer_guard. default=None, help=( "Select the transfer guard level for device-to-device transfers. " @@ -1278,10 +1269,10 @@ transfer_guard_device_to_device = config.define_enum_state( ) transfer_guard_device_to_host = config.define_enum_state( - name="jax_transfer_guard_device_to_host", + name="xla_transfer_guard_device_to_host", enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard. + # accidentally overriding --xla_transfer_guard. default=None, help=( "Select the transfer guard level for device-to-host transfers. " @@ -1298,18 +1289,18 @@ transfer_guard_device_to_host = config.define_enum_state( def _update_all_transfer_guard_global(val): for name in ( - "jax_transfer_guard_host_to_device", - "jax_transfer_guard_device_to_device", - "jax_transfer_guard_device_to_host", + "xla_transfer_guard_host_to_device", + "xla_transfer_guard_device_to_device", + "xla_transfer_guard_device_to_host", ): config.update(name, val) _transfer_guard = config.define_enum_state( - name="jax_transfer_guard", + name="xla_transfer_guard", enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], # The default is applied by transfer_guard_lib. Use None here to avoid - # accidentally overriding --jax_transfer_guard_*. + # accidentally overriding --xla_transfer_guard_*. default=None, help=( "Select the transfer guard level for all transfers. This option is " diff --git a/imperative/python/megengine/xla/lib/xla_bridge.py b/imperative/python/megengine/xla/lib/xla_bridge.py index 949c5dd84..fa2dcadfd 100644 --- a/imperative/python/megengine/xla/lib/xla_bridge.py +++ b/imperative/python/megengine/xla/lib/xla_bridge.py @@ -1,3 +1,4 @@ +# code of this file is mainly from jax: https://github.com/google/jax import logging import os import platform as py_platform @@ -19,41 +20,41 @@ FLAGS = flags.FLAGS logger = logging.getLogger(__name__) flags.DEFINE_string( - "jax_xla_backend", "", "Deprecated, please use --jax_platforms instead." + "xla_backend", "", "Deprecated, please use --xla_platforms instead." ) flags.DEFINE_string( - "jax_backend_target", - os.getenv("JAX_BACKEND_TARGET", "").lower(), + "xla_backend_target", + os.getenv("XLA_BACKEND_TARGET", "").lower(), 'Either "local" or "rpc:address" to connect to a remote service target.', ) flags.DEFINE_string( - "jax_platform_name", - os.getenv("JAX_PLATFORM_NAME", "").lower(), - "Deprecated, please use --jax_platforms instead.", + "xla_platform_name", + os.getenv("XLA_PLATFORM_NAME", "").lower(), + "Deprecated, please use --xla_platforms instead.", ) flags.DEFINE_bool( - "jax_disable_most_optimizations", - bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False), + "xla_disable_most_optimizations", + bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False), "Try not to do much optimization work. This can be useful if the cost of " "optimization is greater than that of running a less-optimized program.", ) flags.DEFINE_integer( - "jax_xla_profile_version", - int_env("JAX_XLA_PROFILE_VERSION", 0), + "xla_profile_version", + int_env("XLA_PROFILE_VERSION", 0), "Optional profile version for XLA compilation. " "This is meaningful only when XLA is configured to " "support the remote compilation profile feature.", ) flags.DEFINE_string( - "jax_cuda_visible_devices", + "xla_cuda_visible_devices", "all", - 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' + 'Restricts the set of CUDA devices that XLA will use. Either "all", or a ' "comma-separate list of integer device IDs.", ) flags.DEFINE_string( - "jax_rocm_visible_devices", + "xla_rocm_visible_devices", "all", - 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' + 'Restricts the set of ROCM devices that XLA will use. Either "all", or a ' "comma-separate list of integer device IDs.", ) @@ -69,22 +70,22 @@ def get_compile_options( ) -> xla_client.CompileOptions: """Returns the compile options to use, as derived from flag values. - Args: - num_replicas: Number of replicas for which to compile. - num_partitions: Number of partitions for which to compile. - device_assignment: Optional ndarray of jax devices indicating the assignment - of logical replicas to physical devices (default inherited from - xla_client.CompileOptions). Must be consistent with `num_replicas` and - `num_partitions`. - use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD - partitioning in XLA. - use_auto_spmd_partitioning: boolean indicating whether to automatically - generate XLA shardings for SPMD partitioner. - auto_spmd_partitioning_mesh_shape: device mesh shape used to create - auto_spmd_partitioning search space. - auto_spmd_partitioning_mesh_ids: device ids used to create - auto_spmd_partitioning search space. - """ + Args: + num_replicas: Number of replicas for which to compile. + num_partitions: Number of partitions for which to compile. + device_assignment: Optional ndarray of xla devices indicating the assignment + of logical replicas to physical devices (default inherited from + xla_client.CompileOptions). Must be consistent with `num_replicas` and + `num_partitions`. + use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD + partitioning in XLA. + use_auto_spmd_partitioning: boolean indicating whether to automatically + generate XLA shardings for SPMD partitioner. + auto_spmd_partitioning_mesh_shape: device mesh shape used to create + auto_spmd_partitioning search space. + auto_spmd_partitioning_mesh_ids: device ids used to create + auto_spmd_partitioning search space. + """ compile_options = xla_client.CompileOptions() compile_options.num_replicas = num_replicas compile_options.num_partitions = num_partitions @@ -130,12 +131,12 @@ def get_compile_options( if cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = cuda_path - if FLAGS.jax_disable_most_optimizations: + if FLAGS.xla_disable_most_optimizations: debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False - compile_options.profile_version = FLAGS.jax_xla_profile_version + compile_options.profile_version = FLAGS.xla_profile_version return compile_options @@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"): partial( make_gpu_client, platform_name="cuda", - visible_devices_flag="jax_cuda_visible_devices", + visible_devices_flag="xla_cuda_visible_devices", ), priority=200, ) @@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"): partial( make_gpu_client, platform_name="rocm", - visible_devices_flag="jax_rocm_visible_devices", + visible_devices_flag="xla_rocm_visible_devices", ), priority=200, ) if hasattr(xla_client, "make_plugin_device_client"): - # It is assumed that if jax has been built with a plugin client, then the + # It is assumed that if xla has been built with a plugin client, then the # user wants to use the plugin client by default. Therefore, it gets the # highest priority. register_backend_factory( @@ -229,11 +230,11 @@ def is_known_platform(platform: str): def canonicalize_platform(platform: str) -> str: """Replaces platform aliases with their concrete equivalent. - In particular, replaces "gpu" with either "cuda" or "rocm", depending on which - hardware is actually present. We want to distinguish "cuda" and "rocm" for - purposes such as MLIR lowering rules, but in many cases we don't want to - force users to care. - """ + In particular, replaces "gpu" with either "cuda" or "rocm", depending on which + hardware is actually present. We want to distinguish "cuda" and "rocm" for + purposes such as MLIR lowering rules, but in many cases we don't want to + force users to care. + """ platforms = _alias_to_platforms.get(platform, None) if platforms is None: return platform @@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str: def expand_platform_alias(platform: str) -> List[str]: """Expands, e.g., "gpu" to ["cuda", "rocm"]. - This is used for convenience reasons: we expect cuda and rocm to act similarly - in many respects since they share most of the same code. - """ + This is used for convenience reasons: we expect cuda and rocm to act similarly + in many respects since they share most of the same code. + """ return _alias_to_platforms.get(platform, [platform]) @@ -270,11 +271,11 @@ def backends(): with _backend_lock: if _backends: return _backends - if config.jax_platforms: - jax_platforms = config.jax_platforms.split(",") + if config.xla_platforms: + xla_platforms = config.xla_platforms.split(",") platforms = [] # Allow platform aliases in the list of platforms. - for platform in jax_platforms: + for platform in xla_platforms: platforms.extend(expand_platform_alias(platform)) priorities = range(len(platforms), 0, -1) platforms_and_priorites = zip(platforms, priorities) @@ -303,8 +304,8 @@ def backends(): # If the backend isn't built into the binary, or if it has no devices, # we expect a RuntimeError. err_msg = f"Unable to initialize backend '{platform}': {err}" - if config.jax_platforms: - err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" + if config.xla_platforms: + err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)" raise RuntimeError(err_msg) else: _backends_errors[platform] = str(err) @@ -315,12 +316,9 @@ def backends(): if ( py_platform.system() != "Darwin" and _default_backend.platform == "cpu" - and FLAGS.jax_platform_name != "cpu" + and FLAGS.xla_platform_name != "cpu" ): - logger.warning( - "No GPU/TPU found, falling back to CPU. " - "(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)" - ) + logger.warning("No GPU/TPU found, falling back to CPU. ") return _backends @@ -329,7 +327,7 @@ def _clear_backends(): global _backends_errors global _default_backend - logger.info("Clearing JAX backend caches.") + logger.info("Clearing XLA backend caches.") with _backend_lock: _backends = {} _backends_errors = {} @@ -351,11 +349,6 @@ def _init_backend(platform): raise RuntimeError(f"Could not initialize backend '{platform}'") if backend.device_count() == 0: raise RuntimeError(f"Backend '{platform}' provides no devices.") - # ccq: disable distributed_debug_log - # util.distributed_debug_log(("Initialized backend", backend.platform), - # ("process_index", backend.process_index()), - # ("device_count", backend.device_count()), - # ("local_devices", backend.local_devices())) logger.debug("Backend '%s' initialized", platform) return backend @@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None): if not isinstance(platform, (type(None), str)): return platform - platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None + platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None bs = backends() if platform is not None: @@ -399,7 +392,7 @@ def get_device_backend(device=None): def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: """Returns the total number of devices. - On most platforms, this is the same as :py:func:`jax.local_device_count`. + On most platforms, this is the same as :py:func:`xla.local_device_count`. However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes. @@ -430,7 +423,7 @@ def devices( :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is equal to ``device_count(backend)``. Local devices can be identified by comparing :attr:`Device.process_index` to the value returned by - :py:func:`jax.process_index`. + :py:func:`xla.process_index`. If ``backend`` is ``None``, returns all the devices from the default backend. The default backend is generally ``'gpu'`` or ``'tpu'`` if available, @@ -457,13 +450,13 @@ def local_devices( backend: Optional[Union[str, XlaBackend]] = None, host_id: Optional[int] = None, ) -> List[xla_client.Device]: - """Like :py:func:`jax.devices`, but only returns devices local to a given process. + """Like :py:func:`xla.devices`, but only returns devices local to a given process. If ``process_index`` is ``None``, returns devices local to this process. Args: process_index: the integer index of the process. Process indices can be - retrieved via ``len(jax.process_count())``. + retrieved via ``len(xla.process_count())``. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. @@ -473,7 +466,7 @@ def local_devices( """ if host_id is not None: warnings.warn( - "The argument to jax.local_devices has been renamed from `host_id` to " + "The argument to xla.local_devices has been renamed from `host_id` to " "`process_index`. This alias will eventually be removed; please update " "your code." ) @@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int: return get_backend(backend).process_index() -# TODO: remove this sometime after jax 0.2.13 is released def host_id(backend=None): warnings.warn( - "jax.host_id has been renamed to jax.process_index. This alias " + "xla.host_id has been renamed to xla.process_index. This alias " "will eventually be removed; please update your code." ) return process_index(backend) def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: - """Returns the number of JAX processes associated with the backend.""" + """Returns the number of XLA processes associated with the backend.""" return max(d.process_index for d in devices(backend)) + 1 -# TODO: remove this sometime after jax 0.2.13 is released def host_count(backend=None): warnings.warn( - "jax.host_count has been renamed to jax.process_count. This alias " + "xla.host_count has been renamed to xla.process_count. This alias " "will eventually be removed; please update your code." ) return process_count(backend) -# TODO: remove this sometime after jax 0.2.13 is released def host_ids(backend=None): warnings.warn( - "jax.host_ids has been deprecated; please use range(jax.process_count()) " - "instead. jax.host_ids will eventually be removed; please update your " + "xla.host_ids has been deprecated; please use range(xla.process_count()) " + "instead. xla.host_ids will eventually be removed; please update your " "code." ) return list(range(process_count(backend))) -- GitLab