Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6f1246ac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6f1246ac
编写于
3月 15, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(trace): fix name duplication and fix error message for invalid input
GitOrigin-RevId: 7fe1605c2639ba2c67488e0bffd6d0e2fab73e6a
上级
c79bcfaa
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
45 addition
and
2 deletion
+45
-2
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+3
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+1
-0
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+34
-0
imperative/src/impl/transformations/trace.cpp
imperative/src/impl/transformations/trace.cpp
+7
-2
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
6f1246ac
...
...
@@ -216,6 +216,9 @@ class trace:
def
_process_inputs
(
self
,
*
args
,
**
kwargs
):
for
i
,
arg
in
enumerate
(
args
):
assert
isinstance
(
arg
,
RawTensor
),
"Only support tensor type args when capture_as_const is enabled"
name_tensor
(
"arg_{}"
.
format
(
i
),
arg
)
# TODO: mark kwargs in order
...
...
imperative/python/src/tensor.cpp
浏览文件 @
6f1246ac
...
...
@@ -1229,6 +1229,7 @@ void init_tensor(py::module m) {
m
.
def
(
"name_tensor"
,
[](
std
::
string
name
,
py
::
object
tensor
)
{
auto
*
tw
=
TensorWrapper
::
try_cast
(
tensor
.
ptr
());
mgb_assert
(
tw
,
"Arg_1 shoud be Tensor!"
);
auto
output
=
imperative
::
apply
(
TraceMarkVar
(
name
),
tw
->
m_tensor
->
data
())[
0
];
tw
->
m_tensor
->
reset
(
output
);
});
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
6f1246ac
...
...
@@ -748,3 +748,37 @@ def test_trace_jit_config():
for
fuse_dimshuffle
in
[
None
,
False
,
True
]:
for
fuse_reduce
in
[
None
,
False
,
True
]:
run
(
fuse_dimshuffle
,
fuse_reduce
)
def
test_trace_naming
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
func
(
x
):
return
F
.
max
(
x
,
axis
=
2
,
keepdims
=
False
)
+
1
inp
=
tensor
(
np
.
random
.
random
((
1
,
3
,
3
,
3
)))
func
(
inp
)
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
import
megengine.utils.network
as
network
net
=
network
.
Network
.
load
(
file
)
names
=
set
()
for
var
in
net
.
all_vars
:
assert
var
.
name
not
in
names
names
.
add
(
var
.
name
)
def
test_invalid_inp_error
():
@
trace
(
capture_as_const
=
True
)
def
func
(
a
):
return
a
*
2
try
:
func
(
1
)
except
Exception
as
e
:
assert
(
str
(
e
)
==
"Only support tensor type args when capture_as_const is enabled"
)
else
:
assert
False
imperative/src/impl/transformations/trace.cpp
浏览文件 @
6f1246ac
...
...
@@ -98,8 +98,6 @@ VarNodeArray TraceResult::dump(
"do model.eval()?"
);
}
output_nodes
=
OpDef
::
apply_on_var_node
(
*
op
,
input_nodes
);
name2ops
[
output_nodes
[
0
]
->
owner_opr
()
->
name
()].
push_back
(
output_nodes
[
0
]
->
owner_opr
());
}
else
{
// no opr, just forward VarNode
mgb_assert
(
...
...
@@ -121,6 +119,13 @@ VarNodeArray TraceResult::dump(
}
}
}
auto
on_opr
=
[
&
name2ops
](
cg
::
OperatorNodeBase
*
opr
)
{
name2ops
[
opr
->
name
()].
push_back
(
opr
);
};
cg
::
DepOprIter
dep_iter
(
on_opr
);
for
(
auto
&&
[
output
,
name
]
:
outputs
)
{
dep_iter
.
add
(
nodes
[
output
]
->
owner_opr
());
}
for
(
auto
&&
[
name
,
ops
]
:
name2ops
)
{
if
(
ops
.
size
()
<=
1
)
{
continue
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录