Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2e530779
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2e530779
编写于
11月 11, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/trace): use xpux device when dump
GitOrigin-RevId: f37285f70e9d21ca0c3951ebe917351e94e1ec3f
上级
739f927c
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
10 addition
and
5 deletion
+10
-5
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+8
-3
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+2
-2
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
2e530779
...
...
@@ -20,6 +20,7 @@ import numpy as np
from
..core._imperative_rt
import
GraphProfiler
from
..core._imperative_rt.ops
import
OprAttr
from
..core._trace_option
import
set_symbolic_shape
from
..core._wrap
import
device
as
as_device
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
...
...
@@ -588,6 +589,8 @@ class trace:
len
(
self
.
_output_bindings
)
)
)
if
arg_names
is
None
:
arg_names
=
[
"arg_%d"
%
i
for
i
in
range
(
len
(
self
.
_arg_bindings
))]
if
arg_names
and
not
isinstance
(
arg_names
,
collections
.
abc
.
Sequence
):
arg_names
=
(
arg_names
,)
if
arg_names
and
len
(
arg_names
)
!=
len
(
self
.
_arg_bindings
):
...
...
@@ -598,6 +601,8 @@ class trace:
)
output_names
=
output_names
or
self
.
_output_names
dumped_device
=
as_device
(
"xpux"
)
h2v
=
{}
graph
=
G
.
Graph
()
# only graph_opt_level takes effect in dump
...
...
@@ -607,14 +612,14 @@ class trace:
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
info
.
device
,
device
=
dumped_
device
,
shape
=
info
.
shape
,
name
=
arg_names
[
i
]
if
arg_names
else
None
,
)
for
k
,
h
in
self
.
_kwarg_bindings
.
items
():
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
info
.
device
,
shape
=
info
.
shape
,
name
=
k
dtype
=
info
.
dtype
,
device
=
dumped_
device
,
shape
=
info
.
shape
,
name
=
k
)
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
...
...
@@ -625,7 +630,7 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
info
.
device
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_
device
)
ivars
.
append
(
h2v
[
h
])
ovars
=
apply
(
op
,
*
ivars
)
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
2e530779
...
...
@@ -100,8 +100,8 @@ def test_dump():
file
=
io
.
BytesIO
()
dump_info
=
f
.
dump
(
file
)
assert
dump_info
.
nr_opr
==
3
np
.
testing
.
assert_equal
(
dump_info
.
inputs
,
[
"
h2d[0]"
,
"h2d[2]
"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD(
h2d[0],h2d[2]
)[4]"
])
np
.
testing
.
assert_equal
(
dump_info
.
inputs
,
[
"
arg_0"
,
"arg_1
"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD(
arg_0,arg_1
)[4]"
])
file
.
seek
(
0
)
result
=
cgtools
.
load_and_inference
(
file
,
[
a
,
b
])
np
.
testing
.
assert_equal
(
result
[
0
],
y
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录