Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a3b2232b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
10 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
a3b2232b
编写于
9月 03, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add trace.dump
GitOrigin-RevId: ea4c9d33c8d3dbef4b93ba75ecc2a9fa80b8152c
上级
76dbaa27
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
279 addition
and
11 deletion
+279
-11
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+260
-11
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+19
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
a3b2232b
import
collections
import
contextlib
import
functools
import
itertools
import
typing
import
warnings
import
weakref
import
numpy
as
np
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.core
import
OpBase
,
apply
from
..core.tensor.core
import
OpBase
,
TensorBase
,
TensorWrapperBase
,
apply
from
..core.tensor.raw_tensor
import
OpDef
,
RawTensor
,
as_raw_tensor
from
..core.tensor.tensor
import
Tensor
from
.sublinear_memory_config
import
SublinearMemoryConfig
...
...
@@ -83,7 +89,6 @@ class trace:
self
.
__wrapped__
=
function
self
.
_symbolic
=
symbolic
self
.
_capture_as_const
=
capture_as_const
self
.
_capture_static_shape
=
False
self
.
_sublinear_memory_config
=
sublinear_memory_config
self
.
_untraced
=
True
...
...
@@ -95,6 +100,12 @@ class trace:
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
weakref
.
WeakSet
()
self
.
_active_tensors
=
weakref
.
WeakSet
()
self
.
_tensor_remaps
=
None
self
.
_inputs_to_restore
=
None
self
.
_args_bindings
=
None
self
.
_kwargs_bindings
=
None
self
.
_output_bindings
=
None
self
.
_output_names
=
None
def
_new_handle
(
self
):
handle
=
len
(
self
.
_tinfo
)
...
...
@@ -132,10 +143,13 @@ class trace:
"last time, got an internal tensor this time"
)
if
x
.
_handle
!=
info
.
bound_data
.
_handle
:
raise
TraceMismatchError
(
"const capture violated: got "
"a different tensor this time"
)
if
not
np
.
array_equal
(
x
.
numpy
(),
info
.
bound_data
.
numpy
(),
equal_nan
=
True
):
raise
TraceMismatchError
(
"const capture violated: got "
"a different tensor this time"
)
else
:
if
info
.
dtype
!=
x
.
dtype
:
raise
TraceMismatchError
(
...
...
@@ -148,10 +162,13 @@ class trace:
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
else
:
if
x
.
__class__
is
not
CompiledTensorProxy
:
raise
TraceMismatchError
(
"unexpected capture: trying to use an external tensor as input, "
"but that input was an internal tensor last time"
)
if
x
not
in
self
.
_tensor_remaps
:
raise
TraceMismatchError
(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else
:
x
=
self
.
_tensor_remaps
[
x
]
if
x
.
_CompiledTensorProxy__handle
!=
h
:
raise
TraceMismatchError
(
"mis-wiring: input edge to an data flow "
...
...
@@ -227,6 +244,9 @@ class trace:
info
=
self
.
_tinfo
[
x
.
_TraceMixin__handle
]
info
.
data_read
=
True
x
.
_TraceMixin__restore
()
if
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
x
.
_TraceMixin__restore
()
if
self
.
_symbolic
:
# eval lazy eval tensors
lazy_eval_tensors
=
tuple
(
self
.
_lazy_eval_tensors
)
...
...
@@ -252,6 +272,7 @@ class trace:
self
.
_reset_exec_env
()
self
.
_pc
=
0
self
.
_tensor_remaps
=
None
apply
.
disable
(
apply_with_tracing
)
apply
.
disable
(
apply_const_with_tracing
)
apply
.
disable
(
apply_symbolic_mode
)
...
...
@@ -260,6 +281,10 @@ class trace:
active_trace
=
None
def
_begin_excluded_region
(
self
):
if
self
.
_capture_as_const
:
raise
RuntimeError
(
"exclude_from_trace cannot be used with capture_as_const"
)
if
self
.
_untraced
:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
...
...
@@ -292,6 +317,19 @@ class trace:
need_reset_nodes
=
self
.
_need_reset_nodes
=
[]
# links enforce ordering of I/O nodes
links
=
()
if
self
.
_capture_as_const
:
for
h
in
itertools
.
chain
(
self
.
_args_bindings
,
self
.
_kwargs_bindings
.
values
()
):
info
=
self
.
_tinfo
[
h
]
opnode
=
info
.
data_setter
=
G
.
InputNode
(
device
=
info
.
device
,
dtype
=
info
.
dtype
,
graph
=
graph
)
need_reset_nodes
.
append
(
opnode
)
info
.
varnode
=
opnode
.
outputs
[
0
]
links
+=
opnode
.
outputs
[
1
:]
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
ivars
=
[]
readers
=
[]
...
...
@@ -355,7 +393,193 @@ class trace:
def
__call__
(
self
,
*
args
,
**
kwargs
):
with
self
.
_setup
():
return
self
.
__wrapped__
(
*
args
,
**
kwargs
)
if
self
.
_capture_as_const
:
self
.
_process_inputs
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
if
self
.
_capture_as_const
:
self
.
_process_outputs
(
outputs
)
return
outputs
def
dump
(
self
,
file
,
*
,
arg_names
=
None
,
output_names
=
None
):
if
not
self
.
_capture_as_const
:
raise
ValueError
(
"you must specify capture_as_const=True at __init__ to use dump"
)
if
self
.
_untraced
:
raise
RuntimeError
(
"should run at least once before calling dump"
)
if
self
.
_output_names
and
output_names
:
raise
TypeError
(
"cannot specify output_names when output is already in dict format"
)
if
output_names
and
not
isinstance
(
output_names
,
collections
.
Sequence
):
output_names
=
(
output_names
,)
if
output_names
and
len
(
output_names
)
!=
len
(
self
.
_output_bindings
):
raise
ValueError
(
"wrong number of output_names"
)
if
arg_names
and
not
isinstance
(
arg_names
,
collections
.
Sequence
):
arg_names
=
(
arg_names
,)
if
arg_names
and
len
(
arg_names
)
!=
len
(
self
.
_arg_bindings
):
raise
ValueError
(
"wrong number of arg_names"
)
output_names
=
output_names
or
self
.
_output_names
h2v
=
{}
graph
=
G
.
Graph
()
for
i
,
h
in
enumerate
(
self
.
_args_bindings
):
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
info
.
device
)
if
arg_names
:
h2v
[
h
].
name
=
arg_names
[
i
]
for
k
,
h
in
self
.
_kwargs_bindings
.
items
():
info
=
self
.
_tinfo
[
h
]
h2v
[
h
]
=
graph
.
make_h2d
(
dtype
=
info
.
dtype
,
device
=
info
.
device
)
h2v
[
h
].
name
=
k
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
ivars
=
[]
for
h
in
ihandles
:
info
=
self
.
_tinfo
[
h
]
if
h
not
in
h2v
:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
_dev_tensor
())
ivars
.
append
(
h2v
[
h
])
ovars
=
apply
(
op
,
*
ivars
)
assert
len
(
ovars
)
==
len
(
ohandles
)
h2v
.
update
(
zip
(
ohandles
,
ovars
))
dest_vars
=
[]
for
i
,
h
in
enumerate
(
self
.
_output_bindings
):
v
=
h2v
[
h
]
if
output_names
:
v
.
name
=
output_names
[
i
]
dest_vars
.
append
(
v
)
if
isinstance
(
file
,
str
):
file
=
open
(
file
,
"wb"
)
file
.
write
(
G
.
dump
(
*
dest_vars
))
def
_process_inputs
(
self
,
*
args
,
**
kwargs
):
if
self
.
_untraced
:
self
.
_inputs_to_restore
=
[]
def
record_input
(
x
):
if
x
is
None
:
return
h
,
info
=
self
.
_new_handle
()
info
.
external
=
False
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
TraceMixin
.
_TraceMixin__inject
(
x
,
h
)
self
.
_inputs_to_restore
.
append
(
x
)
return
h
self
.
_args_bindings
=
[]
for
i
,
x
in
enumerate
(
args
):
x
=
find_raw_tensor
(
x
)
if
x
is
None
:
raise
TypeError
(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one"
%
i
)
self
.
_args_bindings
.
append
(
record_input
(
x
))
self
.
_kwargs_bindings
=
{}
for
k
,
x
in
kwargs
.
items
():
x
=
find_raw_tensor
(
x
)
if
x
is
not
None
:
self
.
_kwargs_bindings
[
k
]
=
record_input
(
x
)
else
:
if
len
(
args
)
!=
len
(
self
.
_args_bindings
):
raise
TraceMismatchError
(
"positional argument length mismatch"
)
self
.
_tensor_remaps
=
{}
for
i
,
(
h
,
x
)
in
enumerate
(
zip
(
self
.
_args_bindings
,
args
)):
x
=
find_raw_tensor
(
x
)
if
x
is
None
:
raise
TypeError
(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one"
%
i
)
info
=
self
.
_tinfo
[
h
]
if
x
.
dtype
!=
info
.
dtype
:
raise
TypeError
(
"args[%d].dtype different from last time"
%
i
)
if
x
.
device
!=
info
.
device
:
raise
TypeError
(
"args[%d].device different from last time"
%
i
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
self
.
_tensor_remaps
[
x
]
=
CompiledTensorProxy
(
h
)
kwargs_tensors
=
{}
for
k
,
x
in
kwargs
.
items
():
x
=
find_raw_tensor
(
x
)
if
x
is
not
None
:
kwargs_tensors
[
k
]
=
x
if
set
(
kwargs_tensors
)
!=
set
(
self
.
_kwargs_bindings
):
too_many
=
set
(
kwargs_tensors
)
-
set
(
self
.
_kwargs_bindings
)
too_few
=
set
(
self
.
_kwargs_bindings
)
-
set
(
kwargs_tensors
)
if
too_many
:
raise
TraceMismatchError
(
"keyword arguments found to be tensor this time "
"but were non-tensor previously: %s"
%
" "
.
join
(
too_many
)
)
if
too_few
:
raise
TraceMismatchError
(
"keyword arguments found to be non-tensor this time "
"but were tensor previously: %s"
%
" "
.
join
(
too_few
)
)
for
k
,
h
in
self
.
_kwargs_bindings
.
items
():
x
=
kwargs_tensors
[
k
]
info
=
self
.
_tinfo
[
h
]
if
x
.
dtype
!=
info
.
dtype
:
raise
TypeError
(
"kwargs[%s].dtype different from last time"
%
k
)
if
x
.
device
!=
info
.
device
:
raise
TypeError
(
"kwargs[%s].device different from last time"
%
k
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
self
.
_tensor_remaps
[
x
]
=
CompiledTensorProxy
(
h
)
def
_process_outputs
(
self
,
outputs
):
output_names
=
None
if
isinstance
(
outputs
,
collections
.
Mapping
):
output_names
,
outputs
=
zip
(
*
sorted
(
outputs
.
items
()))
elif
not
isinstance
(
outputs
,
collections
.
Sequence
):
outputs
=
(
outputs
,)
if
not
self
.
_untraced
:
if
output_names
!=
self
.
_output_names
:
too_many
=
set
(
output_names
)
-
set
(
self
.
_output_names
)
too_few
=
set
(
self
.
_output_names
)
-
set
(
output_names
)
if
too_many
:
raise
TraceMismatchError
(
"output has more keys than last time: %s"
%
" "
.
join
(
too_many
)
)
if
too_few
:
raise
TraceMismatchError
(
"output has less keys than last time: %s"
%
" "
.
join
(
too_few
)
)
if
len
(
outputs
)
!=
len
(
self
.
_output_bindings
):
raise
TraceMismatchError
(
"output size differs from last time"
)
else
:
self
.
_output_names
=
output_names
self
.
_output_bindings
=
[]
for
i
,
x
in
enumerate
(
outputs
):
x
=
find_raw_tensor
(
x
)
if
x
is
None
:
raise
TypeError
(
"every item of return value should be tensor"
)
if
self
.
_untraced
:
if
not
isinstance
(
x
,
TraceMixin
):
raise
RuntimeError
(
"output is not computed from inputs"
)
h
=
x
.
_TraceMixin__handle
self
.
_output_bindings
.
append
(
h
)
else
:
if
not
isinstance
(
x
,
CompiledTensorProxy
):
raise
RuntimeError
(
"output is not computed from inputs"
)
h
=
x
.
_CompiledTensorProxy__handle
if
h
!=
self
.
_output_bindings
[
i
]:
raise
TraceMismatchError
(
"retval[%s] is a different tensor than last time"
%
(
output_names
and
output_names
[
i
]
or
i
)
)
class
CompiledTensorProxy
(
RawTensor
):
...
...
@@ -514,6 +738,7 @@ apply.disable(apply_symbolic_mode)
def
apply_const_symbolic_mode
(
op
:
Const
,
*
args
:
RawTensor
):
graph
=
active_trace
.
_lazy_eval_graph
ret
=
LazyEvalTensor
(
graph
.
make_const
(
op
.
value
,
dtype
=
op
.
dtype
,
device
=
op
.
device
))
active_trace
.
_lazy_eval_tensors
.
add
(
ret
)
return
(
ret
,)
...
...
@@ -561,3 +786,27 @@ class BrokenRawTensor(RawTensor):
def
__setattr__
(
self
,
*
_
):
raise
RuntimeError
(
"broken due to misuse of tracing"
)
@
functools
.
singledispatch
def
find_raw_tensor
(
x
):
return
None
@
find_raw_tensor
.
register
(
RawTensor
)
def
_
(
x
):
return
x
@
find_raw_tensor
.
register
(
TensorWrapperBase
)
def
_
(
x
):
x
=
getattr
(
x
,
"__wrapped__"
,
None
)
if
x
is
not
None
:
return
find_raw_tensor
(
x
)
@
find_raw_tensor
.
register
(
Tensor
)
def
_
(
x
):
x
=
getattr
(
x
,
"_data"
,
None
)
if
x
is
not
None
:
return
find_raw_tensor
(
x
)
imperative/python/test/unit/test_tracing.py
浏览文件 @
a3b2232b
import
io
import
numpy
as
np
from
megengine.core.ops
import
builtin
as
ops
...
...
@@ -63,3 +65,20 @@ def test_print_in_trace():
buf
=
None
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
np
.
testing
.
assert_equal
(
z
,
buf
)
def
test_dump
():
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"negate"
)
(
y
,)
=
apply
(
op
,
x
)
return
y
x
=
as_raw_tensor
([
1
]).
numpy
()
y
=
f
.
__wrapped__
(
as_raw_tensor
(
x
)).
numpy
()
for
i
in
range
(
3
):
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
file
=
io
.
BytesIO
()
f
.
dump
(
file
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录