Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cf5e9488
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cf5e9488
编写于
1月 17, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(traced_module): fix module trace transformation
GitOrigin-RevId: ce11fe5e093d89cd444595f673a849c16dadbe4a
上级
97c90d91
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
54 addition
and
32 deletion
+54
-32
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+4
-8
imperative/python/src/module_trace.h
imperative/python/src/module_trace.h
+6
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+18
-22
imperative/python/test/unit/traced_module/test_trace_module.py
...ative/python/test/unit/traced_module/test_trace_module.py
+26
-2
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
cf5e9488
...
...
@@ -606,7 +606,8 @@ class Apply(Expr):
def
apply_module_trace_hook
(
cls
,
opdef
,
*
inputs
):
for
i
in
inputs
:
node
=
NodeMixin
.
get
(
i
,
None
)
assert
node
is
not
None
if
node
is
None
:
# capture as constant
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
if
isinstance
(
opdef
,
FakeQuant
):
inp_nodes
=
[
NodeMixin
.
get
(
inputs
[
0
])]
...
...
@@ -627,7 +628,6 @@ class Apply(Expr):
unset_module_tracing
()
outputs
=
apply
(
opdef
,
*
inputs
)
outputs
=
list
(
map
(
Tensor
,
outputs
))
set_module_tracing
()
apply_node
.
add_outputs
(
outputs
)
...
...
@@ -741,12 +741,8 @@ class Constant(Expr):
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
or
c
.
is_qat
if
isinstance
(
c
,
RawTensor
):
if
is_tracing_module
():
unset_module_tracing
()
c
=
Tensor
(
c
)
set_module_tracing
()
else
:
if
type
(
c
)
is
RawTensor
:
with
_exclude_from_trace
():
c
=
Tensor
(
c
)
self
.
value
=
c
self
.
name
=
name
...
...
imperative/python/src/module_trace.h
浏览文件 @
cf5e9488
...
...
@@ -52,6 +52,12 @@ public:
}
}
void
enable
()
{
m_enabled
=
1
;
}
void
disable
()
{
m_enabled
=
0
;
}
bool
enabled
()
const
{
return
m_enabled
;
}
ValueRef
unwrap
(
ValueRef
value
)
override
{
return
value
;
}
std
::
string
name
()
const
override
{
return
"ModuleTraceTransformation"
;
}
...
...
imperative/python/src/tensor.cpp
浏览文件 @
cf5e9488
...
...
@@ -219,17 +219,19 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
PyObject
*
TensorWrapper
::
module_trace_info
()
{
if
(
auto
module_trace_info
=
module_trace_info_map
.
try_get
(
m_tensor
->
data
()))
{
if
(
module_trace_info
->
ptr
())
{
return
module_trace_info
->
inc_ref
().
ptr
();
}
else
{
}
}
PyErr_SetString
(
PyExc_AttributeError
,
"Has no attribute named
\'
_NodeMixin__node
\'
, please "
"set it first"
);
return
nullptr
;
}
}
void
TensorWrapper
::
set_module_trace_info
(
PyObject
*
obj
)
{
// TODO: erase when obj == nullptr
module_trace_info_map
[
m_tensor
->
data
()]
=
py
::
reinterpret_borrow
<
py
::
object
>
(
obj
);
}
...
...
@@ -1031,29 +1033,23 @@ void init_tensor(py::module m) {
static
py
::
function
module_trace_hook
;
static
auto
get_module_trace
=
[]
{
static
std
::
shared_ptr
<
ModuleTraceTransformation
>
module_trace_transformation
;
static
int
module_tracing
=
0
;
m
.
def
(
"set_module_tracing"
,
[
=
]
{
if
(
!
module_trace_transformation
)
{
mgb_assert
(
module_trace_hook
);
module_trace_transformation
=
std
::
make_shared
<
ModuleTraceTransformation
>
(
module_trace_hook
);
}
if
(
++
module_tracing
==
1
)
{
transformations
.
register_at
<
TransformationManager
::
ModuleTrace
>
(
transformations
.
register_at
<
Segment
::
ModuleTrace
>
(
module_trace_transformation
);
}
});
return
module_trace_transformation
;
};
m
.
def
(
"unset_module_tracing"
,
[
=
]
{
if
(
--
module_tracing
==
0
)
{
transformations
.
unregister
<
TransformationManager
::
ModuleTrace
>
(
module_trace_transformation
);
}
});
m
.
def
(
"set_module_tracing"
,
[
=
]
{
get_module_trace
()
->
enable
();
});
m
.
def
(
"unset_module_tracing"
,
[
=
]
{
get_module_trace
()
->
disable
();
});
m
.
def
(
"is_tracing_module"
,
[
=
]
{
return
module_tracing
>
0
;
});
m
.
def
(
"is_tracing_module"
,
[
=
]
{
return
get_module_trace
()
->
enabled
()
;
});
m
.
def
(
"set_module_trace_hook"
,
[](
py
::
function
function
)
{
module_trace_hook
=
function
;
});
...
...
imperative/python/test/unit/traced_module/test_trace_module.py
浏览文件 @
cf5e9488
...
...
@@ -5,9 +5,11 @@ import numpy as np
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
Tensor
from
megengine.module.module
import
Module
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
builtin
from
megengine.module
import
Module
from
megengine.traced_module
import
TracedModule
,
enable_expr_checker
,
trace_module
from
megengine.traced_module.expr
import
CallFunction
from
megengine.traced_module.expr
import
Apply
,
CallFunction
,
Constant
class
MyModule1
(
M
.
Module
):
...
...
@@ -133,3 +135,25 @@ def test_trace_module():
tm6
=
trace_module
(
MyModule5
(),
a
,
b
)
assert
tm6
.
m1
.
argspec
is
None
assert
tm6
.
m1
.
_is_top
is
False
def
test_trace_module_2
():
class
Model
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
out
=
x
.
shape
out
=
apply
(
builtin
.
Elemwise
(
mode
=
"ADD"
),
out
,
Tensor
(
1
))
return
out
traced_model
=
trace_module
(
Model
(),
Tensor
(([
1
,])))
assert
isinstance
(
traced_model
.
graph
.
_exprs
[
0
],
Apply
)
and
isinstance
(
traced_model
.
graph
.
_exprs
[
0
].
opdef
,
builtin
.
GetVarShape
)
assert
isinstance
(
traced_model
.
graph
.
_exprs
[
1
],
Constant
)
assert
isinstance
(
traced_model
.
graph
.
_exprs
[
2
],
Apply
)
and
isinstance
(
traced_model
.
graph
.
_exprs
[
2
].
opdef
,
builtin
.
Elemwise
)
assert
int
(
traced_model
(
Tensor
([
1
,
2
]))[
0
])
==
3
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录