提交 5e013d8c 编写于 作者: M Megvii Engine Team

refactor(xla): add xla acknowledgement

GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
上级 4c7905f3
...@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved. ...@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE 5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved. Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0 Terms of Apache License Version 2.0
--------------------------------------------------- ---------------------------------------------------
......
# some code of this directory is from jax: https://github.com/google/jax
from .build import build_xla from .build import build_xla
# code of this directory is mainly from jax: https://github.com/google/jax
try: try:
import mge_xlalib as mge_xlalib import mge_xlalib as mge_xlalib
except ModuleNotFoundError as err: except ModuleNotFoundError as err:
...@@ -9,8 +11,10 @@ except ModuleNotFoundError as err: ...@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise ModuleNotFoundError(msg) raise ModuleNotFoundError(msg)
import gc import gc
import pathlib import os
import platform import platform
import subprocess
import sys
import warnings import warnings
from typing import Optional from typing import Optional
...@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features() ...@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension = xla_client._xla xla_extension = xla_client._xla
pytree = xla_client._xla.pytree pytree = xla_client._xla.pytree
jax_jit = xla_client._xla.jax_jit # we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
pmap_lib = xla_client._xla.pmap_lib pmap_lib = xla_client._xla.pmap_lib
...@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback) ...@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version: int = getattr(xla_client, "_version", 0) xla_extension_version: int = getattr(xla_client, "_version", 0)
mlir_api_version = xla_client.mlir_api_version mlir_api_version = xla_client.mlir_api_version
# Finds the CUDA install path
def _cuda_path() -> Optional[str]: def _find_cuda_root_dir() -> Optional[str]:
_mgexlalib_path = pathlib.Path(mge_xlalib.__file__).parent cuda_root_dir = os.environ.get("CUDA_ROOT_DIR")
path = _mgexlalib_path.parent / "nvidia" / "cuda_nvcc" if cuda_root_dir is None:
if path.is_dir(): try:
return str(path) which = "where" if sys.platform == "win32" else "which"
path = _mgexlalib_path / "cuda" with open(os.devnull, "w") as devnull:
if path.is_dir(): nvcc = (
return str(path) subprocess.check_output([which, "nvcc"], stderr=devnull)
return None .decode()
.rstrip("\r\n")
)
cuda_path = _cuda_path() cuda_root_dir = os.path.dirname(os.path.dirname(nvcc))
except Exception:
if sys.platform == "win32":
assert False, "xla not supported on windows"
else:
cuda_root_dir = "/usr/local/cuda"
if not os.path.exists(cuda_root_dir):
cuda_root_dir = None
return cuda_root_dir
cuda_path = _find_cuda_root_dir()
transfer_guard_lib = xla_client._xla.transfer_guard_lib transfer_guard_lib = xla_client._xla.transfer_guard_lib
# code of this file is mainly from jax: https://github.com/google/jax
import contextlib import contextlib
import functools import functools
import itertools import itertools
...@@ -10,9 +11,10 @@ from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional ...@@ -10,9 +11,10 @@ from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional
import mge_xlalib.xla_client as xla_client import mge_xlalib.xla_client as xla_client
from . import jax_jit as libjax_jit from . import xla_jit as libxla_jit
jax_jit = xla_client._xla.jax_jit # we use some api in jaxlib
xla_jit = xla_client._xla.jax_jit
transfer_guard_lib = xla_client._xla.transfer_guard_lib transfer_guard_lib = xla_client._xla.transfer_guard_lib
...@@ -46,10 +48,9 @@ def int_env(varname: str, default: int) -> int: ...@@ -46,10 +48,9 @@ def int_env(varname: str, default: int) -> int:
UPGRADE_BOOL_HELP = ( UPGRADE_BOOL_HELP = (
" This will be enabled by default in future versions of JAX, at which " " This will be enabled by default in future versions of XLA, at which "
"point all uses of the flag will be considered deprecated (following " "point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy " "the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_)."
) )
UPGRADE_BOOL_EXTRA_DESC = " (transient)" UPGRADE_BOOL_EXTRA_DESC = " (transient)"
...@@ -181,21 +182,21 @@ class Config: ...@@ -181,21 +182,21 @@ class Config:
Example: Example:
enable_foo = config.define_bool_state( enable_foo = config.define_bool_state(
name='jax_enable_foo', name='xla_enable_foo',
default=False, default=False,
help='Enable foo.') help='Enable foo.')
# Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo # Now the XLA_ENABLE_FOO shell environment variable and --xla_enable_foo
# command-line flag can be used to control the process-level value of # command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g. # the configuration option, in addition to using e.g.
# ``config.update("jax_enable_foo", True)`` directly. We can also use a # ``config.update("xla_enable_foo", True)`` directly. We can also use a
# context manager: # context manager:
with enable_foo(True): with enable_foo(True):
... ...
The value of the thread-local state or flag can be accessed via The value of the thread-local state or flag can be accessed via
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is ``config.xla_enable_foo``. Reading it via ``config.FLAGS.xla_enable_foo`` is
an error. an error.
""" """
...@@ -248,7 +249,7 @@ class Config: ...@@ -248,7 +249,7 @@ class Config:
name = name.lower() name = name.lower()
default = os.getenv(name.upper(), default) default = os.getenv(name.upper(), default)
if default is not None and default not in enum_values: if default is not None and default not in enum_values:
raise ValueError(f'Invalid value "{default}" for JAX flag {name}') raise ValueError(f'Invalid value "{default}" for XLA flag {name}')
self.DEFINE_enum( self.DEFINE_enum(
name, name,
default, default,
...@@ -303,7 +304,7 @@ class Config: ...@@ -303,7 +304,7 @@ class Config:
try: try:
default = int(default_env) default = int(default_env)
except ValueError: except ValueError:
raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}')
self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook) self.DEFINE_integer(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name) self._contextmanager_flags.add(name)
...@@ -350,7 +351,7 @@ class Config: ...@@ -350,7 +351,7 @@ class Config:
try: try:
default = float(default_env) default = float(default_env)
except ValueError: except ValueError:
raise ValueError(f'Invalid value "{default_env}" for JAX flag {name}') raise ValueError(f'Invalid value "{default_env}" for XLA flag {name}')
self.DEFINE_float(name, default, help=help, update_hook=update_global_hook) self.DEFINE_float(name, default, help=help, update_hook=update_global_hook)
self._contextmanager_flags.add(name) self._contextmanager_flags.add(name)
...@@ -465,7 +466,7 @@ class Config: ...@@ -465,7 +466,7 @@ class Config:
Values included in this set should also most likely be included in Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately.""" the C++ JIT state, which is handled separately."""
tls = jax_jit.thread_local_state() tls = xla_jit.thread_local_state()
axis_env_state = () axis_env_state = ()
mesh_context_manager = () mesh_context_manager = ()
context = tls.extra_jit_context context = tls.extra_jit_context
...@@ -477,15 +478,14 @@ class Config: ...@@ -477,15 +478,14 @@ class Config:
axis_env_state, axis_env_state,
mesh_context_manager, mesh_context_manager,
self.x64_enabled, self.x64_enabled,
self.jax_numpy_rank_promotion, self.xla_numpy_rank_promotion,
self.jax_default_matmul_precision, self.xla_default_matmul_precision,
self.jax_dynamic_shapes, self.xla_dynamic_shapes,
self.jax_numpy_dtype_promotion, self.xla_numpy_dtype_promotion,
self.jax_default_device, self.xla_default_device,
self.jax_array, self.xla_array,
self.jax_threefry_partitionable, self.xla_threefry_partitionable,
# Technically this affects jaxpr->MHLO lowering, not tracing. self.xla_hlo_source_file_canonicalization_regex,
self.jax_hlo_source_file_canonicalization_regex,
) )
...@@ -507,7 +507,7 @@ class _StateContextManager: ...@@ -507,7 +507,7 @@ class _StateContextManager:
default_value: Any = no_default, default_value: Any = no_default,
): ):
self._name = name self._name = name
self.__name__ = name[4:] if name.startswith("jax_") else name self.__name__ = name[4:] if name.startswith("xla_") else name
self.__doc__ = ( self.__doc__ = (
f"Context manager for `{name}` config option" f"Context manager for `{name}` config option"
f"{extra_description}.\n\n{help}" f"{extra_description}.\n\n{help}"
...@@ -599,7 +599,7 @@ class _GlobalExtraJitContext(NamedTuple): ...@@ -599,7 +599,7 @@ class _GlobalExtraJitContext(NamedTuple):
def _update_global_jit_state(**kw): def _update_global_jit_state(**kw):
gs = jax_jit.global_state() gs = xla_jit.global_state()
context = gs.extra_jit_context or _GlobalExtraJitContext() context = gs.extra_jit_context or _GlobalExtraJitContext()
gs.extra_jit_context = context._replace(**kw) gs.extra_jit_context = context._replace(**kw)
...@@ -626,7 +626,7 @@ class _ThreadLocalExtraJitContext(NamedTuple): ...@@ -626,7 +626,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
class _ThreadLocalStateCache(threading.local): class _ThreadLocalStateCache(threading.local):
""""A thread local cache for _ThreadLocalExtraJitContext """"A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in jax_jit.thread_local_state() may get updated and thus The extra_jit_context in xla_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls. incurring dispatch overhead for comparing this python object during jit calls.
We want to duduplicate the objects that have the same hash/equality to also We want to duduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object have the same object ID, since the equality check is much faster if the object
...@@ -641,7 +641,7 @@ _thread_local_state_cache = _ThreadLocalStateCache() ...@@ -641,7 +641,7 @@ _thread_local_state_cache = _ThreadLocalStateCache()
def update_thread_local_jit_state(**kw): def update_thread_local_jit_state(**kw):
tls = jax_jit.thread_local_state() tls = xla_jit.thread_local_state()
# After xla_client._version >= 70, the thread_local object will necessarily # After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the # be initialized when accessed. The following line can be removed when the
context = tls.extra_jit_context or _ThreadLocalExtraJitContext() context = tls.extra_jit_context or _ThreadLocalExtraJitContext()
...@@ -650,25 +650,25 @@ def update_thread_local_jit_state(**kw): ...@@ -650,25 +650,25 @@ def update_thread_local_jit_state(**kw):
flags.DEFINE_integer( flags.DEFINE_integer(
"jax_tracer_error_num_traceback_frames", "xla_tracer_error_num_traceback_frames",
int_env("JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5), int_env("XLA_TRACER_ERROR_NUM_TRACEBACK_FRAMES", 5),
help="Set the number of stack frames in JAX tracer error messages.", help="Set the number of stack frames in XLA tracer error messages.",
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_pprint_use_color", "xla_pprint_use_color",
bool_env("JAX_PPRINT_USE_COLOR", True), bool_env("XLA_PPRINT_USE_COLOR", True),
help="Enable jaxpr pretty-printing with colorful syntax highlighting.", help="Enable pretty-printing with colorful syntax highlighting.",
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_host_callback_inline", "xla_host_callback_inline",
bool_env("JAX_HOST_CALLBACK_INLINE", False), bool_env("XLA_HOST_CALLBACK_INLINE", False),
help="Inline the host_callback, if not in a staged context.", help="Inline the host_callback, if not in a staged context.",
) )
flags.DEFINE_integer( flags.DEFINE_integer(
"jax_host_callback_max_queue_byte_size", "xla_host_callback_max_queue_byte_size",
int_env("JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)), int_env("XLA_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE", int(256 * 1e6)),
help=( help=(
"The size in bytes of the buffer used to hold outfeeds from each " "The size in bytes of the buffer used to hold outfeeds from each "
"device. When this capacity is reached consuming outfeeds from the " "device. When this capacity is reached consuming outfeeds from the "
...@@ -678,8 +678,8 @@ flags.DEFINE_integer( ...@@ -678,8 +678,8 @@ flags.DEFINE_integer(
lower_bound=int(16 * 1e6), lower_bound=int(16 * 1e6),
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_host_callback_outfeed", "xla_host_callback_outfeed",
bool_env("JAX_HOST_CALLBACK_OUTFEED", False), bool_env("XLA_HOST_CALLBACK_OUTFEED", False),
help=( help=(
"Use outfeed implementation for host_callback, even on CPU and GPU. " "Use outfeed implementation for host_callback, even on CPU and GPU. "
"If false, use the CustomCall implementation. " "If false, use the CustomCall implementation. "
...@@ -687,8 +687,8 @@ flags.DEFINE_bool( ...@@ -687,8 +687,8 @@ flags.DEFINE_bool(
), ),
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_host_callback_ad_transforms", "xla_host_callback_ad_transforms",
bool_env("JAX_HOST_CALLBACK_AD_TRANSFORMS", False), bool_env("XLA_HOST_CALLBACK_AD_TRANSFORMS", False),
help=( help=(
"Enable support for jvp/vjp for the host_callback primitives. Default is " "Enable support for jvp/vjp for the host_callback primitives. Default is "
"False, which means that host_callback operates only on primals. " "False, which means that host_callback operates only on primals. "
...@@ -696,65 +696,63 @@ flags.DEFINE_bool( ...@@ -696,65 +696,63 @@ flags.DEFINE_bool(
), ),
) )
# TODO: remove flag when XLA:CPU is improved. xla2tf_associative_scan_reductions = config.define_bool_state(
jax2tf_associative_scan_reductions = config.define_bool_state( name="xla2tf_associative_scan_reductions",
name="jax2tf_associative_scan_reductions",
default=False, default=False,
help=( help=(
"JAX has two separate lowering rules for the cumulative reduction " "XLA has two separate lowering rules for the cumulative reduction "
"primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses " "primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses "
"a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. " "a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. "
"The latter has a slow implementation on CPUs and GPUs. " "The latter has a slow implementation on CPUs and GPUs. "
"By default, jax2tf uses the TPU lowering. Set this flag to True to " "By default, xla2tf uses the TPU lowering. Set this flag to True to "
"use the associative scan lowering usage, and only if it makes a difference " "use the associative scan lowering usage, and only if it makes a difference "
"for your application. " "for your application. "
"See the jax2tf README.md for more details." "See the xla2tf README.md for more details."
), ),
) )
jax2tf_default_native_serialization = config.define_bool_state( xla2tf_default_native_serialization = config.define_bool_state(
name="jax2tf_default_native_serialization", name="xla2tf_default_native_serialization",
default=bool_env("JAX2TF_DEFAULT_NATIVE_SERIALIZATION", False), default=bool_env("XLA2TF_DEFAULT_NATIVE_SERIALIZATION", False),
help=( help=(
"Sets the default value of the native_serialization parameter to " "Sets the default value of the native_serialization parameter to "
"jax2tf.convert. Prefer using the parameter instead of the flag, the " "xla2tf.convert. Prefer using the parameter instead of the flag, the "
"flag may be removed in the future." "flag may be removed in the future."
), ),
) )
# TODO: remove jax2tf_default_experimental_native_lowering xla2tf_default_experimental_native_lowering = config.define_bool_state(
jax2tf_default_experimental_native_lowering = config.define_bool_state( name="xla2tf_default_experimental_native_lowering",
name="jax2tf_default_experimental_native_lowering", default=bool_env("XLA2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False),
default=bool_env("JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING", False), help=("DO NOT USE, deprecated in favor of xla2tf_default_native_serialization."),
help=("DO NOT USE, deprecated in favor of jax2tf_default_native_serialization."),
) )
jax_platforms = config.define_string_state( xla_platforms = config.define_string_state(
name="jax_platforms", name="xla_platforms",
default=None, default=None,
help=( help=(
"Comma-separated list of platform names specifying which platforms jax " "Comma-separated list of platform names specifying which platforms xla "
"should initialize. If any of the platforms in this list are not successfully " "should initialize. If any of the platforms in this list are not successfully "
"initialized, an exception will be raised and the program will be aborted. " "initialized, an exception will be raised and the program will be aborted. "
"The first platform in the list will be the default platform. " "The first platform in the list will be the default platform. "
"For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends " "For example, config.xla_platforms=cpu,tpu means that CPU and TPU backends "
"will be initialized, and the CPU backend will be used unless otherwise " "will be initialized, and the CPU backend will be used unless otherwise "
"specified. If TPU initialization fails, it will raise an exception. " "specified. If TPU initialization fails, it will raise an exception. "
"By default, jax will try to initialize all available " "By default, xla will try to initialize all available "
"platforms and will default to GPU or TPU if available, and fallback to CPU " "platforms and will default to GPU or TPU if available, and fallback to CPU "
"otherwise." "otherwise."
), ),
) )
enable_checks = config.define_bool_state( enable_checks = config.define_bool_state(
name="jax_enable_checks", name="xla_enable_checks",
default=False, default=False,
help="Turn on invariant checking for JAX internals. Makes things slower.", help="Turn on invariant checking for XLA internals. Makes things slower.",
) )
check_tracer_leaks = config.define_bool_state( check_tracer_leaks = config.define_bool_state(
name="jax_check_tracer_leaks", name="xla_check_tracer_leaks",
default=False, default=False,
help=( help=(
"Turn on checking for leaked tracers as soon as a trace completes. " "Turn on checking for leaked tracers as soon as a trace completes. "
...@@ -767,7 +765,7 @@ check_tracer_leaks = config.define_bool_state( ...@@ -767,7 +765,7 @@ check_tracer_leaks = config.define_bool_state(
checking_leaks = functools.partial(check_tracer_leaks, True) checking_leaks = functools.partial(check_tracer_leaks, True)
debug_nans = config.define_bool_state( debug_nans = config.define_bool_state(
name="jax_debug_nans", name="xla_debug_nans",
default=False, default=False,
help=( help=(
"Add nan checks to every operation. When a nan is detected on the " "Add nan checks to every operation. When a nan is detected on the "
...@@ -778,7 +776,7 @@ debug_nans = config.define_bool_state( ...@@ -778,7 +776,7 @@ debug_nans = config.define_bool_state(
) )
debug_infs = config.define_bool_state( debug_infs = config.define_bool_state(
name="jax_debug_infs", name="xla_debug_infs",
default=False, default=False,
help=( help=(
"Add inf checks to every operation. When an inf is detected on the " "Add inf checks to every operation. When an inf is detected on the "
...@@ -789,7 +787,7 @@ debug_infs = config.define_bool_state( ...@@ -789,7 +787,7 @@ debug_infs = config.define_bool_state(
) )
log_compiles = config.define_bool_state( log_compiles = config.define_bool_state(
name="jax_log_compiles", name="xla_log_compiles",
default=False, default=False,
help=( help=(
"Log a message each time every time `jit` or `pmap` compiles an XLA " "Log a message each time every time `jit` or `pmap` compiles an XLA "
...@@ -800,80 +798,78 @@ log_compiles = config.define_bool_state( ...@@ -800,80 +798,78 @@ log_compiles = config.define_bool_state(
) )
log_compiles = config.define_bool_state( log_compiles = config.define_bool_state(
name="jax_log_checkpoint_residuals", name="xla_log_checkpoint_residuals",
default=False, default=False,
help=( help=(
"Log a message every time jax.checkpoint (aka jax.remat) is " "Log a message every time xla.checkpoint (aka xla.remat) is "
"partially evaluated (e.g. for autodiff), printing what residuals " "partially evaluated (e.g. for autodiff), printing what residuals "
"are saved." "are saved."
), ),
) )
parallel_functions_output_gda = config.define_bool_state( parallel_functions_output_gda = config.define_bool_state(
name="jax_parallel_functions_output_gda", name="xla_parallel_functions_output_gda",
default=False, default=False,
help="If True, pjit will output GDAs.", help="If True, pjit will output GDAs.",
) )
def _update_jax_array_global(val): def _update_xla_array_global(val):
if val is not None and not val: if val is not None and not val:
raise ValueError("not supported in current version, please downgrad") raise ValueError("not supported in current version, please downgrad")
def _update_jax_array_thread_local(val): def _update_xla_array_thread_local(val):
if val is not None and not val: if val is not None and not val:
raise ValueError("not supported in current version, please downgrad") raise ValueError("not supported in current version, please downgrad")
jax_array = config.define_bool_state( xla_array = config.define_bool_state(
name="jax_array", name="xla_array",
default=True, default=True,
upgrade=True, upgrade=True,
update_global_hook=_update_jax_array_global, update_global_hook=_update_xla_array_global,
update_thread_local_hook=_update_jax_array_thread_local, update_thread_local_hook=_update_xla_array_thread_local,
help=( help=("whether use xla array"),
"If True, new pjit behavior will be enabled and `jax.Array` will be " "used."
),
) )
jit_pjit_api_merge = config.define_bool_state( jit_pjit_api_merge = config.define_bool_state(
name="jax_jit_pjit_api_merge", name="xla_jit_pjit_api_merge",
default=True, default=True,
upgrade=True, upgrade=True,
help=( help=(
"If True, jit and pjit API will be merged. You can only disable it via " "If True, jit and pjit API will be merged. You can only disable it via "
"the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. " "the environment variable i.e. `os.environ['XLA_JIT_PJIT_API_MERGE'] = '0'`. "
"The merge must be disabled via an environment variable since it " "The merge must be disabled via an environment variable since it "
"affects JAX at import time so it needs to be disabled before jax is " "affects XLA at import time so it needs to be disabled before xla is "
"imported." "imported."
), ),
) )
spmd_mode = config.define_enum_state( spmd_mode = config.define_enum_state(
name="jax_spmd_mode", name="xla_spmd_mode",
enum_values=["allow_all", "allow_jit", "allow_pjit"], enum_values=["allow_all", "allow_jit", "allow_pjit"],
# TODO: Default to `allow_jit` when the training wheels come # TODO: Default to `allow_jit` when the training wheels come
# off. # off.
default="allow_pjit", default="allow_pjit",
help=( help=(
"Decides whether Math on `jax.Array`'s that are not fully addressable " "Decides whether Math on `xla.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: " "(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to " "* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n" " execute on non-fully addressable `xla.Array`s\n"
"* allow_jit: `pjit` and `jax.jit` computations are allowed to " "* allow_jit: `pjit` and `xla.jit` computations are allowed to "
" execute on non-fully addressable `jax.Array`s\n" " execute on non-fully addressable `xla.Array`s\n"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, " "* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `jax.jit` and all other operations are allowed to " " `xla.jit` and all other operations are allowed to "
" execute on non-fully addresable `jax.Array`s." " execute on non-fully addresable `xla.Array`s."
), ),
) )
distributed_debug = config.define_bool_state( distributed_debug = config.define_bool_state(
name="jax_distributed_debug", name="xla_distributed_debug",
default=False, default=False,
help=( help=(
"Enable logging useful for debugging multi-process distributed " "Enable logging useful for debugging multi-process distributed "
...@@ -884,7 +880,7 @@ distributed_debug = config.define_bool_state( ...@@ -884,7 +880,7 @@ distributed_debug = config.define_bool_state(
enable_custom_prng = config.define_bool_state( enable_custom_prng = config.define_bool_state(
name="jax_enable_custom_prng", name="xla_enable_custom_prng",
default=False, default=False,
upgrade=True, upgrade=True,
help=( help=(
...@@ -894,7 +890,7 @@ enable_custom_prng = config.define_bool_state( ...@@ -894,7 +890,7 @@ enable_custom_prng = config.define_bool_state(
) )
default_prng_impl = config.define_enum_state( default_prng_impl = config.define_enum_state(
name="jax_default_prng_impl", name="xla_default_prng_impl",
enum_values=["threefry2x32", "rbg", "unsafe_rbg"], enum_values=["threefry2x32", "rbg", "unsafe_rbg"],
default="threefry2x32", default="threefry2x32",
help=( help=(
...@@ -904,14 +900,14 @@ default_prng_impl = config.define_enum_state( ...@@ -904,14 +900,14 @@ default_prng_impl = config.define_enum_state(
) )
threefry_partitionable = config.define_bool_state( threefry_partitionable = config.define_bool_state(
name="jax_threefry_partitionable", name="xla_threefry_partitionable",
default=False, default=False,
upgrade=True, upgrade=True,
help=( help=(
"Enables internal threefry PRNG implementation changes that " "Enables internal threefry PRNG implementation changes that "
"render it automatically partitionable in some cases. For use " "render it automatically partitionable in some cases. For use "
"with pjit and/or jax_array=True. Without this flag, using the " "with pjit and/or xla=True. Without this flag, using the "
"standard jax.random pseudo-random number generation may result " "standard xla.random pseudo-random number generation may result "
"in extraneous communication and/or redundant distributed " "in extraneous communication and/or redundant distributed "
"computation. With this flag, the communication overheads disappear " "computation. With this flag, the communication overheads disappear "
"in some cases." "in some cases."
...@@ -923,17 +919,17 @@ threefry_partitionable = config.define_bool_state( ...@@ -923,17 +919,17 @@ threefry_partitionable = config.define_bool_state(
) )
enable_custom_vjp_by_custom_transpose = config.define_bool_state( enable_custom_vjp_by_custom_transpose = config.define_bool_state(
name="jax_enable_custom_vjp_by_custom_transpose", name="xla_enable_custom_vjp_by_custom_transpose",
default=False, default=False,
upgrade=True, upgrade=True,
help=( help=(
"Enables an internal upgrade that implements `jax.custom_vjp` by " "Enables an internal upgrade that implements `xla.custom_vjp` by "
"reduction to `jax.custom_jvp` and `jax.custom_transpose`." "reduction to `xla.custom_jvp` and `xla.custom_transpose`."
), ),
) )
raise_persistent_cache_errors = config.define_bool_state( raise_persistent_cache_errors = config.define_bool_state(
name="jax_raise_persistent_cache_errors", name="xla_raise_persistent_cache_errors",
default=False, default=False,
help=( help=(
"If true, exceptions raised when reading or writing to the " "If true, exceptions raised when reading or writing to the "
...@@ -946,7 +942,7 @@ raise_persistent_cache_errors = config.define_bool_state( ...@@ -946,7 +942,7 @@ raise_persistent_cache_errors = config.define_bool_state(
) )
persistent_cache_min_compile_time_secs = config.define_float_state( persistent_cache_min_compile_time_secs = config.define_float_state(
name="jax_persistent_cache_min_compile_time_secs", name="xla_persistent_cache_min_compile_time_secs",
default=1, default=1,
help=( help=(
"The minimum compile time of a computation to be written to the " "The minimum compile time of a computation to be written to the "
...@@ -956,7 +952,7 @@ persistent_cache_min_compile_time_secs = config.define_float_state( ...@@ -956,7 +952,7 @@ persistent_cache_min_compile_time_secs = config.define_float_state(
) )
hlo_source_file_canonicalization_regex = config.define_string_state( hlo_source_file_canonicalization_regex = config.define_string_state(
name="jax_hlo_source_file_canonicalization_regex", name="xla_hlo_source_file_canonicalization_regex",
default=None, default=None,
help=( help=(
"Used to canonicalize the source_path metadata of HLO instructions " "Used to canonicalize the source_path metadata of HLO instructions "
...@@ -969,18 +965,18 @@ hlo_source_file_canonicalization_regex = config.define_string_state( ...@@ -969,18 +965,18 @@ hlo_source_file_canonicalization_regex = config.define_string_state(
) )
config.define_enum_state( config.define_enum_state(
name="jax_default_dtype_bits", name="xla_default_dtype_bits",
enum_values=["32", "64"], enum_values=["32", "64"],
default="64", default="64",
help=( help=(
"Specify bit width of default dtypes, either 32-bit or 64-bit. " "Specify bit width of default dtypes, either 32-bit or 64-bit. "
"This is a temporary flag that will be used during the process " "This is a temporary flag that will be used during the process "
"of deprecating the ``jax_enable_x64`` flag." "of deprecating the ``xla_enable_x64`` flag."
), ),
) )
numpy_dtype_promotion = config.define_enum_state( numpy_dtype_promotion = config.define_enum_state(
name="jax_numpy_dtype_promotion", name="xla_numpy_dtype_promotion",
enum_values=["standard", "strict"], enum_values=["standard", "strict"],
default="standard", default="standard",
help=( help=(
...@@ -997,39 +993,38 @@ numpy_dtype_promotion = config.define_enum_state( ...@@ -997,39 +993,38 @@ numpy_dtype_promotion = config.define_enum_state(
def _update_x64_global(val): def _update_x64_global(val):
libjax_jit.global_state().enable_x64 = val libxla_jit.global_state().enable_x64 = val
def _update_x64_thread_local(val): def _update_x64_thread_local(val):
libjax_jit.thread_local_state().enable_x64 = val libxla_jit.thread_local_state().enable_x64 = val
enable_x64 = config.define_bool_state( enable_x64 = config.define_bool_state(
name="jax_enable_x64", name="xla_enable_x64",
default=False, default=False,
help="Enable 64-bit types to be used", help="Enable 64-bit types to be used",
update_global_hook=_update_x64_global, update_global_hook=_update_x64_global,
update_thread_local_hook=_update_x64_thread_local, update_thread_local_hook=_update_x64_thread_local,
) )
# TODO: remove after fixing users of FLAGS.x64_enabled. config._contextmanager_flags.remove("xla_enable_x64")
config._contextmanager_flags.remove("jax_enable_x64")
Config.x64_enabled = Config.jax_enable_x64 # type: ignore Config.x64_enabled = Config.xla_enable_x64
def _update_default_device_global(val): def _update_default_device_global(val):
libjax_jit.global_state().default_device = val libxla_jit.global_state().default_device = val
def _update_default_device_thread_local(val): def _update_default_device_thread_local(val):
libjax_jit.thread_local_state().default_device = val libxla_jit.thread_local_state().default_device = val
def _validate_default_device(val): def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device): if val is not None and not isinstance(val, xla_client.Device):
# TODO: this is a workaround for non-PJRT Device types. Remove when # TODO: this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface. # all XLA backends use a single C++ device interface.
if "Device" in str(type(val)): if "Device" in str(type(val)):
logger.info( logger.info(
"Allowing non-`xla_client.Device` default device: %s, type: %s", "Allowing non-`xla_client.Device` default device: %s, type: %s",
...@@ -1038,20 +1033,20 @@ def _validate_default_device(val): ...@@ -1038,20 +1033,20 @@ def _validate_default_device(val):
) )
return return
raise ValueError( raise ValueError(
"jax.default_device must be passed a Device object (e.g. " "xla.default_device must be passed a Device object (e.g. "
f"`jax.devices('cpu')[0]`), got: {repr(val)}" f"`xla.devices('cpu')[0]`), got: {repr(val)}"
) )
# TODO: default_device only accepts devices for now. Make it work with # TODO: default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). # platform names as well (e.g. "cpu" to mean the same as xla.devices("cpu")[0]).
default_device = config.define_string_or_object_state( default_device = config.define_string_or_object_state(
name="jax_default_device", name="xla_default_device",
default=None, default=None,
help=( help=(
"Configure the default device for JAX operations. Set to a Device " "Configure the default device for XLA operations. Set to a Device "
'object (e.g. ``jax.devices("cpu")[0]``) to use that Device as the ' 'object (e.g. ``xla.devices("cpu")[0]``) to use that Device as the '
"default device for JAX operations and jit'd function calls (there is " "default device for XLA operations and jit'd function calls (there is "
"no effect on multi-device computations, e.g. pmapped function calls). " "no effect on multi-device computations, e.g. pmapped function calls). "
"Set to None to use the system default device. See " "Set to None to use the system default device. See "
":ref:`faq-data-placement` for more information on device placement." ":ref:`faq-data-placement` for more information on device placement."
...@@ -1063,15 +1058,15 @@ default_device = config.define_string_or_object_state( ...@@ -1063,15 +1058,15 @@ default_device = config.define_string_or_object_state(
def _update_disable_jit_global(val): def _update_disable_jit_global(val):
libjax_jit.global_state().disable_jit = val libxla_jit.global_state().disable_jit = val
def _update_disable_jit_thread_local(val): def _update_disable_jit_thread_local(val):
libjax_jit.thread_local_state().disable_jit = val libxla_jit.thread_local_state().disable_jit = val
disable_jit = config.define_bool_state( disable_jit = config.define_bool_state(
name="jax_disable_jit", name="xla_disable_jit",
default=False, default=False,
help=("Disable JIT compilation and just call original Python."), help=("Disable JIT compilation and just call original Python."),
update_global_hook=_update_disable_jit_global, update_global_hook=_update_disable_jit_global,
...@@ -1080,7 +1075,7 @@ disable_jit = config.define_bool_state( ...@@ -1080,7 +1075,7 @@ disable_jit = config.define_bool_state(
numpy_rank_promotion = config.define_enum_state( numpy_rank_promotion = config.define_enum_state(
name="jax_numpy_rank_promotion", name="xla_numpy_rank_promotion",
enum_values=["allow", "warn", "raise"], enum_values=["allow", "warn", "raise"],
default="allow", default="allow",
help=( help=(
...@@ -1094,7 +1089,7 @@ numpy_rank_promotion = config.define_enum_state( ...@@ -1094,7 +1089,7 @@ numpy_rank_promotion = config.define_enum_state(
) )
default_matmul_precision = config.define_enum_state( default_matmul_precision = config.define_enum_state(
name="jax_default_matmul_precision", name="xla_default_matmul_precision",
enum_values=["bfloat16", "tensorfloat32", "float32"], enum_values=["bfloat16", "tensorfloat32", "float32"],
default=None, default=None,
help=( help=(
...@@ -1102,8 +1097,8 @@ default_matmul_precision = config.define_enum_state( ...@@ -1102,8 +1097,8 @@ default_matmul_precision = config.define_enum_state(
"Some platforms, like TPU, offer configurable precision levels for " "Some platforms, like TPU, offer configurable precision levels for "
"matrix multiplication and convolution computations, trading off " "matrix multiplication and convolution computations, trading off "
"accuracy for speed. The precision can be controlled for each " "accuracy for speed. The precision can be controlled for each "
"operation; for example, see the :func:`jax.lax.conv_general_dilated` " "operation; for example, see the :func:`xla.lax.conv_general_dilated` "
"and :func:`jax.lax.dot` docstrings. But it can be useful to control " "and :func:`xla.lax.dot` docstrings. But it can be useful to control "
"the default behavior obtained when an operation is not given a " "the default behavior obtained when an operation is not given a "
"specific precision.\n\n" "specific precision.\n\n"
"This option can be used to control the default precision " "This option can be used to control the default precision "
...@@ -1122,10 +1117,10 @@ default_matmul_precision = config.define_enum_state( ...@@ -1122,10 +1117,10 @@ default_matmul_precision = config.define_enum_state(
) )
traceback_filtering = config.define_enum_state( traceback_filtering = config.define_enum_state(
name="jax_traceback_filtering", name="xla_traceback_filtering",
enum_values=["off", "tracebackhide", "remove_frames", "auto"], enum_values=["off", "tracebackhide", "remove_frames", "auto"],
default="auto", default="auto",
help="Controls how JAX filters internal frames out of tracebacks.\n\n" help="Controls how XLA filters internal frames out of tracebacks.\n\n"
"Valid values are:\n" "Valid values are:\n"
' * "off": disables traceback filtering.\n' ' * "off": disables traceback filtering.\n'
' * "auto": use "tracebackhide" if running under a sufficiently ' ' * "auto": use "tracebackhide" if running under a sufficiently '
...@@ -1136,20 +1131,16 @@ traceback_filtering = config.define_enum_state( ...@@ -1136,20 +1131,16 @@ traceback_filtering = config.define_enum_state(
" the unfiltered traceback as a __cause__ of the exception.\n", " the unfiltered traceback as a __cause__ of the exception.\n",
) )
# This flag is for internal use.
# TODO: Removes once we always enable cusparse lowering.
# TODO: Set to true after bug is fixed
bcoo_cusparse_lowering = config.define_bool_state( bcoo_cusparse_lowering = config.define_bool_state(
name="jax_bcoo_cusparse_lowering", name="xla_bcoo_cusparse_lowering",
default=False, default=False,
help=("Enables lowering BCOO ops to cuSparse."), help=("Enables lowering BCOO ops to cuSparse."),
) )
# TODO: remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
config.define_bool_state( config.define_bool_state(
name="jax_dynamic_shapes", name="xla_dynamic_shapes",
default=bool(os.getenv("JAX_DYNAMIC_SHAPES", "")), default=bool(os.getenv("XLA_DYNAMIC_SHAPES", "")),
help=( help=(
"Enables experimental features for staging out computations with " "Enables experimental features for staging out computations with "
"dynamic shapes." "dynamic shapes."
...@@ -1163,14 +1154,14 @@ config.define_bool_state( ...@@ -1163,14 +1154,14 @@ config.define_bool_state(
# This flag is temporary during rollout of the remat barrier. # This flag is temporary during rollout of the remat barrier.
# TODO: Remove if there are no complaints. # TODO: Remove if there are no complaints.
config.define_bool_state( config.define_bool_state(
name="jax_remat_opt_barrier", name="xla_remat_opt_barrier",
default=True, default=True,
help=("Enables using optimization-barrier op for lowering remat."), help=("Enables using optimization-barrier op for lowering remat."),
) )
# TODO: Remove flag once coordination service has rolled out. # TODO: Remove flag once coordination service has rolled out.
config.define_bool_state( config.define_bool_state(
name="jax_coordination_service", name="xla_coordination_service",
default=True, default=True,
help=( help=(
"Use coordination service (experimental) instead of the default PjRT " "Use coordination service (experimental) instead of the default PjRT "
...@@ -1180,17 +1171,17 @@ config.define_bool_state( ...@@ -1180,17 +1171,17 @@ config.define_bool_state(
# TODO: set default to True, then remove # TODO: set default to True, then remove
config.define_bool_state( config.define_bool_state(
name="jax_eager_pmap", name="xla_eager_pmap",
default=True, default=True,
upgrade=True, upgrade=True,
help="Enable eager-mode pmap when jax_disable_jit is activated.", help="Enable eager-mode pmap when xla_disable_jit is activated.",
) )
config.define_bool_state( config.define_bool_state(
name="jax_experimental_unsafe_xla_runtime_errors", name="xla_experimental_unsafe_xla_runtime_errors",
default=False, default=False,
help=( help=(
"Enable XLA runtime errors for jax.experimental.checkify.checks " "Enable XLA runtime errors for xla.experimental.checkify.checks "
"on CPU and GPU. These errors are async, might get lost and are not " "on CPU and GPU. These errors are async, might get lost and are not "
"very readable. But, they crash the computation and enable you " "very readable. But, they crash the computation and enable you "
"to write jittable checks without needing to checkify. Does not " "to write jittable checks without needing to checkify. Does not "
...@@ -1242,10 +1233,10 @@ def _update_transfer_guard(state, key, val): ...@@ -1242,10 +1233,10 @@ def _update_transfer_guard(state, key, val):
transfer_guard_host_to_device = config.define_enum_state( transfer_guard_host_to_device = config.define_enum_state(
name="jax_transfer_guard_host_to_device", name="xla_transfer_guard_host_to_device",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid # The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard. # accidentally overriding --xla_transfer_guard.
default=None, default=None,
help=( help=(
"Select the transfer guard level for host-to-device transfers. " "Select the transfer guard level for host-to-device transfers. "
...@@ -1260,10 +1251,10 @@ transfer_guard_host_to_device = config.define_enum_state( ...@@ -1260,10 +1251,10 @@ transfer_guard_host_to_device = config.define_enum_state(
) )
transfer_guard_device_to_device = config.define_enum_state( transfer_guard_device_to_device = config.define_enum_state(
name="jax_transfer_guard_device_to_device", name="xla_transfer_guard_device_to_device",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid # The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard. # accidentally overriding --xla_transfer_guard.
default=None, default=None,
help=( help=(
"Select the transfer guard level for device-to-device transfers. " "Select the transfer guard level for device-to-device transfers. "
...@@ -1278,10 +1269,10 @@ transfer_guard_device_to_device = config.define_enum_state( ...@@ -1278,10 +1269,10 @@ transfer_guard_device_to_device = config.define_enum_state(
) )
transfer_guard_device_to_host = config.define_enum_state( transfer_guard_device_to_host = config.define_enum_state(
name="jax_transfer_guard_device_to_host", name="xla_transfer_guard_device_to_host",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid # The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard. # accidentally overriding --xla_transfer_guard.
default=None, default=None,
help=( help=(
"Select the transfer guard level for device-to-host transfers. " "Select the transfer guard level for device-to-host transfers. "
...@@ -1298,18 +1289,18 @@ transfer_guard_device_to_host = config.define_enum_state( ...@@ -1298,18 +1289,18 @@ transfer_guard_device_to_host = config.define_enum_state(
def _update_all_transfer_guard_global(val): def _update_all_transfer_guard_global(val):
for name in ( for name in (
"jax_transfer_guard_host_to_device", "xla_transfer_guard_host_to_device",
"jax_transfer_guard_device_to_device", "xla_transfer_guard_device_to_device",
"jax_transfer_guard_device_to_host", "xla_transfer_guard_device_to_host",
): ):
config.update(name, val) config.update(name, val)
_transfer_guard = config.define_enum_state( _transfer_guard = config.define_enum_state(
name="jax_transfer_guard", name="xla_transfer_guard",
enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"], enum_values=["allow", "log", "disallow", "log_explicit", "disallow_explicit"],
# The default is applied by transfer_guard_lib. Use None here to avoid # The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --jax_transfer_guard_*. # accidentally overriding --xla_transfer_guard_*.
default=None, default=None,
help=( help=(
"Select the transfer guard level for all transfers. This option is " "Select the transfer guard level for all transfers. This option is "
......
# code of this file is mainly from jax: https://github.com/google/jax
import logging import logging
import os import os
import platform as py_platform import platform as py_platform
...@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS ...@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
flags.DEFINE_string( flags.DEFINE_string(
"jax_xla_backend", "", "Deprecated, please use --jax_platforms instead." "xla_backend", "", "Deprecated, please use --xla_platforms instead."
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_backend_target", "xla_backend_target",
os.getenv("JAX_BACKEND_TARGET", "").lower(), os.getenv("XLA_BACKEND_TARGET", "").lower(),
'Either "local" or "rpc:address" to connect to a remote service target.', 'Either "local" or "rpc:address" to connect to a remote service target.',
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_platform_name", "xla_platform_name",
os.getenv("JAX_PLATFORM_NAME", "").lower(), os.getenv("XLA_PLATFORM_NAME", "").lower(),
"Deprecated, please use --jax_platforms instead.", "Deprecated, please use --xla_platforms instead.",
) )
flags.DEFINE_bool( flags.DEFINE_bool(
"jax_disable_most_optimizations", "xla_disable_most_optimizations",
bool_env("JAX_DISABLE_MOST_OPTIMIZATIONS", False), bool_env("XLA_DISABLE_MOST_OPTIMIZATIONS", False),
"Try not to do much optimization work. This can be useful if the cost of " "Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program.", "optimization is greater than that of running a less-optimized program.",
) )
flags.DEFINE_integer( flags.DEFINE_integer(
"jax_xla_profile_version", "xla_profile_version",
int_env("JAX_XLA_PROFILE_VERSION", 0), int_env("XLA_PROFILE_VERSION", 0),
"Optional profile version for XLA compilation. " "Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to " "This is meaningful only when XLA is configured to "
"support the remote compilation profile feature.", "support the remote compilation profile feature.",
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_cuda_visible_devices", "xla_cuda_visible_devices",
"all", "all",
'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'Restricts the set of CUDA devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.", "comma-separate list of integer device IDs.",
) )
flags.DEFINE_string( flags.DEFINE_string(
"jax_rocm_visible_devices", "xla_rocm_visible_devices",
"all", "all",
'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'Restricts the set of ROCM devices that XLA will use. Either "all", or a '
"comma-separate list of integer device IDs.", "comma-separate list of integer device IDs.",
) )
...@@ -69,22 +70,22 @@ def get_compile_options( ...@@ -69,22 +70,22 @@ def get_compile_options(
) -> xla_client.CompileOptions: ) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values. """Returns the compile options to use, as derived from flag values.
Args: Args:
num_replicas: Number of replicas for which to compile. num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile. num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment device_assignment: Optional ndarray of xla devices indicating the assignment
of logical replicas to physical devices (default inherited from of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`. `num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA. partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner. generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space. auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space. auto_spmd_partitioning search space.
""" """
compile_options = xla_client.CompileOptions() compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions compile_options.num_partitions = num_partitions
...@@ -130,12 +131,12 @@ def get_compile_options( ...@@ -130,12 +131,12 @@ def get_compile_options(
if cuda_path is not None: if cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = cuda_path debug_options.xla_gpu_cuda_data_dir = cuda_path
if FLAGS.jax_disable_most_optimizations: if FLAGS.xla_disable_most_optimizations:
debug_options.xla_backend_optimization_level = 0 debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version compile_options.profile_version = FLAGS.xla_profile_version
return compile_options return compile_options
...@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"): ...@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial( partial(
make_gpu_client, make_gpu_client,
platform_name="cuda", platform_name="cuda",
visible_devices_flag="jax_cuda_visible_devices", visible_devices_flag="xla_cuda_visible_devices",
), ),
priority=200, priority=200,
) )
...@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"): ...@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial( partial(
make_gpu_client, make_gpu_client,
platform_name="rocm", platform_name="rocm",
visible_devices_flag="jax_rocm_visible_devices", visible_devices_flag="xla_rocm_visible_devices",
), ),
priority=200, priority=200,
) )
if hasattr(xla_client, "make_plugin_device_client"): if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the # It is assumed that if xla has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the # user wants to use the plugin client by default. Therefore, it gets the
# highest priority. # highest priority.
register_backend_factory( register_backend_factory(
...@@ -229,11 +230,11 @@ def is_known_platform(platform: str): ...@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def canonicalize_platform(platform: str) -> str: def canonicalize_platform(platform: str) -> str:
"""Replaces platform aliases with their concrete equivalent. """Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care. force users to care.
""" """
platforms = _alias_to_platforms.get(platform, None) platforms = _alias_to_platforms.get(platform, None)
if platforms is None: if platforms is None:
return platform return platform
...@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str: ...@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def expand_platform_alias(platform: str) -> List[str]: def expand_platform_alias(platform: str) -> List[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"]. """Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code. in many respects since they share most of the same code.
""" """
return _alias_to_platforms.get(platform, [platform]) return _alias_to_platforms.get(platform, [platform])
...@@ -270,11 +271,11 @@ def backends(): ...@@ -270,11 +271,11 @@ def backends():
with _backend_lock: with _backend_lock:
if _backends: if _backends:
return _backends return _backends
if config.jax_platforms: if config.xla_platforms:
jax_platforms = config.jax_platforms.split(",") xla_platforms = config.xla_platforms.split(",")
platforms = [] platforms = []
# Allow platform aliases in the list of platforms. # Allow platform aliases in the list of platforms.
for platform in jax_platforms: for platform in xla_platforms:
platforms.extend(expand_platform_alias(platform)) platforms.extend(expand_platform_alias(platform))
priorities = range(len(platforms), 0, -1) priorities = range(len(platforms), 0, -1)
platforms_and_priorites = zip(platforms, priorities) platforms_and_priorites = zip(platforms, priorities)
...@@ -303,8 +304,8 @@ def backends(): ...@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices, # If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError. # we expect a RuntimeError.
err_msg = f"Unable to initialize backend '{platform}': {err}" err_msg = f"Unable to initialize backend '{platform}': {err}"
if config.jax_platforms: if config.xla_platforms:
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)" err_msg += " (set XLA_PLATFORMS='' to automatically choose an available backend)"
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
else: else:
_backends_errors[platform] = str(err) _backends_errors[platform] = str(err)
...@@ -315,12 +316,9 @@ def backends(): ...@@ -315,12 +316,9 @@ def backends():
if ( if (
py_platform.system() != "Darwin" py_platform.system() != "Darwin"
and _default_backend.platform == "cpu" and _default_backend.platform == "cpu"
and FLAGS.jax_platform_name != "cpu" and FLAGS.xla_platform_name != "cpu"
): ):
logger.warning( logger.warning("No GPU/TPU found, falling back to CPU. ")
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
return _backends return _backends
...@@ -329,7 +327,7 @@ def _clear_backends(): ...@@ -329,7 +327,7 @@ def _clear_backends():
global _backends_errors global _backends_errors
global _default_backend global _default_backend
logger.info("Clearing JAX backend caches.") logger.info("Clearing XLA backend caches.")
with _backend_lock: with _backend_lock:
_backends = {} _backends = {}
_backends_errors = {} _backends_errors = {}
...@@ -351,11 +349,6 @@ def _init_backend(platform): ...@@ -351,11 +349,6 @@ def _init_backend(platform):
raise RuntimeError(f"Could not initialize backend '{platform}'") raise RuntimeError(f"Could not initialize backend '{platform}'")
if backend.device_count() == 0: if backend.device_count() == 0:
raise RuntimeError(f"Backend '{platform}' provides no devices.") raise RuntimeError(f"Backend '{platform}' provides no devices.")
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger.debug("Backend '%s' initialized", platform) logger.debug("Backend '%s' initialized", platform)
return backend return backend
...@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None): ...@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if not isinstance(platform, (type(None), str)): if not isinstance(platform, (type(None), str)):
return platform return platform
platform = platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name or None platform = platform or FLAGS.xla_backend or FLAGS.xla_platform_name or None
bs = backends() bs = backends()
if platform is not None: if platform is not None:
...@@ -399,7 +392,7 @@ def get_device_backend(device=None): ...@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: def device_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the total number of devices. """Returns the total number of devices.
On most platforms, this is the same as :py:func:`jax.local_device_count`. On most platforms, this is the same as :py:func:`xla.local_device_count`.
However, on multi-process platforms where different devices are associated However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across with different processes, this will return the total number of devices across
all processes. all processes.
...@@ -430,7 +423,7 @@ def devices( ...@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by comparing :attr:`Device.process_index` to the value returned by
:py:func:`jax.process_index`. :py:func:`xla.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend. If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available, The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
...@@ -457,13 +450,13 @@ def local_devices( ...@@ -457,13 +450,13 @@ def local_devices(
backend: Optional[Union[str, XlaBackend]] = None, backend: Optional[Union[str, XlaBackend]] = None,
host_id: Optional[int] = None, host_id: Optional[int] = None,
) -> List[xla_client.Device]: ) -> List[xla_client.Device]:
"""Like :py:func:`jax.devices`, but only returns devices local to a given process. """Like :py:func:`xla.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process. If ``process_index`` is ``None``, returns devices local to this process.
Args: Args:
process_index: the integer index of the process. Process indices can be process_index: the integer index of the process. Process indices can be
retrieved via ``len(jax.process_count())``. retrieved via ``len(xla.process_count())``.
backend: This is an experimental feature and the API is likely to change. backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``. ``'tpu'``.
...@@ -473,7 +466,7 @@ def local_devices( ...@@ -473,7 +466,7 @@ def local_devices(
""" """
if host_id is not None: if host_id is not None:
warnings.warn( warnings.warn(
"The argument to jax.local_devices has been renamed from `host_id` to " "The argument to xla.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update " "`process_index`. This alias will eventually be removed; please update "
"your code." "your code."
) )
...@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int: ...@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return get_backend(backend).process_index() return get_backend(backend).process_index()
# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend=None): def host_id(backend=None):
warnings.warn( warnings.warn(
"jax.host_id has been renamed to jax.process_index. This alias " "xla.host_id has been renamed to xla.process_index. This alias "
"will eventually be removed; please update your code." "will eventually be removed; please update your code."
) )
return process_index(backend) return process_index(backend)
def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int: def process_count(backend: Optional[Union[str, XlaBackend]] = None) -> int:
"""Returns the number of JAX processes associated with the backend.""" """Returns the number of XLA processes associated with the backend."""
return max(d.process_index for d in devices(backend)) + 1 return max(d.process_index for d in devices(backend)) + 1
# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend=None): def host_count(backend=None):
warnings.warn( warnings.warn(
"jax.host_count has been renamed to jax.process_count. This alias " "xla.host_count has been renamed to xla.process_count. This alias "
"will eventually be removed; please update your code." "will eventually be removed; please update your code."
) )
return process_count(backend) return process_count(backend)
# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(backend=None): def host_ids(backend=None):
warnings.warn( warnings.warn(
"jax.host_ids has been deprecated; please use range(jax.process_count()) " "xla.host_ids has been deprecated; please use range(xla.process_count()) "
"instead. jax.host_ids will eventually be removed; please update your " "instead. xla.host_ids will eventually be removed; please update your "
"code." "code."
) )
return list(range(process_count(backend))) return list(range(process_count(backend)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册