refactor(xla): add xla acknowledgement

GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
......@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0
---------------------------------------------------
......
# some code of this directory is from jax: https://github.com/google/jax
from .build import build_xla
# code of this directory is mainly from jax: https://github.com/google/jax
try:
import mge_xlalib as mge_xlalib
except ModuleNotFoundError as err:
......@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise ModuleNotFoundError(msg)
import gc
import pathlib
import os
import platform
import subprocess
import sys
import warnings
from typing import Optional
......@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla
pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit
# we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib
......@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version: int = getattr(xla_client, "_version", 0)
mlir_api_version = xla_client.mlir_api_version
def _cuda_path() -> Optional[str]:
_mgexlalib_path = pathlib.Path(mge_xlalib.__file__).parent
path = _mgexlalib_path.parent / "nvidia" / "cuda_nvcc"
if path.is_dir():
return str(path)
path = _mgexlalib_path / "cuda"
if path.is_dir():
return str(path)
return None
cuda_path = _cuda_path()
# Finds the CUDA install path
def _find_cuda_root_dir() -> Optional[str]:
cuda_root_dir = os.environ.get("CUDA_ROOT_DIR")
if cuda_root_dir is None:
try:
which = "where" if sys.platform == "win32" else "which"
with open(os.devnull, "w") as devnull:
nvcc = (
subprocess.check_output([which, "nvcc"], stderr=devnull)
.decode()
.rstrip("\r\n")
)
cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
except Exception:
if sys.platform == "win32":
assert False, "xla not supported on windows"
else:
cuda_root_dir = "/usr/local/cuda"
if not os.path.exists(cuda_root_dir):
cuda_root_dir = None
return cuda_root_dir
cuda_path = _find_cuda_root_dir()
transfer_guard_lib = xla_client._xla.transfer_guard_lib
# code of this file is mainly from jax: https://github.com/google/jax
import contextlib
import functools
import itertools
......@@ -10,9 +11,10 @@ from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional
import mge_xlalib.xla_client as xla_client
from . import jax_jit as libjax_jit
from . import xla_jit as libxla_jit
jax_jit = xla_client._xla.jax_jit
# we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
transfer_guard_lib = xla_client._xla.transfer_guard_lib
......@@ -46,10 +48,9 @@ def int_env(varname: str, default: int) -> int:
UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which "
" This will be enabled by default in future versions of XLA, at which "
"point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_)."
)
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
......@@ -181,21 +182,21 @@ class Config:
Example:
enable_foo = config.define_bool_state(
name='jax_enable_foo',
name='xla_enable_foo',
default=False,
help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo
# Now the XLA_ENABLE_FOO shell environment variable and --xla_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a
# ``config.update("xla_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
``config.xla_enable_foo``. Reading it via ``config.FLAGS.xla_enable_foo`` is
an error.
"""
......@@ -248,7 +249,7 @@ class Config:
name = name.lower()
default = os.getenv(name.upper(), default)
if default is not None and default not in enum_values:
raise ValueError(f'Invalid value "{default}" for JAX flag {name}')
raise ValueError(f'Invalid value "{default}" for XLA flag {name}')
self.DEFINE_enum(
name,
default,
......@@ -303,7 +304,7 @@ class Config:
try:
default = int(default_env)
except ValueError:
raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}')
raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}')
self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)
......@@ -350,7 +351,7 @@ class Config:
try:
default = float(default_env)
except ValueError:
raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}')
raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}')
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name)
......@@ -465,7 +466,7 @@ class Config:
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
tls = jax_jit.thread_local_state()
tls = xla_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
context = tls.extra_jit_context
......@@ -477,15 +478,14 @@ class Config:
axis_env_state,
mesh_context_manager,
self.x64_enabled,
self.jax_numpy_rank_promotion,
self.jax_default_matmul_precision,
self.jax_dynamic_shapes,
self.jax_numpy_dtype_promotion,
self.jax_default_device,
self.jax_array,
self.jax_threefry_partitionable,
# Technically this affects jaxpr->MHLO lowering, not tracing.
self.jax_hlo_source_file_canonicalization_regex,
self.xla_numpy_rank_promotion,
self.xla_default_matmul_precision,
self.xla_dynamic_shapes,
self.xla_numpy_dtype_promotion,
self.xla_default_device,
self.xla_array,
self.xla_threefry_partitionable,
self.xla_hlo_source_file_canonicalization_regex,
)
......@@ -507,7 +507,7 @@ class _StateContextManager:
default_value: Any = no_default,
):
self._name = name
self.__name__ = name[4:] if name.startswith("jax_") else name
self.__name__ = name[4:] if name.startswith("xla_") else name
self.__doc__ = (
f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}"
......@@ -599,7 +599,7 @@ class _GlobalExtraJitContext(NamedTuple):
def _update_global_jit_state(**kw):
gs = jax_jit.global_state()
gs = xla_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw)
......@@ -626,7 +626,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus
The extra_jit_context in xla_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to duduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
......@@ -641,7 +641,7 @@ _thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state()
tls = xla_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
......@@ -650,25 +650,25 @@ def update_thread_local_jit_state(**kw):
flags.DEFINE_integer(
"jax_tracer_error_num_traceback_frames",
int_env("JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5),
help="Set the number of stack frames in JAX tracer error messages.",
"xla_tracer_error_num_traceback_frames",
int_env("XLA_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5),
help="Set the number of stack frames in XLA tracer error messages.",
)
flags.DEFINE_bool(
"jax_pprint_use_color",
bool_env("JAX_PPRINT_USE_COLOR", True),
help="Enable jaxpr pretty-printing with colorful syntax highlighting.",
"xla_pprint_use_color",
bool_env("XLA_PPRINT_USE_COLOR", True),
help="Enable pretty-printing with colorful syntax highlighting.",
)
flags.DEFINE_bool(
"jax_host_callback_inline",
bool_env("JAX_HOST_CALLBACK_INLINE", False),
"xla_host_callback_inline",
bool_env("XLA_HOST_CALLBACK_INLINE", False),
help="Inline the host_callback, if not in a staged context.",
)
flags.DEFINE_integer(
"jax_host_callback_max_queue_byte_size",
int_env("JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)),
"xla_host_callback_max_queue_byte_size",
int_env("XLA_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)),
help=(
"The size in bytes of the buffer used to hold outfeeds from each "
"device. When this capacity is reached consuming outfeeds from the "
......@@ -678,8 +678,8 @@ flags.DEFINE_integer(
lower_bound=int(16 * 1e6),
)
flags.DEFINE_bool(
"jax_host_callback_outfeed",
bool_env("JAX_HOST_CALLBACK_OUTFEED", False),
"xla_host_callback_outfeed",
bool_env("XLA_HOST_CALLBACK_OUTFEED", False),
help=(
"Use outfeed implementation for host_callback, even on CPU and GPU. "
"If false, use the CustomCall implementation. "
......@@ -687,8 +687,8 @@ flags.DEFINE_bool(
),
)
flags.DEFINE_bool(
"jax_host_callback_ad_transforms",
bool_env("JAX_HOST_CALLBACK_AD_TRANSFORMS", False),
"xla_host_callback_ad_transforms",
bool_env("XLA_HOST_CALLBACK_AD_TRANSFORMS", False),
help=(
"Enable support for jvp/vjp for the host_callback primitives. Default is "
"False, which means that host_callback operates only on primals. "
......@@ -696,65 +696,63 @@ flags.DEFINE_bool(
),
)
# TODO: remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = config.define_bool_state(
name="jax2tf_associative_scan_reductions",
xla2tf_associative_scan_reductions = config.define_bool_state(
name="xla2tf_associative_scan_reductions",
default=False,
help=(
"JAX has two separate lowering rules for the cumulative reduction "
"XLA has two separate lowering rules for the cumulative reduction "
"primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses "
"a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. "
"The latter has a slow implementation on CPUs and GPUs. "
"By default, jax2tf uses the TPU lowering. Set this flag to True to "
"By default, xla2tf uses the TPU lowering. Set this flag to True to "
"use the associative scan lowering usage, and only if it makes a difference "
"for your application. "
"See the jax2tf README.md for more details."
"See the xla2tf README.md for more details."
),
)
jax2tf_default_native_serialization = config.define_bool_state(
name="jax2tf_default_native_serialization",
default=bool_env("JAX2TF_DEFAULT_NATIVE_SERIALIZATION", False),
xla2tf_default_native_serialization = config.define_bool_state(
name="xla2tf_default_native_serialization",
default=bool_env("XLA2TF_DEFAULT_NATIVE_SERIALIZATION", False),
help=(
"Sets the default value of the native_serialization parameter to "
"jax2tf.convert. Prefer using the parameter instead of the flag, the "
"xla2tf.convert. Prefer using the parameter instead of the flag, the "
"flag may be removed in the future."
),
)
# TODO: remove jax2tf_default_experimental_native_lowering
jax2tf_default_experimental_native_lowering = config.define_bool_state(
name="jax2tf_default_experimental_native_lowering",
default=bool_env("JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False),
help=("DO NOT USE, deprecated in favor of jax2tf_default_native_serialization."),
xla2tf_default_experimental_native_lowering = config.define_bool_state(
name="xla2tf_default_experimental_native_lowering",
default=bool_env("XLA2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False),
help=("DO NOT USE, deprecated in favor of xla2tf_default_native_serialization."),
)
jax_platforms = config.define_string_state(
name="jax_platforms",
xla_platforms = config.define_string_state(
name="xla_platforms",
default=None,
help=(
"Comma-separated list of platform names specifying which platforms jax "
"Comma-separated list of platform names specifying which platforms xla "
"should initialize. If any of the platforms in this list are not successfully "
"initialized, an exception will be raised and the program will be aborted. "
"The first platform in the list will be the default platform. "
"For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends "
"For example, config.xla_platforms=cpu,tpu means that CPU and TPU backends "
"will be initialized, and the CPU backend will be used unless otherwise "
"specified. If TPU initialization fails, it will raise an exception. "
"By default, jax will try to initialize all available "
"By default, xla will try to initialize all available "
"platforms and will default to GPU or TPU if available, and fallback to CPU "
"otherwise."
),
)
enable_checks = config.define_bool_state(
name="jax_enable_checks",
name="xla_enable_checks",
default=False,
help="Turn on invariant checking for JAX internals. Makes things slower.",
help="Turn on invariant checking for XLA internals. Makes things slower.",
)
check_tracer_leaks = config.define_bool_state(
name="jax_check_tracer_leaks",
name="xla_check_tracer_leaks",
default=False,
help=(
"Turn on checking for leaked tracers as soon as a trace completes. "
......@@ -767,7 +765,7 @@ check_tracer_leaks = config.define_bool_state(
checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = config.define_bool_state(
name="jax_debug_nans",
name="xla_debug_nans",
default=False,
help=(
"Add nan checks to every operation. When a nan is detected on the "
......@@ -778,7 +776,7 @@ debug_nans = config.define_bool_state(
)
debug_infs = config.define_bool_state(
name="jax_debug_infs",
name="xla_debug_infs",
default=False,
help=(
"Add inf checks to every operation. When an inf is detected on the "
......@@ -789,7 +787,7 @@ debug_infs = config.define_bool_state(
)
log_compiles = config.define_bool_state(
name="jax_log_compiles",
name="xla_log_compiles",
default=False,
help=(
"Log a message each time every time `jit` or `pmap` compiles an XLA "
......@@ -800,80 +798,78 @@ log_compiles = config.define_bool_state(
)
log_compiles = config.define_bool_state(
name="jax_log_checkpoint_residuals",
name="xla_log_checkpoint_residuals",
default=False,
help=(
"Log a message every time jax.checkpoint (aka jax.remat) is "
"Log a message every time xla.checkpoint (aka xla.remat) is "
"partially evaluated (e.g. for autodiff), printing what residuals "
"are saved."
),
)
parallel_functions_output_gda = config.define_bool_state(
name="jax_parallel_functions_output_gda",
name="xla_parallel_functions_output_gda",
default=False,
help="If True, pjit will output GDAs.",
)
def _update_jax_array_global(val):
def _update_xla_array_global(val):
if val is not None and not val:
raise ValueError("not supported in current version, please downgrad")
def _update_jax_array_thread_local(val):
def _update_xla_array_thread_local(val):
if val is not None and not val:
raise ValueError("not supported in current version, please downgrad")
jax_array = config.define_bool_state(
name="jax_array",
xla_array = config.define_bool_state(
name="xla_array",
default=True,
upgrade=True,
update_global_hook=_update_jax_array_global,
update_thread_local_hook=_update_jax_array_thread_local,
help=(
"If True, new pjit behavior will be enabled and `jax.Array` will be " "used."
),
update_global_hook=_update_xla_array_global,
update_thread_local_hook=_update_xla_array_thread_local,
help=("whether use xla array"),
)
jit_pjit_api_merge = config.define_bool_state(
name="jax_jit_pjit_api_merge",
name="xla_jit_pjit_api_merge",
default=True,
upgrade=True,
help=(
"If True, jit and pjit API will be merged. You can only disable it via "
"the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. "
"the environment variable i.e. `os.environ['XLA_JIT_PJIT_API_MERGE'] = '0'`. "
"The merge must be disabled via an environment variable since it "
"affects JAX at import time so it needs to be disabled before jax is "
"affects XLA at import time so it needs to be disabled before xla is "
"imported."
),
)
spmd_mode = config.define_enum_state(
name="jax_spmd_mode",
name="xla_spmd_mode",
enum_values=["allow_all", "allow_jit", "allow_pjit"],
# TODO: Default to `allow_jit` when the training wheels come
# off.
default="allow_pjit",
help=(
"Decides whether Math on `jax.Array`'s that are not fully addressable "
"Decides whether Math on `xla.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
"* allow_jit: `pjit` and `jax.jit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n"
" execute on non-fully addressable `xla.Array`s\n"
"* allow_jit: `pjit` and `xla.jit` computations are allowed to "
" execute on non-fully addressable `xla.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s."
" `xla.jit` and all other operations are allowed to "
" execute on non-fully addresable `xla.Array`s."
),
)
distributed_debug = config.define_bool_state(
name="jax_distributed_debug",
name="xla_distributed_debug",
default=False,
help=(
"Enable logging useful for debugging multi-process distributed "
......@@ -884,7 +880,7 @@ distributed_debug = config.define_bool_state(
enable_custom_prng = config.define_bool_state(
name="jax_enable_custom_prng",
name="xla_enable_custom_prng",
default=False,
upgrade=True,
help=(
......@@ -894,7 +890,7 @@ enable_custom_prng = config.define_bool_state(
)
default_prng_impl = config.define_enum_state(
name="jax_default_prng_impl",
name="xla_default_prng_impl",
enum_values=["threefry2x32", "rbg", "unsafe_rbg"],
default="threefry2x32",
help=(
......@@ -904,14 +900,14 @@ default_prng_impl = config.define_enum_state(
)
threefry_partitionable = config.define_bool_state(
name="jax_threefry_partitionable",
name="xla_threefry_partitionable",
default=False,
upgrade=True,
help=(
"Enables internal threefry PRNG implementation changes that "
"render it automatically partitionable in some cases. For use "
"with pjit and/or jax_array=True. Without this flag, using the "
"standard jax.random pseudo-random number generation may result "
"with pjit and/or xla=True. Without this flag, using the "
"standard xla.random pseudo-random number generation may result "
"in extraneous communication and/or redundant distributed "
"computation. With this flag, the communication overheads disappear "
"in some cases."
......@@ -923,17 +919,17 @@ threefry_partitionable = config.define_bool_state(
)
enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name="jax_enable_custom_vjp_by_custom_transpose",
name="xla_enable_custom_vjp_by_custom_transpose",
default=False,
upgrade=True,
help=(
"Enables an internal upgrade that implements `jax.custom_vjp` by "
"reduction to `jax.custom_jvp` and `jax.custom_transpose`."
"Enables an internal upgrade that implements `xla.custom_vjp` by "
"reduction to `xla.custom_jvp` and `xla.custom_transpose`."
),
)
raise_persistent_cache_errors = config.define_bool_state(
name="jax_raise_persistent_cache_errors",
name="xla_raise_persistent_cache_errors",
default=False,
help=(
"If true, exceptions raised when reading or writing to the "
......@@ -946,7 +942,7 @@ raise_persistent_cache_errors = config.define_bool_state(
)
persistent_cache_min_compile_time_secs = config.define_float_state(
name="jax_persistent_cache_min_compile_time_secs",
name="xla_persistent_cache_min_compile_time_secs",
default=1,
help=(
"The minimum compile time of a computation to be written to the "
......@@ -956,7 +952,7 @@ persistent_cache_min_compile_time_secs = config.define_float_state(
)
hlo_source_file_canonicalization_regex = config.define_string_state(
name="jax_hlo_source_file_canonicalization_regex",
name="xla_hlo_source_file_canonicalization_regex",
default=None,
help=(
"Used to canonicalize the source_path metadata of HLO instructions "
......@@ -969,18 +965,18 @@ hlo_source_file_canonicalization_regex = config.define_string_state(
)
config.define_enum_state(
name="jax_default_dtype_bits",
name="xla_default_dtype_bits",
enum_values=["32", "64"],
default="64",
help=(
"Specify bit width of default dtypes, either 32-bit or 64-bit. "
"This is a temporary flag that will be used during the process "
"of deprecating the ``jax_enable_x64`` flag."
"of deprecating the ``xla_enable_x64`` flag."
),
)
numpy_dtype_promotion = config.define_enum_state(
name="jax_numpy_dtype_promotion",
name="xla_numpy_dtype_promotion",
enum_values=["standard", "strict"],
default="standard",
help=(
......@@ -997,39 +993,38 @@ numpy_dtype_promotion = config.define_enum_state(
def _update_x64_global(val):
libjax_jit.global_state().enable_x64 = val
libxla_jit.global_state().enable_x64 = val
def _update_x64_thread_local(val):
libjax_jit.thread_local_state().enable_x64 = val
libxla_jit.thread_local_state().enable_x64 = val
enable_x64 = config.define_bool_state(
name="jax_enable_x64",
name="xla_enable_x64",
default=False,
help="Enable 64-bit types to be used",
update_global_hook=_update_x64_global,
update_thread_local_hook=_update_x64_thread_local,
)
# TODO: remove after fixing users of FLAGS.x64_enabled.
config._contextmanager_flags.remove("jax_enable_x64")
config._contextmanager_flags.remove("xla_enable_x64")
Config.x64_enabled = Config.jax_enable_x64 # type: ignore
Config.x64_enabled = Config.xla_enable_x64
def _update_default_device_global(val):
libjax_jit.global_state().default_device = val
libxla_jit.global_state().default_device = val
def _update_default_device_thread_local(val):
libjax_jit.thread_local_state().default_device = val
libxla_jit.thread_local_state().default_device = val
def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
# TODO: this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
# all XLA backends use a single C++ device interface.
if "Device" in str(type(val)):
logger.info(
"Allowing non-`xla_client.Device` default device: %s, type: %s",
......@@ -1038,20 +1033,20 @@ def _validate_default_device(val):
)
return
raise ValueError(
"jax.default_device must be passed a Device object (e.g. "
f"`jax.devices('cpu')[0]`), got: {repr(val)}"
"xla.default_device must be passed a Device object (e.g. "
f"`xla.devices('cpu')[0]`), got: {repr(val)}"
)
# TODO: default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
# platform names as well (e.g. "cpu" to mean the same as xla.devices("cpu")[0]).
default_device = config.define_string_or_object_state(
name="jax_default_device",
name="xla_default_device",
default=None,
help=(
"Configure the default device for JAX operations. Set to a Device "
'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the '
"default device for JAX operations and jit'd function calls (there is "
"Configure the default device for XLA operations. Set to a Device "
'object (e.g. ``xla.devices("cpu")[0]``) to use that Device as the '
"default device for XLA operations and jit'd function calls (there is "
"no effect on multi-device computations, e.g. pmapped function calls). "
"Set to None to use the system default device. See "
":ref:`faq-data-placement` for more information on device placement."
......@@ -1063,15 +1058,15 @@ default_device = config.define_string_or_object_state(
def _update_disable_jit_global(val):
libjax_jit.global_state().disable_jit = val
libxla_jit.global_state().disable_jit = val
def _update_disable_jit_thread_local(val):
libjax_jit.thread_local_state().disable_jit = val
libxla_jit.thread_local_state().disable_jit = val
disable_jit = config.define_bool_state(
name="jax_disable_jit",
name="xla_disable_jit",
default=False,
help=("Disable JIT compilation and just call original Python."),
update_global_hook=_update_disable_jit_global,
......@@ -1080,7 +1075,7 @@ disable_jit = config.define_bool_state(
numpy_rank_promotion = config.define_enum_state(
name="jax_numpy_rank_promotion",
name="xla_numpy_rank_promotion",
enum_values=["allow", "warn", "raise"],
default="allow",
help=(
......@@ -1094,7 +1089,7 @@ numpy_rank_promotion = config.define_enum_state(
)
default_matmul_precision = config.define_enum_state(
name="jax_default_matmul_precision",
name="xla_default_matmul_precision",
enum_values=["bfloat16", "tensorfloat32", "float32"],
default=None,
help=(
......@@ -1102,8 +1097,8 @@ default_matmul_precision = config.define_enum_state(
"Some platforms, like TPU, offer configurable precision levels for "
"matrix multiplication and convolution computations, trading off "
"accuracy for speed. The precision can be controlled for each "
"operation; for example, see the :func:`jax.lax.conv_general_dilated` "
"and :func:`jax.lax.dot` docstrings. But it can be useful to control "
"operation; for example, see the :func:`xla.lax.conv_general_dilated` "
"and :func:`xla.lax.dot` docstrings. But it can be useful to control "
"the default behavior obtained when an operation is not given a "
"specific precision.\n\n"
"This option can be used to control the default precision "
......@@ -1122,10 +1117,10 @@ default_matmul_precision = config.define_enum_state(
)
traceback_filtering = config.define_enum_state(
name="jax_traceback_filtering",
name="xla_traceback_filtering",
enum_values=["off", "tracebackhide", "remove_frames", "auto"],
default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n"
help="Controls how XLA filters internal frames out of tracebacks.\n\n"
"Valid values are:\n"
' * "off": disables traceback filtering.\n'
' * "auto": use "tracebackhide" if running under a sufficiently '
......@@ -1136,20 +1131,16 @@ traceback_filtering = config.define_enum_state(
" the unfiltered traceback as a __cause__ of the exception.\n",
)
# This flag is for internal use.
# TODO: Removes once we always enable cusparse lowering.
# TODO: Set to true after bug is fixed
bcoo_cusparse_lowering = config.define_bool_state(
name="jax_bcoo_cusparse_lowering",
name="xla_bcoo_cusparse_lowering",
default=False,
help=("Enables lowering BCOO ops to cuSparse."),
)
# TODO: remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
config.define_bool_state(
name="jax_dynamic_shapes",
default=bool(os.getenv("JAX_DYNAMIC_SHAPES", "")),
name="xla_dynamic_shapes",
default=bool(os.getenv("XLA_DYNAMIC_SHAPES", "")),
help=(
"Enables experimental features for staging out computations with "
"dynamic shapes."
......@@ -1163,14 +1154,14 @@ config.define_bool_state(
# This flag is temporary during rollout of the remat barrier.
# TODO: Remove if there are no complaints.
config.define_bool_state(
name="jax_remat_opt_barrier",
name="xla_remat_opt_barrier",
default=True,
help=("Enables using optimization-barrier op for lowering remat."),
)
# TODO: Remove flag once coordination service has rolled out.
config.define_bool_state(
name="jax_coordination_service",
name="xla_coordination_service",
default=True,
help=(
"Use coordination service (experimental) instead of the default PjRT "
......@@ -1180,17 +1171,17 @@ config.define_bool_state(
# TODO: set default to True, then remove
config.define_bool_state(
name="jax_eager_pmap",
name="xla_eager_pmap",
default=True,
upgrade=True,
help="Enable eager-mode pmap when jax_disable_jit is activated.",
help="Enable eager-mode pmap when xla_disable_jit is activated.",
)
config.define_bool_state(
name="jax_experimental_unsafe_xla_runtime_errors",
name="xla_experimental_unsafe_xla_runtime_errors",
default=False,
help=(
"Enable XLA runtime errors for jax.experimental.checkify.checks "
"Enable XLA runtime errors for xla.experimental.checkify.checks "
"on CPU and GPU. These errors are async, might get lost and are not "
"very readable. But, they crash the computation and enable you "
"to write jittable checks without needing to checkify. Does not "
......@@ -1242,10 +1233,10 @@ def _update_transfer_guard(state, key, val):
transfer_guard_host_to_device = config.define_enum_state(
name="jax_transfer_guard_host_to_device",
name="xla_transfer_guard_host_to_device",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# accidentally overriding --xla_transfer_guard.
default=None,
help=(
"Select the transfer guard level for host-to-device transfers. "
......@@ -1260,10 +1251,10 @@ transfer_guard_host_to_device = config.define_enum_state(
)
transfer_guard_device_to_device = config.define_enum_state(
name="jax_transfer_guard_device_to_device",
name="xla_transfer_guard_device_to_device",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# accidentally overriding --xla_transfer_guard.
default=None,
help=(
"Select the transfer guard level for device-to-device transfers. "
......@@ -1278,10 +1269,10 @@ transfer_guard_device_to_device = config.define_enum_state(
)
transfer_guard_device_to_host = config.define_enum_state(
name="jax_transfer_guard_device_to_host",
name="xla_transfer_guard_device_to_host",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard.
# accidentally overriding --xla_transfer_guard.
default=None,
help=(
"Select the transfer guard level for device-to-host transfers. "
......@@ -1298,18 +1289,18 @@ transfer_guard_device_to_host = config.define_enum_state(
def _update_all_transfer_guard_global(val):
for name in (
"jax_transfer_guard_host_to_device",
"jax_transfer_guard_device_to_device",
"jax_transfer_guard_device_to_host",
"xla_transfer_guard_host_to_device",
"xla_transfer_guard_device_to_device",
"xla_transfer_guard_device_to_host",
):
config.update(name, val)
_transfer_guard = config.define_enum_state(
name="jax_transfer_guard",
name="xla_transfer_guard",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*.
# accidentally overriding --xla_transfer_guard_*.
default=None,
help=(
"Select the transfer guard level for all transfers. This option is "
......
# code of this file is mainly from jax: https://github.com/google/jax
import logging
import os
import platform as py_platform
......@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger = logging.getLogger(__name__)
flags.DEFINE_string(
"jax_xla_backend", "", "Deprecated, please use --jax_platforms instead."
"xla_backend", "", "Deprecated, please use --xla_platforms instead."
)
flags.DEFINE_string(
"jax_backend_target",
os.getenv("JAX_BACKEND_TARGET", "").lower(),
"xla_backend_target",
os.getenv("XLA_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.',
)
flags.DEFINE_string(
"jax_platform_name",
os.getenv("JAX_PLATFORM_NAME", "").lower(),
"Deprecated, please use --jax_platforms instead.",
"xla_platform_name",
os.getenv("XLA_PLATFORM_NAME", "").lower(),
"Deprecated, please use --xla_platforms instead.",
)
flags.DEFINE_bool(
"jax_disable_most_optimizations",
bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False),
"xla_disable_most_optimizations",
bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.",
)
flags.DEFINE_integer(
"jax_xla_profile_version",
int_env("JAX_XLA_PROFILE_VERSION", 0),
"xla_profile_version",
int_env("XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.",
)
flags.DEFINE_string(
"jax_cuda_visible_devices",
"xla_cuda_visible_devices",
"all",
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'Restricts the set of CUDA devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
flags.DEFINE_string(
"jax_rocm_visible_devices",
"xla_rocm_visible_devices",
"all",
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'Restricts the set of ROCM devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.",
)
......@@ -69,22 +70,22 @@ def get_compile_options(
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of xla devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
......@@ -130,12 +131,12 @@ def get_compile_options(
if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.jax_disable_most_optimizations:
if FLAGS.xla_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version
compile_options.profile_version = FLAGS.xla_profile_version
return compile_options
......@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag="jax_cuda_visible_devices",
visible_devices_flag="xla_cuda_visible_devices",
),
priority=200,
)
......@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag="jax_rocm_visible_devices",
visible_devices_flag="xla_rocm_visible_devices",
),
priority=200,
)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the
# It is assumed that if xla has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory(
......@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms = _alias_to_platforms.get(platform, None)
if platforms is None:
return platform
......@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])
......@@ -270,11 +271,11 @@ def backends():
with _backend_lock:
if _backends:
return _backends
if config.jax_platforms:
jax_platforms = config.jax_platforms.split(",")
if config.xla_platforms:
xla_platforms = config.xla_platforms.split(",")
platforms = []
# Allow platform aliases in the list of platforms.
for platform in jax_platforms:
for platform in xla_platforms:
platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities)
......@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
if config.xla_platforms:
err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg)
else:
_backends_errors[platform] = str(err)
......@@ -315,12 +316,9 @@ def backends():
if (
py_platform.system() != "Darwin"
and _default_backend.platform == "cpu"
and FLAGS.jax_platform_name != "cpu"
and FLAGS.xla_platform_name != "cpu"
):
logger.warning(
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
logger.warning("No GPU/TPU found, falling back to CPU. ")
return _backends
......@@ -329,7 +327,7 @@ def _clear_backends():
global _backends_errors
global _default_backend
logger.info("Clearing JAX backend caches.")
logger.info("Clearing XLA backend caches.")
with _backend_lock:
_backends = {}
_backends_errors = {}
......@@ -351,11 +349,6 @@ def _init_backend(platform):
raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.")
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger.debug("Backend '%s' initialized", platform)
return backend
......@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if not isinstance(platform, (type(None), str)):
return platform
platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None
platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None
bs = backends()
if platform is not None:
......@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`.
On most platforms, this is the same as :py:func:`xla.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
......@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`jax.process_index`.
:py:func:`xla.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
......@@ -457,13 +450,13 @@ def local_devices(
backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None,
) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
"""Like :py:func:`xla.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``.
retrieved via ``len(xla.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
......@@ -473,7 +466,7 @@ def local_devices(
"""
if host_id is not None:
warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to "
"The argument to xla.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
......@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None):
warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias "
"xla.host_id has been renamed to xla.process_index. This alias "
"will eventually be removed; please update your code."
)
return process_index(backend)
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of JAX processes associated with the backend."""
"""Returns the number of XLA processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None):
warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias "
"xla.host_count has been renamed to xla.process_count. This alias "
"will eventually be removed; please update your code."
)
return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None):
warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your "
"xla.host_ids has been deprecated; please use range(xla.process_count()) "
"instead. xla.host_ids will eventually be removed; please update your "
"code."
)
return list(range(process_count(backend)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部