Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d8917c22
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d8917c22
编写于
7月 05, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(xla_trace): convert params in compile
GitOrigin-RevId: c00e0592810d717aa2f786e64165ce78041a83ed
上级
4ae9dd00
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
114 addition
and
39 deletion
+114
-39
imperative/python/megengine/jit/xla_backend.py
imperative/python/megengine/jit/xla_backend.py
+43
-18
imperative/python/megengine/xla/compile.py
imperative/python/megengine/xla/compile.py
+3
-10
imperative/python/test/unit/xla/functional/test_xla_convert.py
...ative/python/test/unit/xla/functional/test_xla_convert.py
+68
-11
未找到文件。
imperative/python/megengine/jit/xla_backend.py
浏览文件 @
d8917c22
from
collections
import
OrderedDict
,
defaultdict
from
..
import
tensor
import
numpy
as
np
from
..
import
_full_sync
,
tensor
from
..core._imperative_rt
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
...
...
@@ -15,16 +17,11 @@ from ..device import get_default_device
from
..utils.dlpack
import
from_dlpack
,
to_dlpack
from
.tracing
import
trace
# try:
# from mge_xlalib.xla_extension import ArrayImpl
# from ..xla.lib import xla_client as xc
# except ImportError:
# pass
from
mge_xlalib.xla_extension
import
ArrayImpl
from
..xla.lib
import
xla_client
as
xc
try
:
from
mge_xlalib.xla_extension
import
ArrayImpl
from
..xla.lib
import
xla_client
as
xc
except
ImportError
as
e
:
pass
xla_client_compute_stream
=
None
...
...
@@ -93,22 +90,48 @@ class xla_trace(trace):
def
unset_env
(
self
):
set_use_xla_backend
(
self
.
orig_use_xla
)
def
compile
(
self
):
from
..xla
import
build_xla
from
..traced_module.pytree
import
SUPPORTED_LEAF_TYPE
,
register_supported_type
def
convert_params_to_xla
(
self
):
from
..device
import
coalesce_free_memory
from
..utils.module_utils
import
get_expand_structure
from
..xla.device
import
get_xla_backend_and_device
from
..tensor
import
Tensor
from
..distributed
import
get_mm_server_addr
,
is_distributed
assert
self
.
traced
if
self
.
overall
:
backend
=
self
.
xla_exec
.
backend
devices
=
backend
.
local_devices
()
_
,
device_id
,
_
=
CompNode
(
get_default_device
()).
physical_locator
device_index
=
(
0
if
len
(
devices
)
==
0
else
[
d
.
id
for
d
in
devices
].
index
(
device_id
)
)
device
=
devices
[
device_index
]
for
attr
,
_
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
param
.
_reset
(
param
.
to
(
"cpux"
))
for
tensor
,
_
in
self
.
opt_param_dict
.
items
():
tensor
.
_reset
(
tensor
.
to
(
"cpux"
))
def
as_xla_array
(
tensor
,
backend
,
device
):
np_array
=
tensor
.
numpy
()
if
np_array
.
shape
==
():
np_array
=
np_array
[
np
.
newaxis
]
xla_array
=
backend
.
buffer_from_pyval
(
np_array
,
device
)
tensor
.
_reset
(
Tensor
(
xla_array
))
for
attr
,
_
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
as_xla_array
(
param
,
backend
,
device
)
for
tensor
,
_
in
self
.
opt_param_dict
.
items
():
as_xla_array
(
tensor
,
backend
,
device
)
def
compile
(
self
):
from
..xla
import
build_xla
from
..traced_module.pytree
import
SUPPORTED_LEAF_TYPE
,
register_supported_type
from
..utils.module_utils
import
get_expand_structure
from
..xla.device
import
get_xla_backend_and_device
from
..tensor
import
Tensor
from
..distributed
import
get_mm_server_addr
,
is_distributed
assert
self
.
traced
self
.
xla_exec
,
self
.
inp_ids
,
self
.
out_ids
=
build_xla
(
self
,
return_with_io
=
True
,
...
...
@@ -116,6 +139,8 @@ class xla_trace(trace):
ip
=
get_mm_server_addr
()[
0
]
if
is_distributed
()
else
None
,
port
=
get_mm_server_addr
()[
1
]
+
1
if
is_distributed
()
else
None
,
)
if
self
.
overall
:
self
.
convert_params_to_xla
()
id2inpidx
=
defaultdict
(
list
)
id2outidx
=
defaultdict
(
list
)
for
idx
,
id
in
enumerate
(
self
.
inp_ids
):
...
...
imperative/python/megengine/xla/compile.py
浏览文件 @
d8917c22
...
...
@@ -73,16 +73,9 @@ class InputsHandler:
if
i
.
_is_external_value
():
rst
.
append
([
i
.
_external_obj
()])
else
:
if
"gpu"
in
i
.
device
.
physical_name
:
capsule
=
to_dlpack
(
i
)
xla_array
=
self
.
from_dlpack
(
capsule
)
rst
.
append
([
xla_array
])
else
:
r
=
self
.
handler
(
self
.
local_devices
,
[
self
.
input_indices
[
idx
],],
[
i
,]
)[
0
]
rst
.
append
(
r
)
i
.
_reset
(
tensor
(
r
[
0
]))
return
rst
def
__str__
(
self
):
...
...
imperative/python/test/unit/xla/functional/test_xla_convert.py
浏览文件 @
d8917c22
...
...
@@ -3,27 +3,48 @@ import platform
import
numpy
as
np
import
pytest
import
megengine.distributed
as
dist
import
megengine.functional
as
F
import
megengine.
jit
as
ji
t
import
megengine.
functional.distributed
as
fdis
t
import
megengine.tensor
as
tensor
from
megengine
import
autodiff
,
is_cuda_available
from
megengine.autodiff.grad_manager
import
GradManager
from
meg_xlalib.xla_extension
import
ArrayImpl
from
megengine.core._imperative_rt.core2
import
(
is_external_convert
,
set_external_convert_hook
,
)
from
megengine.jit
import
xla_trace
from
megengine.module
import
Conv2d
def
test_external_flag_set
():
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_external_tsf_set
():
from
mge_xlalib.xla_extension
import
ArrayImpl
@
xla_trace
(
capture_as_const
=
True
)
def
test_fun
():
pass
def
test_func
(
inp
):
return
inp
assert
is_external_convert
()
inp
=
tensor
(
np
.
random
.
random
((
9
,
9
,
32
,
32
)))
mge_inp
=
test_func
(
inp
)
xla_inp
=
test_func
(
inp
)
assert
xla_inp
.
_is_external_value
()
assert
isinstance
(
xla_inp
.
_external_obj
(),
ArrayImpl
)
assert
mge_inp
.
shape
==
xla_inp
.
shape
assert
mge_inp
.
dtype
==
xla_inp
.
dtype
assert
not
xla_inp
.
_is_external_value
()
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_external_value
():
m
=
Conv2d
(
9
,
9
,
3
,
groups
=
9
)
m
=
Conv2d
(
9
,
9
,
3
,
groups
=
9
)
gm
=
GradManager
()
gm
.
attach
(
m
.
parameters
())
...
...
@@ -39,8 +60,44 @@ def test_external_value():
model
.
weight
.
grad
=
None
return
ig
,
wg
inp
=
tensor
(
np
.
random
.
random
((
9
,
9
,
32
,
32
)))
*
100
inp
=
tensor
(
np
.
random
.
random
((
9
,
9
,
32
,
32
)))
*
100
mge_ig
,
mge_wg
=
conv_grad
(
inp
,
m
)
xla_ig
,
xla_wg
=
conv_grad
(
inp
,
m
)
np
.
testing
.
assert_allclose
(
mge_ig
.
numpy
(),
xla_ig
.
numpy
())
np
.
testing
.
assert_allclose
(
mge_wg
.
numpy
(),
xla_wg
.
numpy
(),
atol
=
1e-5
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
def
test_distributed_convert
():
from
mge_xlalib.xla_extension
import
ArrayImpl
def
tester
(
ishape
,
n_gpus
,
dtype
=
None
):
@
dist
.
launcher
(
n_gpus
=
n_gpus
)
def
worker
(
data
):
rank
=
dist
.
get_rank
()
inp
=
tensor
(
data
[
rank
])
@
xla_trace
(
without_host
=
True
)
def
func1
(
inp
):
return
fdist
.
all_reduce_sum
(
inp
)
mge_rst
=
func1
(
inp
)
xla_rst
=
func1
(
inp
)
assert
xla_rst
.
_is_external_value
()
assert
isinstance
(
xla_rst
.
_external_obj
(),
ArrayImpl
)
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-5
)
assert
mge_rst
.
shape
==
xla_rst
.
shape
assert
mge_rst
.
dtype
==
xla_rst
.
dtype
assert
not
xla_rst
.
_is_external_value
()
x
=
np
.
random
.
randn
(
*
ishape
).
astype
(
dtype
)
y
=
np
.
random
.
randn
(
*
ishape
).
astype
(
dtype
)
data
=
(
x
,
y
)
worker
(
data
)
a
,
b
=
conv_grad
(
inp
,
m
)
a1
,
b1
=
conv_grad
(
inp
,
m
)
np
.
testing
.
assert_allclose
(
a
.
numpy
(),
a1
.
numpy
())
\ No newline at end of file
tester
((
16
,
1
,
64
,),
2
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录