Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5e013d8c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5e013d8c
编写于
7月 06, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(xla): add xla acknowledgement
GitOrigin-RevId: f2fafebedfc3b520684655cff8c86b1df7c2c1cd
上级
4c7905f3
变更
5
展开全部
隐藏空白更改
内联
并排
Showing
5 changed file
with
244 addition
and
240 deletion
+244
-240
ACKNOWLEDGMENTS
ACKNOWLEDGMENTS
+5
-0
imperative/python/megengine/xla/__init__.py
imperative/python/megengine/xla/__init__.py
+2
-0
imperative/python/megengine/xla/lib/__init__.py
imperative/python/megengine/xla/lib/__init__.py
+31
-15
imperative/python/megengine/xla/lib/config.py
imperative/python/megengine/xla/lib/config.py
+143
-152
imperative/python/megengine/xla/lib/xla_bridge.py
imperative/python/megengine/xla/lib/xla_bridge.py
+63
-73
未找到文件。
ACKNOWLEDGMENTS
浏览文件 @
5e013d8c
...
...
@@ -755,6 +755,11 @@ Copyright 2014 Google Inc. All rights reserved.
5. MACE
Copyright 2018 Xiaomi Inc. All rights reserved.
6. XLA
Copyright 2017 The TensorFlow Authors. All Rights Reserved.
7. JAX
Copyright 2018 The JAX Authors.
Terms of Apache License Version 2.0
---------------------------------------------------
...
...
imperative/python/megengine/xla/__init__.py
浏览文件 @
5e013d8c
# some code of this directory is from jax: https://github.com/google/jax
from
.build
import
build_xla
imperative/python/megengine/xla/lib/__init__.py
浏览文件 @
5e013d8c
# code of this directory is mainly from jax: https://github.com/google/jax
try
:
import
mge_xlalib
as
mge_xlalib
except
ModuleNotFoundError
as
err
:
...
...
@@ -9,8 +11,10 @@ except ModuleNotFoundError as err:
raise
ModuleNotFoundError
(
msg
)
import
gc
import
pathlib
import
os
import
platform
import
subprocess
import
sys
import
warnings
from
typing
import
Optional
...
...
@@ -43,7 +47,8 @@ cpu_feature_guard.check_cpu_features()
xla_extension
=
xla_client
.
_xla
pytree
=
xla_client
.
_xla
.
pytree
jax_jit
=
xla_client
.
_xla
.
jax_jit
# we use some api in jaxlib
xla_jit
=
xla_client
.
_xla
.
jax_jit
pmap_lib
=
xla_client
.
_xla
.
pmap_lib
...
...
@@ -57,18 +62,29 @@ gc.callbacks.append(_xla_gc_callback)
xla_extension_version
:
int
=
getattr
(
xla_client
,
"_version"
,
0
)
mlir_api_version
=
xla_client
.
mlir_api_version
def
_cuda_path
()
->
Optional
[
str
]:
_mgexlalib_path
=
pathlib
.
Path
(
mge_xlalib
.
__file__
).
parent
path
=
_mgexlalib_path
.
parent
/
"nvidia"
/
"cuda_nvcc"
if
path
.
is_dir
():
return
str
(
path
)
path
=
_mgexlalib_path
/
"cuda"
if
path
.
is_dir
():
return
str
(
path
)
return
None
cuda_path
=
_cuda_path
()
# Finds the CUDA install path
def
_find_cuda_root_dir
()
->
Optional
[
str
]:
cuda_root_dir
=
os
.
environ
.
get
(
"CUDA_ROOT_DIR"
)
if
cuda_root_dir
is
None
:
try
:
which
=
"where"
if
sys
.
platform
==
"win32"
else
"which"
with
open
(
os
.
devnull
,
"w"
)
as
devnull
:
nvcc
=
(
subprocess
.
check_output
([
which
,
"nvcc"
],
stderr
=
devnull
)
.
decode
()
.
rstrip
(
"
\r\n
"
)
)
cuda_root_dir
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
nvcc
))
except
Exception
:
if
sys
.
platform
==
"win32"
:
assert
False
,
"xla not supported on windows"
else
:
cuda_root_dir
=
"/usr/local/cuda"
if
not
os
.
path
.
exists
(
cuda_root_dir
):
cuda_root_dir
=
None
return
cuda_root_dir
cuda_path
=
_find_cuda_root_dir
()
transfer_guard_lib
=
xla_client
.
_xla
.
transfer_guard_lib
imperative/python/megengine/xla/lib/config.py
浏览文件 @
5e013d8c
此差异已折叠。
点击以展开。
imperative/python/megengine/xla/lib/xla_bridge.py
浏览文件 @
5e013d8c
# code of this file is mainly from jax: https://github.com/google/jax
import
logging
import
os
import
platform
as
py_platform
...
...
@@ -19,41 +20,41 @@ FLAGS = flags.FLAGS
logger
=
logging
.
getLogger
(
__name__
)
flags
.
DEFINE_string
(
"
jax_xla_backend"
,
""
,
"Deprecated, please use --jax
_platforms instead."
"
xla_backend"
,
""
,
"Deprecated, please use --xla
_platforms instead."
)
flags
.
DEFINE_string
(
"
jax
_backend_target"
,
os
.
getenv
(
"
JAX
_BACKEND_TARGET"
,
""
).
lower
(),
"
xla
_backend_target"
,
os
.
getenv
(
"
XLA
_BACKEND_TARGET"
,
""
).
lower
(),
'Either "local" or "rpc:address" to connect to a remote service target.'
,
)
flags
.
DEFINE_string
(
"
jax
_platform_name"
,
os
.
getenv
(
"
JAX
_PLATFORM_NAME"
,
""
).
lower
(),
"Deprecated, please use --
jax
_platforms instead."
,
"
xla
_platform_name"
,
os
.
getenv
(
"
XLA
_PLATFORM_NAME"
,
""
).
lower
(),
"Deprecated, please use --
xla
_platforms instead."
,
)
flags
.
DEFINE_bool
(
"
jax
_disable_most_optimizations"
,
bool_env
(
"
JAX
_DISABLE_MOST_OPTIMIZATIONS"
,
False
),
"
xla
_disable_most_optimizations"
,
bool_env
(
"
XLA
_DISABLE_MOST_OPTIMIZATIONS"
,
False
),
"Try not to do much optimization work. This can be useful if the cost of "
"optimization is greater than that of running a less-optimized program."
,
)
flags
.
DEFINE_integer
(
"
jax_
xla_profile_version"
,
int_env
(
"
JAX_
XLA_PROFILE_VERSION"
,
0
),
"xla_profile_version"
,
int_env
(
"XLA_PROFILE_VERSION"
,
0
),
"Optional profile version for XLA compilation. "
"This is meaningful only when XLA is configured to "
"support the remote compilation profile feature."
,
)
flags
.
DEFINE_string
(
"
jax
_cuda_visible_devices"
,
"
xla
_cuda_visible_devices"
,
"all"
,
'Restricts the set of CUDA devices that
JAX
will use. Either "all", or a '
'Restricts the set of CUDA devices that
XLA
will use. Either "all", or a '
"comma-separate list of integer device IDs."
,
)
flags
.
DEFINE_string
(
"
jax
_rocm_visible_devices"
,
"
xla
_rocm_visible_devices"
,
"all"
,
'Restricts the set of ROCM devices that
JAX
will use. Either "all", or a '
'Restricts the set of ROCM devices that
XLA
will use. Either "all", or a '
"comma-separate list of integer device IDs."
,
)
...
...
@@ -69,22 +70,22 @@ def get_compile_options(
)
->
xla_client
.
CompileOptions
:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax
devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of xla
devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
"""
compile_options
=
xla_client
.
CompileOptions
()
compile_options
.
num_replicas
=
num_replicas
compile_options
.
num_partitions
=
num_partitions
...
...
@@ -130,12 +131,12 @@ def get_compile_options(
if
cuda_path
is
not
None
:
debug_options
.
xla_gpu_cuda_data_dir
=
cuda_path
if
FLAGS
.
jax
_disable_most_optimizations
:
if
FLAGS
.
xla
_disable_most_optimizations
:
debug_options
.
xla_backend_optimization_level
=
0
debug_options
.
xla_llvm_disable_expensive_passes
=
True
debug_options
.
xla_test_all_input_layouts
=
False
compile_options
.
profile_version
=
FLAGS
.
jax_
xla_profile_version
compile_options
.
profile_version
=
FLAGS
.
xla_profile_version
return
compile_options
...
...
@@ -187,7 +188,7 @@ if hasattr(xla_client, "make_gpu_client"):
partial
(
make_gpu_client
,
platform_name
=
"cuda"
,
visible_devices_flag
=
"
jax
_cuda_visible_devices"
,
visible_devices_flag
=
"
xla
_cuda_visible_devices"
,
),
priority
=
200
,
)
...
...
@@ -196,13 +197,13 @@ if hasattr(xla_client, "make_gpu_client"):
partial
(
make_gpu_client
,
platform_name
=
"rocm"
,
visible_devices_flag
=
"
jax
_rocm_visible_devices"
,
visible_devices_flag
=
"
xla
_rocm_visible_devices"
,
),
priority
=
200
,
)
if
hasattr
(
xla_client
,
"make_plugin_device_client"
):
# It is assumed that if
jax
has been built with a plugin client, then the
# It is assumed that if
xla
has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory
(
...
...
@@ -229,11 +230,11 @@ def is_known_platform(platform: str):
def
canonicalize_platform
(
platform
:
str
)
->
str
:
"""Replaces platform aliases with their concrete equivalent.
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms
=
_alias_to_platforms
.
get
(
platform
,
None
)
if
platforms
is
None
:
return
platform
...
...
@@ -252,9 +253,9 @@ def canonicalize_platform(platform: str) -> str:
def
expand_platform_alias
(
platform
:
str
)
->
List
[
str
]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return
_alias_to_platforms
.
get
(
platform
,
[
platform
])
...
...
@@ -270,11 +271,11 @@ def backends():
with
_backend_lock
:
if
_backends
:
return
_backends
if
config
.
jax
_platforms
:
jax_platforms
=
config
.
jax
_platforms
.
split
(
","
)
if
config
.
xla
_platforms
:
xla_platforms
=
config
.
xla
_platforms
.
split
(
","
)
platforms
=
[]
# Allow platform aliases in the list of platforms.
for
platform
in
jax
_platforms
:
for
platform
in
xla
_platforms
:
platforms
.
extend
(
expand_platform_alias
(
platform
))
priorities
=
range
(
len
(
platforms
),
0
,
-
1
)
platforms_and_priorites
=
zip
(
platforms
,
priorities
)
...
...
@@ -303,8 +304,8 @@ def backends():
# If the backend isn't built into the binary, or if it has no devices,
# we expect a RuntimeError.
err_msg
=
f
"Unable to initialize backend '
{
platform
}
':
{
err
}
"
if
config
.
jax
_platforms
:
err_msg
+=
" (set
JAX
_PLATFORMS='' to automatically choose an available backend)"
if
config
.
xla
_platforms
:
err_msg
+=
" (set
XLA
_PLATFORMS='' to automatically choose an available backend)"
raise
RuntimeError
(
err_msg
)
else
:
_backends_errors
[
platform
]
=
str
(
err
)
...
...
@@ -315,12 +316,9 @@ def backends():
if
(
py_platform
.
system
()
!=
"Darwin"
and
_default_backend
.
platform
==
"cpu"
and
FLAGS
.
jax
_platform_name
!=
"cpu"
and
FLAGS
.
xla
_platform_name
!=
"cpu"
):
logger
.
warning
(
"No GPU/TPU found, falling back to CPU. "
"(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)"
)
logger
.
warning
(
"No GPU/TPU found, falling back to CPU. "
)
return
_backends
...
...
@@ -329,7 +327,7 @@ def _clear_backends():
global
_backends_errors
global
_default_backend
logger
.
info
(
"Clearing
JAX
backend caches."
)
logger
.
info
(
"Clearing
XLA
backend caches."
)
with
_backend_lock
:
_backends
=
{}
_backends_errors
=
{}
...
...
@@ -351,11 +349,6 @@ def _init_backend(platform):
raise
RuntimeError
(
f
"Could not initialize backend '
{
platform
}
'"
)
if
backend
.
device_count
()
==
0
:
raise
RuntimeError
(
f
"Backend '
{
platform
}
' provides no devices."
)
# ccq: disable distributed_debug_log
# util.distributed_debug_log(("Initialized backend", backend.platform),
# ("process_index", backend.process_index()),
# ("device_count", backend.device_count()),
# ("local_devices", backend.local_devices()))
logger
.
debug
(
"Backend '%s' initialized"
,
platform
)
return
backend
...
...
@@ -366,7 +359,7 @@ def _get_backend_uncached(platform=None):
if
not
isinstance
(
platform
,
(
type
(
None
),
str
)):
return
platform
platform
=
platform
or
FLAGS
.
jax_xla_backend
or
FLAGS
.
jax
_platform_name
or
None
platform
=
platform
or
FLAGS
.
xla_backend
or
FLAGS
.
xla
_platform_name
or
None
bs
=
backends
()
if
platform
is
not
None
:
...
...
@@ -399,7 +392,7 @@ def get_device_backend(device=None):
def
device_count
(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
)
->
int
:
"""Returns the total number of devices.
On most platforms, this is the same as :py:func:`
jax
.local_device_count`.
On most platforms, this is the same as :py:func:`
xla
.local_device_count`.
However, on multi-process platforms where different devices are associated
with different processes, this will return the total number of devices across
all processes.
...
...
@@ -430,7 +423,7 @@ def devices(
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
equal to ``device_count(backend)``. Local devices can be identified by
comparing :attr:`Device.process_index` to the value returned by
:py:func:`
jax
.process_index`.
:py:func:`
xla
.process_index`.
If ``backend`` is ``None``, returns all the devices from the default backend.
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
...
...
@@ -457,13 +450,13 @@ def local_devices(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
,
host_id
:
Optional
[
int
]
=
None
,
)
->
List
[
xla_client
.
Device
]:
"""Like :py:func:`
jax
.devices`, but only returns devices local to a given process.
"""Like :py:func:`
xla
.devices`, but only returns devices local to a given process.
If ``process_index`` is ``None``, returns devices local to this process.
Args:
process_index: the integer index of the process. Process indices can be
retrieved via ``len(
jax
.process_count())``.
retrieved via ``len(
xla
.process_count())``.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
``'tpu'``.
...
...
@@ -473,7 +466,7 @@ def local_devices(
"""
if
host_id
is
not
None
:
warnings
.
warn
(
"The argument to
jax
.local_devices has been renamed from `host_id` to "
"The argument to
xla
.local_devices has been renamed from `host_id` to "
"`process_index`. This alias will eventually be removed; please update "
"your code."
)
...
...
@@ -502,34 +495,31 @@ def process_index(backend: Optional[Union[str, XlaBackend]] = None) -> int:
return
get_backend
(
backend
).
process_index
()
# TODO: remove this sometime after jax 0.2.13 is released
def
host_id
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_id has been renamed to jax
.process_index. This alias "
"
xla.host_id has been renamed to xla
.process_index. This alias "
"will eventually be removed; please update your code."
)
return
process_index
(
backend
)
def
process_count
(
backend
:
Optional
[
Union
[
str
,
XlaBackend
]]
=
None
)
->
int
:
"""Returns the number of
JAX
processes associated with the backend."""
"""Returns the number of
XLA
processes associated with the backend."""
return
max
(
d
.
process_index
for
d
in
devices
(
backend
))
+
1
# TODO: remove this sometime after jax 0.2.13 is released
def
host_count
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_count has been renamed to jax
.process_count. This alias "
"
xla.host_count has been renamed to xla
.process_count. This alias "
"will eventually be removed; please update your code."
)
return
process_count
(
backend
)
# TODO: remove this sometime after jax 0.2.13 is released
def
host_ids
(
backend
=
None
):
warnings
.
warn
(
"
jax.host_ids has been deprecated; please use range(jax
.process_count()) "
"instead.
jax
.host_ids will eventually be removed; please update your "
"
xla.host_ids has been deprecated; please use range(xla
.process_count()) "
"instead.
xla
.host_ids will eventually be removed; please update your "
"code."
)
return
list
(
range
(
process_count
(
backend
)))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录