Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5e013d8c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5e013d8c
编写于
1年前
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(xla): add xla acknowledgement
GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
上级
4c7905f3
master
release-1.13.0
release-1.13.1
try-import
v1.13.1
v1.13.0
无相关合并请求
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
244 addition
and
240 deletion
+244
-240
ACKNOWLEDGMENTS
ACKNOWLEDGMENTS
+5
-0
imperative/python/megengine/xla/__init__.py
imperative/python/megengine/xla/__init__.py
+2
-0
imperative/python/megengine/xla/lib/__init__.py
imperative/python/megengine/xla/lib/__init__.py
+31
-15
imperative/python/megengine/xla/lib/config.py
imperative/python/megengine/xla/lib/config.py
+143
-152
imperative/python/megengine/xla/lib/xla_bridge.py
imperative/python/megengine/xla/lib/xla_bridge.py
+63
-73
未找到文件。
ACKNOWLEDGMENTS
浏览文件 @
5e013d8c
...
...
@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0
---------------------------------------------------
...
...
This diff is collapsed.
Click to expand it.
imperative/python/megengine/xla/__init__.py
浏览文件 @
5e013d8c
# some code of this directory is from jax: https://github.com/google/jax
from
.build
import
build_xla
This diff is collapsed.
Click to expand it.
imperative/python/megengine/xla/lib/__init__.py
浏览文件 @
5e013d8c
# code of this directory is mainly from jax: https://github.com/google/jax
try
:
import
mge_xlalib
as
mge_xlalib
except
ModuleNotFoundError
as
err
:
...
...
@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise
ModuleNotFoundError
(
msg
)
import
gc
import
pathlib
import
os
import
platform
import
subprocess
import
sys
import
warnings
from
typing
import
Optional
...
...
@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension
=
xla_client
.
_xla
pytree
=
xla_client
.
_xla
.
pytree
jax_jit
=
xla_client
.
_xla
.
jax_jit
# we use some api in jaxlib
xla_jit
=
xla_client
.
_xla
.
jax_jit
pmap_lib
=
xla_client
.
_xla
.
pmap_lib
...
...
@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version
:
int
=
getattr
(
xla_client
,
"_version"
,
0
)
mlir_api_version
=
xla_client
.
mlir_api_version
def
_cuda_path
()
->
Optional
[
str
]:
_mgexlalib_path
=
pathlib
.
Path
(
mge_xlalib
.
__file__
).
parent
path
=
_mgexlalib_path
.
parent
/
"nvidia"
/
"cuda_nvcc"
if
path
.
is_dir
():
return
str
(
path
)
path
=
_mgexlalib_path
/
"cuda"
if
path
.
is_dir
():
return
str
(
path
)
return
None
cuda_path
=
_cuda_path
()
# Finds the CUDA install path
def
_find_cuda_root_dir
()
->
Optional
[
str
]:
cuda_root_dir
=
os
.
environ
.
get
(
"CUDA_ROOT_DIR"
)
if
cuda_root_dir
is
None
:
try
:
which
=
"where"
if
sys
.
platform
==
"win32"
else
"which"
with
open
(
os
.
devnull
,
"w"
)
as
devnull
:
nvcc
=
(
subprocess
.
check_output
([
which
,
"nvcc"
],
stderr
=
devnull
)
.
decode
()
.
rstrip
(
"
\r\n
"
)
)
cuda_root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
nvcc
))
except
Exception
:
if
sys
.
platform
==
"win32"
:
assert
False
,
"xla not supported on windows"
else
:
cuda_root_dir
=
"/usr/local/cuda"
if
not
os
.
path
.
exists
(
cuda_root_dir
):
cuda_root_dir
=
None
return
cuda_root_dir
cuda_path
=
_find_cuda_root_dir
()
transfer_guard_lib
=
xla_client
.
_xla
.
transfer_guard_lib
This diff is collapsed.
Click to expand it.
imperative/python/megengine/xla/lib/config.py
浏览文件 @
5e013d8c
# code of this file is mainly from jax: https://github.com/google/jax
import
contextlib
import
functools
import
itertools
...
...
@@ -10,9 +11,10 @@ from typing import Any, Callable, Hashable, Iterator, List, NamedTuple, Optional
import
mge_xlalib.xla_client
as
xla_client
from
.
import
jax_jit
as
libjax
_jit
from
.
import
xla_jit
as
libxla
_jit
jax_jit
=
xla_client
.
_xla
.
jax_jit
# we use some api in jaxlib
xla_jit
=
xla_client
.
_xla
.
jax_jit
transfer_guard_lib
=
xla_client
.
_xla
.
transfer_guard_lib
...
...
@@ -46,10 +48,9 @@ def int_env(varname: str, default: int) -> int:
UPGRADE_BOOL_HELP
=
(
" This will be enabled by default in future versions of
JAX
, at which "
" This will be enabled by default in future versions of
XLA
, at which "
"point all uses of the flag will be considered deprecated (following "
"the `API compatibility policy "
"<https://jax.readthedocs.io/en/latest/api_compatibility.html>`_)."
)
UPGRADE_BOOL_EXTRA_DESC
=
" (transient)"
...
...
@@ -181,21 +182,21 @@ class Config:
Example:
enable_foo = config.define_bool_state(
name='
jax
_enable_foo',
name='
xla
_enable_foo',
default=False,
help='Enable foo.')
# Now the
JAX_ENABLE_FOO shell environment variable and --jax
_enable_foo
# Now the
XLA_ENABLE_FOO shell environment variable and --xla
_enable_foo
# command-line flag can be used to control the process-level value of
# the configuration option, in addition to using e.g.
# ``config.update("
jax
_enable_foo", True)`` directly. We can also use a
# ``config.update("
xla
_enable_foo", True)`` directly. We can also use a
# context manager:
with enable_foo(True):
...
The value of the thread-local state or flag can be accessed via
``config.
jax_enable_foo``. Reading it via ``config.FLAGS.jax
_enable_foo`` is
``config.
xla_enable_foo``. Reading it via ``config.FLAGS.xla
_enable_foo`` is
an error.
"""
...
...
@@ -248,7 +249,7 @@ class Config:
name
=
name
.
lower
()
default
=
os
.
getenv
(
name
.
upper
(),
default
)
if
default
is
not
None
and
default
not
in
enum_values
:
raise
ValueError
(
f
'Invalid value "
{
default
}
" for
JAX
flag
{
name
}
'
)
raise
ValueError
(
f
'Invalid value "
{
default
}
" for
XLA
flag
{
name
}
'
)
self
.
DEFINE_enum
(
name
,
default
,
...
...
@@ -303,7 +304,7 @@ class Config:
try
:
default
=
int
(
default_env
)
except
ValueError
:
raise
ValueError
(
f
'Invalid value "
{
default_env
}
" for
JAX
flag
{
name
}
'
)
raise
ValueError
(
f
'Invalid value "
{
default_env
}
" for
XLA
flag
{
name
}
'
)
self
.
DEFINE_integer
(
name
,
default
,
help
=
help
,
update_hook
=
update_global_hook
)
self
.
_contextmanager_flags
.
add
(
name
)
...
...
@@ -350,7 +351,7 @@ class Config:
try
:
default
=
float
(
default_env
)
except
ValueError
:
raise
ValueError
(
f
'Invalid value "
{
default_env
}
" for
JAX
flag
{
name
}
'
)
raise
ValueError
(
f
'Invalid value "
{
default_env
}
" for
XLA
flag
{
name
}
'
)
self
.
DEFINE_float
(
name
,
default
,
help
=
help
,
update_hook
=
update_global_hook
)
self
.
_contextmanager_flags
.
add
(
name
)
...
...
@@ -465,7 +466,7 @@ class Config:
Values included in this set should also most likely be included in
the C++ JIT state, which is handled separately."""
tls
=
jax
_jit
.
thread_local_state
()
tls
=
xla
_jit
.
thread_local_state
()
axis_env_state
=
()
mesh_context_manager
=
()
context
=
tls
.
extra_jit_context
...
...
@@ -477,15 +478,14 @@ class Config:
axis_env_state
,
mesh_context_manager
,
self
.
x64_enabled
,
self
.
jax_numpy_rank_promotion
,
self
.
jax_default_matmul_precision
,
self
.
jax_dynamic_shapes
,
self
.
jax_numpy_dtype_promotion
,
self
.
jax_default_device
,
self
.
jax_array
,
self
.
jax_threefry_partitionable
,
# Technically this affects jaxpr->MHLO lowering, not tracing.
self
.
jax_hlo_source_file_canonicalization_regex
,
self
.
xla_numpy_rank_promotion
,
self
.
xla_default_matmul_precision
,
self
.
xla_dynamic_shapes
,
self
.
xla_numpy_dtype_promotion
,
self
.
xla_default_device
,
self
.
xla_array
,
self
.
xla_threefry_partitionable
,
self
.
xla_hlo_source_file_canonicalization_regex
,
)
...
...
@@ -507,7 +507,7 @@ class _StateContextManager:
default_value
:
Any
=
no_default
,
):
self
.
_name
=
name
self
.
__name__
=
name
[
4
:]
if
name
.
startswith
(
"
jax
_"
)
else
name
self
.
__name__
=
name
[
4
:]
if
name
.
startswith
(
"
xla
_"
)
else
name
self
.
__doc__
=
(
f
"Context manager for `
{
name
}
` config option"
f
"
{
extra_description
}
.
\n\n
{
help
}
"
...
...
@@ -599,7 +599,7 @@ class _GlobalExtraJitContext(NamedTuple):
def
_update_global_jit_state
(
**
kw
):
gs
=
jax
_jit
.
global_state
()
gs
=
xla
_jit
.
global_state
()
context
=
gs
.
extra_jit_context
or
_GlobalExtraJitContext
()
gs
.
extra_jit_context
=
context
.
_replace
(
**
kw
)
...
...
@@ -626,7 +626,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
class
_ThreadLocalStateCache
(
threading
.
local
):
""""A thread local cache for _ThreadLocalExtraJitContext
The extra_jit_context in
jax
_jit.thread_local_state() may get updated and thus
The extra_jit_context in
xla
_jit.thread_local_state() may get updated and thus
incurring dispatch overhead for comparing this python object during jit calls.
We want to duduplicate the objects that have the same hash/equality to also
have the same object ID, since the equality check is much faster if the object
...
...
@@ -641,7 +641,7 @@ _thread_local_state_cache = _ThreadLocalStateCache()
def
update_thread_local_jit_state
(
**
kw
):
tls
=
jax
_jit
.
thread_local_state
()
tls
=
xla
_jit
.
thread_local_state
()
# After xla_client._version >= 70, the thread_local object will necessarily
# be initialized when accessed. The following line can be removed when the
context
=
tls
.
extra_jit_context
or
_ThreadLocalExtraJitContext
()
...
...
@@ -650,25 +650,25 @@ def update_thread_local_jit_state(**kw):
flags
.
DEFINE_integer
(
"
jax
_tracer_error_num_traceback_frames"
,
int_env
(
"
JAX
_TRACER_ERROR_NUM_TRACEBACK_FRAMES"
,
5
),
help
=
"Set the number of stack frames in
JAX
tracer error messages."
,
"
xla
_tracer_error_num_traceback_frames"
,
int_env
(
"
XLA
_TRACER_ERROR_NUM_TRACEBACK_FRAMES"
,
5
),
help
=
"Set the number of stack frames in
XLA
tracer error messages."
,
)
flags
.
DEFINE_bool
(
"
jax
_pprint_use_color"
,
bool_env
(
"
JAX
_PPRINT_USE_COLOR"
,
True
),
help
=
"Enable
jaxpr
pretty-printing with colorful syntax highlighting."
,
"
xla
_pprint_use_color"
,
bool_env
(
"
XLA
_PPRINT_USE_COLOR"
,
True
),
help
=
"Enable pretty-printing with colorful syntax highlighting."
,
)
flags
.
DEFINE_bool
(
"
jax
_host_callback_inline"
,
bool_env
(
"
JAX
_HOST_CALLBACK_INLINE"
,
False
),
"
xla
_host_callback_inline"
,
bool_env
(
"
XLA
_HOST_CALLBACK_INLINE"
,
False
),
help
=
"Inline the host_callback, if not in a staged context."
,
)
flags
.
DEFINE_integer
(
"
jax
_host_callback_max_queue_byte_size"
,
int_env
(
"
JAX
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE"
,
int
(
256
*
1e6
)),
"
xla
_host_callback_max_queue_byte_size"
,
int_env
(
"
XLA
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE"
,
int
(
256
*
1e6
)),
help
=
(
"The size in bytes of the buffer used to hold outfeeds from each "
"device. When this capacity is reached consuming outfeeds from the "
...
...
@@ -678,8 +678,8 @@ flags.DEFINE_integer(
lower_bound
=
int
(
16
*
1e6
),
)
flags
.
DEFINE_bool
(
"
jax
_host_callback_outfeed"
,
bool_env
(
"
JAX
_HOST_CALLBACK_OUTFEED"
,
False
),
"
xla
_host_callback_outfeed"
,
bool_env
(
"
XLA
_HOST_CALLBACK_OUTFEED"
,
False
),
help
=
(
"Use outfeed implementation for host_callback, even on CPU and GPU. "
"If false, use the CustomCall implementation. "
...
...
@@ -687,8 +687,8 @@ flags.DEFINE_bool(
),
)
flags
.
DEFINE_bool
(
"
jax
_host_callback_ad_transforms"
,
bool_env
(
"
JAX
_HOST_CALLBACK_AD_TRANSFORMS"
,
False
),
"
xla
_host_callback_ad_transforms"
,
bool_env
(
"
XLA
_HOST_CALLBACK_AD_TRANSFORMS"
,
False
),
help
=
(
"Enable support for jvp/vjp for the host_callback primitives. Default is "
"False, which means that host_callback operates only on primals. "
...
...
@@ -696,65 +696,63 @@ flags.DEFINE_bool(
),
)
# TODO: remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions
=
config
.
define_bool_state
(
name
=
"jax2tf_associative_scan_reductions"
,
xla2tf_associative_scan_reductions
=
config
.
define_bool_state
(
name
=
"xla2tf_associative_scan_reductions"
,
default
=
False
,
help
=
(
"
JAX
has two separate lowering rules for the cumulative reduction "
"
XLA
has two separate lowering rules for the cumulative reduction "
"primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses "
"a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. "
"The latter has a slow implementation on CPUs and GPUs. "
"By default,
jax
2tf uses the TPU lowering. Set this flag to True to "
"By default,
xla
2tf uses the TPU lowering. Set this flag to True to "
"use the associative scan lowering usage, and only if it makes a difference "
"for your application. "
"See the
jax
2tf README.md for more details."
"See the
xla
2tf README.md for more details."
),
)
jax
2tf_default_native_serialization
=
config
.
define_bool_state
(
name
=
"
jax
2tf_default_native_serialization"
,
default
=
bool_env
(
"
JAX
2TF_DEFAULT_NATIVE_SERIALIZATION"
,
False
),
xla
2tf_default_native_serialization
=
config
.
define_bool_state
(
name
=
"
xla
2tf_default_native_serialization"
,
default
=
bool_env
(
"
XLA
2TF_DEFAULT_NATIVE_SERIALIZATION"
,
False
),
help
=
(
"Sets the default value of the native_serialization parameter to "
"
jax
2tf.convert. Prefer using the parameter instead of the flag, the "
"
xla
2tf.convert. Prefer using the parameter instead of the flag, the "
"flag may be removed in the future."
),
)
# TODO: remove jax2tf_default_experimental_native_lowering
jax2tf_default_experimental_native_lowering
=
config
.
define_bool_state
(
name
=
"jax2tf_default_experimental_native_lowering"
,
default
=
bool_env
(
"JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING"
,
False
),
help
=
(
"DO NOT USE, deprecated in favor of jax2tf_default_native_serialization."
),
xla2tf_default_experimental_native_lowering
=
config
.
define_bool_state
(
name
=
"xla2tf_default_experimental_native_lowering"
,
default
=
bool_env
(
"XLA2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING"
,
False
),
help
=
(
"DO NOT USE, deprecated in favor of xla2tf_default_native_serialization."
),
)
jax
_platforms
=
config
.
define_string_state
(
name
=
"
jax
_platforms"
,
xla
_platforms
=
config
.
define_string_state
(
name
=
"
xla
_platforms"
,
default
=
None
,
help
=
(
"Comma-separated list of platform names specifying which platforms
jax
"
"Comma-separated list of platform names specifying which platforms
xla
"
"should initialize. If any of the platforms in this list are not successfully "
"initialized, an exception will be raised and the program will be aborted. "
"The first platform in the list will be the default platform. "
"For example, config.
jax
_platforms=cpu,tpu means that CPU and TPU backends "
"For example, config.
xla
_platforms=cpu,tpu means that CPU and TPU backends "
"will be initialized, and the CPU backend will be used unless otherwise "
"specified. If TPU initialization fails, it will raise an exception. "
"By default,
jax
will try to initialize all available "
"By default,
xla
will try to initialize all available "
"platforms and will default to GPU or TPU if available, and fallback to CPU "
"otherwise."
),
)
enable_checks
=
config
.
define_bool_state
(
name
=
"
jax
_enable_checks"
,
name
=
"
xla
_enable_checks"
,
default
=
False
,
help
=
"Turn on invariant checking for
JAX
internals. Makes things slower."
,
help
=
"Turn on invariant checking for
XLA
internals. Makes things slower."
,
)
check_tracer_leaks
=
config
.
define_bool_state
(
name
=
"
jax
_check_tracer_leaks"
,
name
=
"
xla
_check_tracer_leaks"
,
default
=
False
,
help
=
(
"Turn on checking for leaked tracers as soon as a trace completes. "
...
...
@@ -767,7 +765,7 @@ check_tracer_leaks = config.define_bool_state(
checking_leaks
=
functools
.
partial
(
check_tracer_leaks
,
True
)
debug_nans
=
config
.
define_bool_state
(
name
=
"
jax
_debug_nans"
,
name
=
"
xla
_debug_nans"
,
default
=
False
,
help
=
(
"Add nan checks to every operation. When a nan is detected on the "
...
...
@@ -778,7 +776,7 @@ debug_nans = config.define_bool_state(
)
debug_infs
=
config
.
define_bool_state
(
name
=
"
jax
_debug_infs"
,
name
=
"
xla
_debug_infs"
,
default
=
False
,
help
=
(
"Add inf checks to every operation. When an inf is detected on the "
...
...
@@ -789,7 +787,7 @@ debug_infs = config.define_bool_state(
)
log_compiles
=
config
.
define_bool_state
(
name
=
"
jax
_log_compiles"
,
name
=
"
xla
_log_compiles"
,
default
=
False
,
help
=
(
"Log a message each time every time `jit` or `pmap` compiles an XLA "
...
...
@@ -800,80 +798,78 @@ log_compiles = config.define_bool_state(
)
log_compiles
=
config
.
define_bool_state
(
name
=
"
jax
_log_checkpoint_residuals"
,
name
=
"
xla
_log_checkpoint_residuals"
,
default
=
False
,
help
=
(
"Log a message every time
jax.checkpoint (aka jax
.remat) is "
"Log a message every time
xla.checkpoint (aka xla
.remat) is "
"partially evaluated (e.g. for autodiff), printing what residuals "
"are saved."
),
)
parallel_functions_output_gda
=
config
.
define_bool_state
(
name
=
"
jax
_parallel_functions_output_gda"
,
name
=
"
xla
_parallel_functions_output_gda"
,
default
=
False
,
help
=
"If True, pjit will output GDAs."
,
)
def
_update_
jax
_array_global
(
val
):
def
_update_
xla
_array_global
(
val
):
if
val
is
not
None
and
not
val
:
raise
ValueError
(
"not supported in current version, please downgrad"
)
def
_update_
jax
_array_thread_local
(
val
):
def
_update_
xla
_array_thread_local
(
val
):
if
val
is
not
None
and
not
val
:
raise
ValueError
(
"not supported in current version, please downgrad"
)
jax
_array
=
config
.
define_bool_state
(
name
=
"
jax
_array"
,
xla
_array
=
config
.
define_bool_state
(
name
=
"
xla
_array"
,
default
=
True
,
upgrade
=
True
,
update_global_hook
=
_update_jax_array_global
,
update_thread_local_hook
=
_update_jax_array_thread_local
,
help
=
(
"If True, new pjit behavior will be enabled and `jax.Array` will be "
"used."
),
update_global_hook
=
_update_xla_array_global
,
update_thread_local_hook
=
_update_xla_array_thread_local
,
help
=
(
"whether use xla array"
),
)
jit_pjit_api_merge
=
config
.
define_bool_state
(
name
=
"
jax
_jit_pjit_api_merge"
,
name
=
"
xla
_jit_pjit_api_merge"
,
default
=
True
,
upgrade
=
True
,
help
=
(
"If True, jit and pjit API will be merged. You can only disable it via "
"the environment variable i.e. `os.environ['
JAX
_JIT_PJIT_API_MERGE'] = '0'`. "
"the environment variable i.e. `os.environ['
XLA
_JIT_PJIT_API_MERGE'] = '0'`. "
"The merge must be disabled via an environment variable since it "
"affects
JAX at import time so it needs to be disabled before jax
is "
"affects
XLA at import time so it needs to be disabled before xla
is "
"imported."
),
)
spmd_mode
=
config
.
define_enum_state
(
name
=
"
jax
_spmd_mode"
,
name
=
"
xla
_spmd_mode"
,
enum_values
=
[
"allow_all"
,
"allow_jit"
,
"allow_pjit"
],
# TODO: Default to `allow_jit` when the training wheels come
# off.
default
=
"allow_pjit"
,
help
=
(
"Decides whether Math on `
jax
.Array`'s that are not fully addressable "
"Decides whether Math on `
xla
.Array`'s that are not fully addressable "
"(i.e. spans across multiple processes) is allowed. The options are: "
"* allow_pjit: Default, only `pjit` computations are allowed to "
" execute on non-fully addressable `
jax
.Array`s
\n
"
"* allow_jit: `pjit` and `
jax
.jit` computations are allowed to "
" execute on non-fully addressable `
jax
.Array`s
\n
"
" execute on non-fully addressable `
xla
.Array`s
\n
"
"* allow_jit: `pjit` and `
xla
.jit` computations are allowed to "
" execute on non-fully addressable `
xla
.Array`s
\n
"
"* allow_all: `jnp`, normal math (like `a + b`, etc), `pjit`, "
" `
jax
.jit` and all other operations are allowed to "
" execute on non-fully addresable `
jax
.Array`s."
" `
xla
.jit` and all other operations are allowed to "
" execute on non-fully addresable `
xla
.Array`s."
),
)
distributed_debug
=
config
.
define_bool_state
(
name
=
"
jax
_distributed_debug"
,
name
=
"
xla
_distributed_debug"
,
default
=
False
,
help
=
(
"Enable logging useful for debugging multi-process distributed "
...
...
@@ -884,7 +880,7 @@ distributed_debug = config.define_bool_state(
enable_custom_prng
=
config
.
define_bool_state
(
name
=
"
jax
_enable_custom_prng"
,
name
=
"
xla
_enable_custom_prng"
,
default
=
False
,
upgrade
=
True
,
help
=
(
...
...
@@ -894,7 +890,7 @@ enable_custom_prng = config.define_bool_state(
)
default_prng_impl
=
config
.
define_enum_state
(
name
=
"
jax
_default_prng_impl"
,
name
=
"
xla
_default_prng_impl"
,
enum_values
=
[
"threefry2x32"
,
"rbg"
,
"unsafe_rbg"
],
default
=
"threefry2x32"
,
help
=
(
...
...
@@ -904,14 +900,14 @@ default_prng_impl = config.define_enum_state(
)
threefry_partitionable
=
config
.
define_bool_state
(
name
=
"
jax
_threefry_partitionable"
,
name
=
"
xla
_threefry_partitionable"
,
default
=
False
,
upgrade
=
True
,
help
=
(
"Enables internal threefry PRNG implementation changes that "
"render it automatically partitionable in some cases. For use "
"with pjit and/or
jax_array
=True. Without this flag, using the "
"standard
jax
.random pseudo-random number generation may result "
"with pjit and/or
xla
=True. Without this flag, using the "
"standard
xla
.random pseudo-random number generation may result "
"in extraneous communication and/or redundant distributed "
"computation. With this flag, the communication overheads disappear "
"in some cases."
...
...
@@ -923,17 +919,17 @@ threefry_partitionable = config.define_bool_state(
)
enable_custom_vjp_by_custom_transpose
=
config
.
define_bool_state
(
name
=
"
jax
_enable_custom_vjp_by_custom_transpose"
,
name
=
"
xla
_enable_custom_vjp_by_custom_transpose"
,
default
=
False
,
upgrade
=
True
,
help
=
(
"Enables an internal upgrade that implements `
jax
.custom_vjp` by "
"reduction to `
jax.custom_jvp` and `jax
.custom_transpose`."
"Enables an internal upgrade that implements `
xla
.custom_vjp` by "
"reduction to `
xla.custom_jvp` and `xla
.custom_transpose`."
),
)
raise_persistent_cache_errors
=
config
.
define_bool_state
(
name
=
"
jax
_raise_persistent_cache_errors"
,
name
=
"
xla
_raise_persistent_cache_errors"
,
default
=
False
,
help
=
(
"If true, exceptions raised when reading or writing to the "
...
...
@@ -946,7 +942,7 @@ raise_persistent_cache_errors = config.define_bool_state(
)
persistent_cache_min_compile_time_secs
=
config
.
define_float_state
(
name
=
"
jax
_persistent_cache_min_compile_time_secs"
,
name
=
"
xla
_persistent_cache_min_compile_time_secs"
,
default
=
1
,
help
=
(
"The minimum compile time of a computation to be written to the "
...
...
@@ -956,7 +952,7 @@ persistent_cache_min_compile_time_secs = config.define_float_state(
)
hlo_source_file_canonicalization_regex
=
config
.
define_string_state
(
name
=
"
jax
_hlo_source_file_canonicalization_regex"
,
name
=
"
xla
_hlo_source_file_canonicalization_regex"
,
default
=
None
,
help
=
(
"Used to canonicalize the source_path metadata of HLO instructions "
...
...
@@ -969,18 +965,18 @@ hlo_source_file_canonicalization_regex = config.define_string_state(
)
config
.
define_enum_state
(
name
=
"
jax
_default_dtype_bits"
,
name
=
"
xla
_default_dtype_bits"
,
enum_values
=
[
"32"
,
"64"
],
default
=
"64"
,
help
=
(
"Specify bit width of default dtypes, either 32-bit or 64-bit. "
"This is a temporary flag that will be used during the process "
"of deprecating the ``
jax
_enable_x64`` flag."
"of deprecating the ``
xla
_enable_x64`` flag."
),
)
numpy_dtype_promotion
=
config
.
define_enum_state
(
name
=
"
jax
_numpy_dtype_promotion"
,
name
=
"
xla
_numpy_dtype_promotion"
,
enum_values
=
[
"standard"
,
"strict"
],
default
=
"standard"
,
help
=
(
...
...
@@ -997,39 +993,38 @@ numpy_dtype_promotion = config.define_enum_state(
def
_update_x64_global
(
val
):
lib
jax
_jit
.
global_state
().
enable_x64
=
val
lib
xla
_jit
.
global_state
().
enable_x64
=
val
def
_update_x64_thread_local
(
val
):
lib
jax
_jit
.
thread_local_state
().
enable_x64
=
val
lib
xla
_jit
.
thread_local_state
().
enable_x64
=
val
enable_x64
=
config
.
define_bool_state
(
name
=
"
jax
_enable_x64"
,
name
=
"
xla
_enable_x64"
,
default
=
False
,
help
=
"Enable 64-bit types to be used"
,
update_global_hook
=
_update_x64_global
,
update_thread_local_hook
=
_update_x64_thread_local
,
)
# TODO: remove after fixing users of FLAGS.x64_enabled.
config
.
_contextmanager_flags
.
remove
(
"jax_enable_x64"
)
config
.
_contextmanager_flags
.
remove
(
"xla_enable_x64"
)
Config
.
x64_enabled
=
Config
.
jax_enable_x64
# type: ignore
Config
.
x64_enabled
=
Config
.
xla_enable_x64
def
_update_default_device_global
(
val
):
lib
jax
_jit
.
global_state
().
default_device
=
val
lib
xla
_jit
.
global_state
().
default_device
=
val
def
_update_default_device_thread_local
(
val
):
lib
jax
_jit
.
thread_local_state
().
default_device
=
val
lib
xla
_jit
.
thread_local_state
().
default_device
=
val
def
_validate_default_device
(
val
):
if
val
is
not
None
and
not
isinstance
(
val
,
xla_client
.
Device
):
# TODO: this is a workaround for non-PJRT Device types. Remove when
# all
JAX
backends use a single C++ device interface.
# all
XLA
backends use a single C++ device interface.
if
"Device"
in
str
(
type
(
val
)):
logger
.
info
(
"Allowing non-`xla_client.Device` default device: %s, type: %s"
,
...
...
@@ -1038,20 +1033,20 @@ def _validate_default_device(val):
)
return
raise
ValueError
(
"
jax
.default_device must be passed a Device object (e.g. "
f
"`
jax
.devices('cpu')[0]`), got:
{
repr
(
val
)
}
"
"
xla
.default_device must be passed a Device object (e.g. "
f
"`
xla
.devices('cpu')[0]`), got:
{
repr
(
val
)
}
"
)
# TODO: default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as
jax
.devices("cpu")[0]).
# platform names as well (e.g. "cpu" to mean the same as
xla
.devices("cpu")[0]).
default_device
=
config
.
define_string_or_object_state
(
name
=
"
jax
_default_device"
,
name
=
"
xla
_default_device"
,
default
=
None
,
help
=
(
"Configure the default device for
JAX
operations. Set to a Device "
'object (e.g. ``
jax
.devices("cpu")[0]``) to use that Device as the '
"default device for
JAX
operations and jit'd function calls (there is "
"Configure the default device for
XLA
operations. Set to a Device "
'object (e.g. ``
xla
.devices("cpu")[0]``) to use that Device as the '
"default device for
XLA
operations and jit'd function calls (there is "
"no effect on multi-device computations, e.g. pmapped function calls). "
"Set to None to use the system default device. See "
":ref:`faq-data-placement` for more information on device placement."
...
...
@@ -1063,15 +1058,15 @@ default_device = config.define_string_or_object_state(
def
_update_disable_jit_global
(
val
):
lib
jax
_jit
.
global_state
().
disable_jit
=
val
lib
xla
_jit
.
global_state
().
disable_jit
=
val
def
_update_disable_jit_thread_local
(
val
):
lib
jax
_jit
.
thread_local_state
().
disable_jit
=
val
lib
xla
_jit
.
thread_local_state
().
disable_jit
=
val
disable_jit
=
config
.
define_bool_state
(
name
=
"
jax
_disable_jit"
,
name
=
"
xla
_disable_jit"
,
default
=
False
,
help
=
(
"Disable JIT compilation and just call original Python."
),
update_global_hook
=
_update_disable_jit_global
,
...
...
@@ -1080,7 +1075,7 @@ disable_jit = config.define_bool_state(
numpy_rank_promotion
=
config
.
define_enum_state
(
name
=
"
jax
_numpy_rank_promotion"
,
name
=
"
xla
_numpy_rank_promotion"
,
enum_values
=
[
"allow"
,
"warn"
,
"raise"
],
default
=
"allow"
,
help
=
(
...
...
@@ -1094,7 +1089,7 @@ numpy_rank_promotion = config.define_enum_state(
)
default_matmul_precision
=
config
.
define_enum_state
(
name
=
"
jax
_default_matmul_precision"
,
name
=
"
xla
_default_matmul_precision"
,
enum_values
=
[
"bfloat16"
,
"tensorfloat32"
,
"float32"
],
default
=
None
,
help
=
(
...
...
@@ -1102,8 +1097,8 @@ default_matmul_precision = config.define_enum_state(
"Some platforms, like TPU, offer configurable precision levels for "
"matrix multiplication and convolution computations, trading off "
"accuracy for speed. The precision can be controlled for each "
"operation; for example, see the :func:`
jax
.lax.conv_general_dilated` "
"and :func:`
jax
.lax.dot` docstrings. But it can be useful to control "
"operation; for example, see the :func:`
xla
.lax.conv_general_dilated` "
"and :func:`
xla
.lax.dot` docstrings. But it can be useful to control "
"the default behavior obtained when an operation is not given a "
"specific precision.
\n\n
"
"This option can be used to control the default precision "
...
...
@@ -1122,10 +1117,10 @@ default_matmul_precision = config.define_enum_state(
)
traceback_filtering
=
config
.
define_enum_state
(
name
=
"
jax
_traceback_filtering"
,
name
=
"
xla
_traceback_filtering"
,
enum_values
=
[
"off"
,
"tracebackhide"
,
"remove_frames"
,
"auto"
],
default
=
"auto"
,
help
=
"Controls how
JAX
filters internal frames out of tracebacks.
\n\n
"
help
=
"Controls how
XLA
filters internal frames out of tracebacks.
\n\n
"
"Valid values are:
\n
"
' * "off": disables traceback filtering.
\n
'
' * "auto": use "tracebackhide" if running under a sufficiently '
...
...
@@ -1136,20 +1131,16 @@ traceback_filtering = config.define_enum_state(
" the unfiltered traceback as a __cause__ of the exception.
\n
"
,
)
# This flag is for internal use.
# TODO: Removes once we always enable cusparse lowering.
# TODO: Set to true after bug is fixed
bcoo_cusparse_lowering
=
config
.
define_bool_state
(
name
=
"
jax
_bcoo_cusparse_lowering"
,
name
=
"
xla
_bcoo_cusparse_lowering"
,
default
=
False
,
help
=
(
"Enables lowering BCOO ops to cuSparse."
),
)
# TODO: remove this flag when we ensure we only succeed at trace-staging
# if the intended backend can handle lowering the result
config
.
define_bool_state
(
name
=
"
jax
_dynamic_shapes"
,
default
=
bool
(
os
.
getenv
(
"
JAX
_DYNAMIC_SHAPES"
,
""
)),
name
=
"
xla
_dynamic_shapes"
,
default
=
bool
(
os
.
getenv
(
"
XLA
_DYNAMIC_SHAPES"
,
""
)),
help
=
(
"Enables experimental features for staging out computations with "
"dynamic shapes."
...
...
@@ -1163,14 +1154,14 @@ config.define_bool_state(
# This flag is temporary during rollout of the remat barrier.
# TODO: Remove if there are no complaints.
config
.
define_bool_state
(
name
=
"
jax
_remat_opt_barrier"
,
name
=
"
xla
_remat_opt_barrier"
,
default
=
True
,
help
=
(
"Enables using optimization-barrier op for lowering remat."
),
)
# TODO: Remove flag once coordination service has rolled out.
config
.
define_bool_state
(
name
=
"
jax
_coordination_service"
,
name
=
"
xla
_coordination_service"
,
default
=
True
,
help
=
(
"Use coordination service (experimental) instead of the default PjRT "
...
...
@@ -1180,17 +1171,17 @@ config.define_bool_state(
# TODO: set default to True, then remove
config
.
define_bool_state
(
name
=
"
jax
_eager_pmap"
,
name
=
"
xla
_eager_pmap"
,
default
=
True
,
upgrade
=
True
,
help
=
"Enable eager-mode pmap when
jax
_disable_jit is activated."
,
help
=
"Enable eager-mode pmap when
xla
_disable_jit is activated."
,
)
config
.
define_bool_state
(
name
=
"
jax
_experimental_unsafe_xla_runtime_errors"
,
name
=
"
xla
_experimental_unsafe_xla_runtime_errors"
,
default
=
False
,
help
=
(
"Enable XLA runtime errors for
jax
.experimental.checkify.checks "
"Enable XLA runtime errors for
xla
.experimental.checkify.checks "
"on CPU and GPU. These errors are async, might get lost and are not "
"very readable. But, they crash the computation and enable you "
"to write jittable checks without needing to checkify. Does not "
...
...
@@ -1242,10 +1233,10 @@ def _update_transfer_guard(state, key, val):
transfer_guard_host_to_device
=
config
.
define_enum_state
(
name
=
"
jax
_transfer_guard_host_to_device"
,
name
=
"
xla
_transfer_guard_host_to_device"
,
enum_values
=
[
"allow"
,
"log"
,
"disallow"
,
"log_explicit"
,
"disallow_explicit"
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --
jax
_transfer_guard.
# accidentally overriding --
xla
_transfer_guard.
default
=
None
,
help
=
(
"Select the transfer guard level for host-to-device transfers. "
...
...
@@ -1260,10 +1251,10 @@ transfer_guard_host_to_device = config.define_enum_state(
)
transfer_guard_device_to_device
=
config
.
define_enum_state
(
name
=
"
jax
_transfer_guard_device_to_device"
,
name
=
"
xla
_transfer_guard_device_to_device"
,
enum_values
=
[
"allow"
,
"log"
,
"disallow"
,
"log_explicit"
,
"disallow_explicit"
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --
jax
_transfer_guard.
# accidentally overriding --
xla
_transfer_guard.
default
=
None
,
help
=
(
"Select the transfer guard level for device-to-device transfers. "
...
...
@@ -1278,10 +1269,10 @@ transfer_guard_device_to_device = config.define_enum_state(
)
transfer_guard_device_to_host
=
config
.
define_enum_state
(
name
=
"
jax
_transfer_guard_device_to_host"
,
name
=
"
xla
_transfer_guard_device_to_host"
,
enum_values
=
[
"allow"
,
"log"
,
"disallow"
,
"log_explicit"
,
"disallow_explicit"
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --
jax
_transfer_guard.
# accidentally overriding --
xla
_transfer_guard.
default
=
None
,
help
=
(
"Select the transfer guard level for device-to-host transfers. "
...
...
@@ -1298,18 +1289,18 @@ transfer_guard_device_to_host = config.define_enum_state(
def
_update_all_transfer_guard_global
(
val
):
for
name
in
(
"
jax
_transfer_guard_host_to_device"
,
"
jax
_transfer_guard_device_to_device"
,
"
jax
_transfer_guard_device_to_host"
,
"
xla
_transfer_guard_host_to_device"
,
"
xla
_transfer_guard_device_to_device"
,
"
xla
_transfer_guard_device_to_host"
,
):
config
.
update
(
name
,
val
)
_transfer_guard
=
config
.
define_enum_state
(
name
=
"
jax
_transfer_guard"
,
name
=
"
xla
_transfer_guard"
,
enum_values
=
[
"allow"
,
"log"
,
"disallow"
,
"log_explicit"
,
"disallow_explicit"
],
# The default is applied by transfer_guard_lib. Use None here to avoid
# accidentally overriding --
jax
_transfer_guard_*.
# accidentally overriding --
xla
_transfer_guard_*.
default
=
None
,
help
=
(
"Select the transfer guard level for all transfers. This option is "
...
...
This diff is collapsed.
Click to expand it.
imperative/python/megengine/xla/lib/xla_bridge.py
浏览文件 @
5e013d8c
# code of this file is mainly from jax: https://github.com/google/jax
import
logging
import
os
import
platform
as
py_platform
...
...
@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger
=
logging
.
getLogger
(
__name__
)
flags
.
DEFINE_string
(
"
jax_xla_backend"
,
""
,
"Deprecated, please use --jax
_platforms instead."
"
xla_backend"
,
""
,
"Deprecated, please use --xla
_platforms instead."
)
flags
.
DEFINE_string
(
"
jax
_backend_target"
,
os
.
getenv
(
"
JAX
_BACKEND_TARGET"
,
""
).
lower
(),
"
xla
_backend_target"
,
os
.
getenv
(
"
XLA
_BACKEND_TARGET"
,
""
).
lower
(),
'Either "local" or "rpc:address" to connect to a remote service target.'
,
)
flags
.
DEFINE_string
(
"
jax
_platform_name"
,
os
.
getenv
(
"
JAX
_PLATFORM_NAME"
,
""
).
lower
(),
"Deprecated, please use --
jax
_platforms instead."
,
"
xla
_platform_name"
,
os
.
getenv
(
"
XLA
_PLATFORM_NAME"
,
""
).
lower
(),
"Deprecated, please use --
xla
_platforms instead."
,
)
flags
.
DEFINE_bool
(
"
jax
_disable_most_optimizations"
,
bool_env
(
"
JAX
_DISABLE_MOST_OPTIMIZATIONS"
,
False
),
"
xla
_disable_most_optimizations"
,
bool_env
(
"
XLA
_DISABLE_MOST_OPTIMIZATIONS"
,
False
),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program."
,
)
flags
.
DEFINE_integer
(
"
jax_
xla_profile_version"
,
int_env
(
"
JAX_
XLA_PROFILE_VERSION"
,
0
),
"xla_profile_version"
,
int_env
(
"XLA_PROFILE_VERSION"
,
0
),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature."
,
)
flags
.
DEFINE_string
(
"
jax
_cuda_visible_devices"
,
"
xla
_cuda_visible_devices"
,
"all"
,
'Restricts the set of CUDA devices that
JAX
will use. Either "all", or a '
'Restricts the set of CUDA devices that
XLA
will use. Either "all", or a '
"comma-separate list of integer device IDs."
,
)
flags
.
DEFINE_string
(
"
jax
_rocm_visible_devices"
,
"
xla
_rocm_visible_devices"
,
"all"
,
'Restricts the set of ROCM devices that
JAX
will use. Either "all", or a '
'Restricts the set of ROCM devices that
XLA
will use. Either "all", or a '
"comma-separate list of integer device IDs."
,
)
...
...
@@ -69,22 +70,22 @@ def get_compile_options(
)
->
xla_client
.
CompileOptions
:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax
devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of xla
devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options
=
xla_client
.
CompileOptions
()
compile_options
.
num_replicas
=
num_replicas
compile_options
.
num_partitions
=
num_partitions
...
...
@@ -130,12 +131,12 @@ def get_compile_options(
if
cuda_path
is
not
None
:
debug_options
.
xla_gpu_cuda_data_dir
=
cuda_path
if
FLAGS
.
jax
_disable_most_optimizations
:
if
FLAGS
.
xla
_disable_most_optimizations
:
debug_options
.
xla_backend_optimization_level
=
0
debug_options
.
xla_llvm_disable_expensive_passes
=
True
debug_options
.
xla_test_all_input_layouts
=
False
compile_options
.
profile_version
=
FLAGS
.
jax_
xla_profile_version
compile_options
.
profile_version
=
FLAGS
.
xla_profile_version
return
compile_options
...
...
@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial
(
make_gpu_client
,
platform_name
=
"cuda"
,
visible_devices_flag
=
"
jax
_cuda_visible_devices"
,
visible_devices_flag
=
"
xla
_cuda_visible_devices"
,
),
priority
=
200
,
)
...
...
@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial
(
make_gpu_client
,
platform_name
=
"rocm"
,
visible_devices_flag
=
"
jax
_rocm_visible_devices"
,
visible_devices_flag
=
"
xla
_rocm_visible_devices"
,
),
priority
=
200
,
)
if
hasattr
(
xla_client
,
"make_plugin_device_client"
):
# It is assumed that if
jax
has been built with a plugin client, then the
# It is assumed that if
xla
has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory
(
...
...
@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def
canonicalize_platform
(
platform
:
str
)
->
str
:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms
=
_alias_to_platforms
.
get
(
platform
,
None
)
if
platforms
is
None
:
return
platform
...
...
@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def
expand_platform_alias
(
platform
:
str
)
->
List
[
str
]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return
_alias_to_platforms
.
get
(
platform
,
[
platform
])
...
...
@@ -270,11 +271,11 @@ def backends():
with
_backend_lock
:
if
_backends
:
return
_backends
if
config
.
jax
_platforms
:
jax_platforms
=
config
.
jax
_platforms
.
split
(
","
)
if
config
.
xla
_platforms
:
xla_platforms
=
config
.
xla
_platforms
.
split
(
","
)
platforms
=
[]
# Allow platform aliases in the list of platforms.
for
platform
in
jax
_platforms
:
for
platform
in
xla
_platforms
:
platforms
.
extend
(
expand_platform_alias
(
platform
))
priorities
=
range
(
len
(
platforms
),
0
,
-
1
)
platforms_and_priorites
=
zip
(
platforms
,
priorities
)
...
...
@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg
=
f
"Unable to initialize backend '
{
platform
}
':
{
err
}
"
if
config
.
jax
_platforms
:
err_msg
+=
" (set
JAX
_PLATFORMS='' to automatically choose an available backend)"
if
config
.
xla
_platforms
:
err_msg
+=
" (set
XLA
_PLATFORMS='' to automatically choose an available backend)"
raise
RuntimeError
(
err_msg
)
else
:
_backends_errors
[
platform
]
=
str
(
err
)
...
...
@@ -315,12 +316,9 @@ def backends():
if
(
py_platform
.
system
()
!=
"Darwin"
and
_default_backend
.
platform
==
"cpu"
and
FLAGS
.
jax
_platform_name
!=
"cpu"
and
FLAGS
.
xla
_platform_name
!=
"cpu"
):
logger
.
warning
(
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
logger
.
warning
(
"No GPU/TPU found, falling back to CPU. "
)
return
_backends
...
...
@@ -329,7 +327,7 @@ def _clear_backends():
global
_backends_errors
global
_default_backend
logger
.
info
(
"Clearing
JAX
backend caches."
)
logger
.
info
(
"Clearing
XLA
backend caches."
)
with
_backend_lock
:
_backends
=
{}
_backends_errors
=
{}
...
...
@@ -351,11 +349,6 @@ def _init_backend(platform):
raise
RuntimeError
(
f
"Could not initialize backend '
{
platform
}
'"
)
if
backend
.
device_count
()
==
0
:
raise
RuntimeError
(
f
"Backend '
{
platform
}
' provides no devices."
)
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger
.
debug
(
"Backend '%s' initialized"
,
platform
)
return
backend
...
...
@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if
not
isinstance
(
platform
,
(
type
(
None
),
str
)):
return
platform
platform
=
platform
or
FLAGS
.
jax_xla_backend
or
FLAGS
.
jax
_platform_name
or
None
platform
=
platform
or
FLAGS
.
xla_backend
or
FLAGS
.
xla
_platform_name
or
None
bs
=
backends
()
if
platform
is
not
None
:
...
...
@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def
device_count
(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
)
->
int
:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`
jax
.local_device_count`.
On most platforms, this is the same as :py:func:`
xla
.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
...
...
@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`
jax
.process_index`.
:py:func:`
xla
.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
...
...
@@ -457,13 +450,13 @@ def local_devices(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
,
host_id
:
Optional
[
int
]
=
None
,
)
->
List
[
xla_client
.
Device
]:
"""Like :py:func:`
jax
.devices`, but only returns devices local to a given process.
"""Like :py:func:`
xla
.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(
jax
.process_count())``.
retrieved via ``len(
xla
.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
...
...
@@ -473,7 +466,7 @@ def local_devices(
"""
if
host_id
is
not
None
:
warnings
.
warn
(
"The argument to
jax
.local_devices has been renamed from `host_id` to "
"The argument to
xla
.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
...
...
@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return
get_backend
(
backend
).
process_index
()
# TODO: remove this sometime after jax 0.2.13 is released
def
host_id
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_id has been renamed to jax
.process_index. This alias "
"
xla.host_id has been renamed to xla
.process_index. This alias "
"will eventually be removed; please update your code."
)
return
process_index
(
backend
)
def
process_count
(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
)
->
int
:
"""Returns the number of
JAX
processes associated with the backend."""
"""Returns the number of
XLA
processes associated with the backend."""
return
max
(
d
.
process_index
for
d
in
devices
(
backend
))
+
1
# TODO: remove this sometime after jax 0.2.13 is released
def
host_count
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_count has been renamed to jax
.process_count. This alias "
"
xla.host_count has been renamed to xla
.process_count. This alias "
"will eventually be removed; please update your code."
)
return
process_count
(
backend
)
# TODO: remove this sometime after jax 0.2.13 is released
def
host_ids
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_ids has been deprecated; please use range(jax
.process_count()) "
"instead.
jax
.host_ids will eventually be removed; please update your "
"
xla.host_ids has been deprecated; please use range(xla
.process_count()) "
"instead.
xla
.host_ids will eventually be removed; please update your "
"code."
)
return
list
(
range
(
process_count
(
backend
)))
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
新手
引导
客服
返回
顶部