Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d3bfb0e9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d3bfb0e9
编写于
1月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge): fix trace exit code and reformat
GitOrigin-RevId: 145c06b7e7a7f98f40f0e7acc1b555c16f27e2ba
上级
23b9a98f
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
117 addition
and
46 deletion
+117
-46
imperative/python/megengine/distributed/functional.py
imperative/python/megengine/distributed/functional.py
+0
-2
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+48
-35
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+21
-5
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+0
-1
imperative/python/src/trace_info.h
imperative/python/src/trace_info.h
+3
-1
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+45
-2
未找到文件。
imperative/python/megengine/distributed/functional.py
浏览文件 @
d3bfb0e9
...
@@ -292,8 +292,6 @@ def remote_recv(
...
@@ -292,8 +292,6 @@ def remote_recv(
op
=
RemoteRecv
()
op
=
RemoteRecv
()
op
.
key
=
key
op
.
key
=
key
op
.
cn
=
device
op
.
cn
=
device
if
isinstance
(
shape
,
Tensor
):
shape
=
shape
.
numpy
()
op
.
shape
=
shape
op
.
shape
=
shape
op
.
dtype
=
dtype
op
.
dtype
=
dtype
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
op
.
addr
,
op
.
port
=
get_mm_server_addr
()
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
d3bfb0e9
...
@@ -191,19 +191,20 @@ class trace:
...
@@ -191,19 +191,20 @@ class trace:
if
len
(
ihandles
)
!=
len
(
args
):
if
len
(
ihandles
)
!=
len
(
args
):
raise
TraceMismatchError
(
"op input size different from last time"
)
raise
TraceMismatchError
(
"op input size different from last time"
)
# check all inputs of crrent op
for
h
,
x
in
zip
(
ihandles
,
args
):
for
h
,
x
in
zip
(
ihandles
,
args
):
info
=
self
.
_tinfo
[
h
]
info
=
self
.
_tinfo
[
h
]
if
info
.
external
:
if
info
.
external
:
if
(
if
(
x
.
_
_class__
is
CompiledTensorProxy
x
.
_
compiled_info
is
not
None
and
not
self
.
_tinfo
[
x
.
_
CompiledTensorProxy_
_handle
].
exported
and
not
self
.
_tinfo
[
x
.
_
mixin
_handle
].
exported
):
):
raise
TraceMismatchError
(
raise
TraceMismatchError
(
"failed to capture: input was an external tensor "
"failed to capture: input was an external tensor "
"last time, got an internal tensor this time"
"last time, got an internal tensor this time"
)
)
if
info
.
bound_data
:
if
info
.
bound_data
:
if
x
.
_
_class__
is
CompiledTensorProxy
:
if
x
.
_
compiled_info
is
not
None
:
raise
TraceMismatchError
(
raise
TraceMismatchError
(
"const capture violated: was an external tensor "
"const capture violated: was an external tensor "
"last time, got an internal tensor this time"
"last time, got an internal tensor this time"
...
@@ -225,17 +226,17 @@ class trace:
...
@@ -225,17 +226,17 @@ class trace:
)
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
else
:
else
:
if
x
.
mixin_handle
==
-
1
:
if
x
.
_
mixin_handle
==
-
1
:
if
x
.
_handle
not
in
self
.
_tensor_remaps
:
if
x
.
_handle
not
in
self
.
_tensor_remaps
:
raise
TraceMismatchError
(
raise
TraceMismatchError
(
"unexpected capture: trying to use an external tensor as "
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
"input, but that input was an internal tensor last time"
)
)
else
:
else
:
x
.
mixin_handle
=
self
.
_tensor_remaps
[
x
.
_
mixin_handle
=
self
.
_tensor_remaps
[
x
.
_handle
x
.
_handle
].
_CompiledTensorProxy__handle
].
_CompiledTensorProxy__handle
if
x
.
mixin_handle
!=
h
:
if
x
.
_
mixin_handle
!=
h
:
raise
TraceMismatchError
(
raise
TraceMismatchError
(
"mis-wiring: input edge to an data flow "
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
"graph node is different from last time"
...
@@ -245,9 +246,10 @@ class trace:
...
@@ -245,9 +246,10 @@ class trace:
outputs
=
[]
outputs
=
[]
for
h
in
ohandles
:
for
h
in
ohandles
:
info
=
self
.
_tinfo
[
h
]
info
=
self
.
_tinfo
[
h
]
# generate output tensor and create compied info
y
=
RawTensor
(
info
.
varnode
)
y
=
RawTensor
(
info
.
varnode
)
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
mixin_handle
=
h
y
.
_
mixin_handle
=
h
outputs
+=
[
y
]
outputs
+=
[
y
]
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
y
)
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
y
)
self
.
_output_handles
.
update
(
ohandles
)
self
.
_output_handles
.
update
(
ohandles
)
...
@@ -260,6 +262,7 @@ class trace:
...
@@ -260,6 +262,7 @@ class trace:
raise
TraceMismatchError
(
"trace should end here, but more op observed"
)
raise
TraceMismatchError
(
"trace should end here, but more op observed"
)
record
=
self
.
_seq
[
self
.
_pc
]
record
=
self
.
_seq
[
self
.
_pc
]
op_
,
ihandles
,
ohandles
=
record
op_
,
ihandles
,
ohandles
=
record
# Const op is represented by a str
assert
isinstance
(
op_
,
str
)
and
op_
==
"Const"
assert
isinstance
(
op_
,
str
)
and
op_
==
"Const"
eq
=
np
.
all
(
np
.
atleast_1d
(
value
)
==
self
.
_tinfo
[
ohandles
[
0
]].
bound_data
.
numpy
())
eq
=
np
.
all
(
np
.
atleast_1d
(
value
)
==
self
.
_tinfo
[
ohandles
[
0
]].
bound_data
.
numpy
())
...
@@ -273,17 +276,18 @@ class trace:
...
@@ -273,17 +276,18 @@ class trace:
outputs
=
[
self
.
_tinfo
[
h
].
bound_data
]
outputs
=
[
self
.
_tinfo
[
h
].
bound_data
]
return
outputs
return
outputs
# run in first step, record information for trace
def
_record_op
(
self
,
op
,
inputs
,
outputs
):
def
_record_op
(
self
,
op
,
inputs
,
outputs
):
if
skip_tracing
:
if
skip_tracing
:
for
x
in
inputs
:
for
x
in
inputs
:
h
=
getattr
(
x
,
"mixin_handle"
,
-
1
)
h
=
getattr
(
x
,
"
_
mixin_handle"
,
-
1
)
if
h
>=
0
:
if
h
>=
0
:
self
.
_tinfo
[
h
].
data
=
True
self
.
_tinfo
[
h
].
data
=
True
return
return
ihandles
=
[]
ihandles
=
[]
for
x
in
inputs
:
for
x
in
inputs
:
h
=
getattr
(
x
,
"mixin_handle"
,
-
1
)
h
=
getattr
(
x
,
"
_
mixin_handle"
,
-
1
)
if
h
<
0
or
(
not
self
.
_capture_as_const
and
self
.
_tinfo
[
h
].
exported
):
if
h
<
0
or
(
not
self
.
_capture_as_const
and
self
.
_tinfo
[
h
].
exported
):
h
,
info
=
self
.
_new_handle
()
h
,
info
=
self
.
_new_handle
()
info
.
external
=
True
info
.
external
=
True
...
@@ -300,8 +304,8 @@ class trace:
...
@@ -300,8 +304,8 @@ class trace:
h
,
info
=
self
.
_new_handle
()
h
,
info
=
self
.
_new_handle
()
ohandles
.
append
(
h
)
ohandles
.
append
(
h
)
info
.
external
=
False
info
.
external
=
False
x
.
mixin_handle
=
h
x
.
_
mixin_handle
=
h
x
.
recording
=
True
x
.
_
recording
=
True
x
.
_trace_mixin_info
=
info
x
.
_trace_mixin_info
=
info
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_active_tensors
[
h
]
=
TensorWeakRef
(
x
)
if
self
.
_symbolic
:
if
self
.
_symbolic
:
...
@@ -312,7 +316,7 @@ class trace:
...
@@ -312,7 +316,7 @@ class trace:
def
_record_const
(
self
,
outputs
):
def
_record_const
(
self
,
outputs
):
if
skip_tracing
:
if
skip_tracing
:
(
x
,)
=
outputs
(
x
,)
=
outputs
h
=
getattr
(
x
,
"mixin_handle"
,
-
1
)
h
=
getattr
(
x
,
"
_
mixin_handle"
,
-
1
)
if
h
>=
0
:
if
h
>=
0
:
self
.
_tinfo
[
h
].
data_read
=
True
self
.
_tinfo
[
h
].
data_read
=
True
return
return
...
@@ -326,8 +330,8 @@ class trace:
...
@@ -326,8 +330,8 @@ class trace:
info
.
shape
=
x
.
shape
info
.
shape
=
x
.
shape
info
.
bound_data
=
x
info
.
bound_data
=
x
info
.
is_const
=
True
info
.
is_const
=
True
x
.
mixin_handle
=
h
x
.
_
mixin_handle
=
h
x
.
recording
=
True
x
.
_
recording
=
True
x
.
_trace_mixin_info
=
info
x
.
_trace_mixin_info
=
info
if
self
.
_symbolic
:
if
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
self
.
_lazy_eval_tensors
[
h
]
=
TensorWeakRef
(
x
)
...
@@ -371,6 +375,7 @@ class trace:
...
@@ -371,6 +375,7 @@ class trace:
lazy_eval_graph
.
compile
(
*
lazy_eval_links
,
*
readers
)
lazy_eval_graph
.
compile
(
*
lazy_eval_links
,
*
readers
)
lazy_eval_graph
()
lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
# get values from lazy_eval_graph and assign to lazy_eval tensor
x
().
_handle
=
RawTensor
(
r
.
op
.
get_value
()).
_handle
x
().
_handle
=
RawTensor
(
r
.
op
.
get_value
()).
_handle
x
().
_reset_varnode
()
x
().
_reset_varnode
()
...
@@ -395,14 +400,14 @@ class trace:
...
@@ -395,14 +400,14 @@ class trace:
if
self
.
_untraced
:
if
self
.
_untraced
:
for
x
in
escaped_tensors
:
for
x
in
escaped_tensors
:
if
x
():
if
x
():
info
=
self
.
_tinfo
[
x
().
mixin_handle
]
info
=
self
.
_tinfo
[
x
().
_
mixin_handle
]
info
.
data_read
=
True
info
.
data_read
=
True
x
().
mixin_handle
=
-
1
x
().
_
mixin_handle
=
-
1
x
().
recording
=
False
x
().
_
recording
=
False
if
self
.
_inputs_to_restore
:
if
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
x
.
mixin_handle
=
-
1
x
.
_
mixin_handle
=
-
1
x
.
recording
=
False
x
.
_
recording
=
False
if
self
.
_symbolic
and
(
if
self
.
_symbolic
and
(
self
.
_lazy_eval_tensors
or
self
.
_lazy_eval_links
self
.
_lazy_eval_tensors
or
self
.
_lazy_eval_links
):
):
...
@@ -441,12 +446,13 @@ class trace:
...
@@ -441,12 +446,13 @@ class trace:
if
not
self
.
_untraced
and
self
.
_pc
!=
len
(
self
.
_seq
):
if
not
self
.
_untraced
and
self
.
_pc
!=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"premature end"
)
raise
TraceMismatchError
(
"premature end"
)
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
if
not
self
.
_symbolic
or
not
self
.
_untraced
:
# reset output tensors
for
x
in
self
.
_active_tensors
.
values
():
for
x
in
self
.
_active_tensors
.
values
():
if
x
()
is
not
None
:
if
x
()
is
not
None
:
x
().
_dev_tensor
()
x
().
_dev_tensor
()
x
().
_reset_varnode
()
x
().
_reset_varnode
()
x
().
mixin_handle
=
-
1
x
().
_
mixin_handle
=
-
1
x
().
recording
=
False
x
().
_
recording
=
False
x
().
_trace_mixin_info
=
None
x
().
_trace_mixin_info
=
None
try
:
try
:
...
@@ -470,9 +476,13 @@ class trace:
...
@@ -470,9 +476,13 @@ class trace:
# conditionally reading a compiled tensor in excluded region
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
# is permitted, so we have to assume every tensor might be read
for
x
in
self
.
_active_tensors
.
values
():
for
x
in
self
.
_active_tensors
.
values
():
info
=
self
.
_tinfo
[
x
().
mixin_handle
]
if
x
():
info
=
self
.
_tinfo
[
x
().
_mixin_handle
]
info
.
exported
=
True
info
.
exported
=
True
info
.
data_read
=
True
info
.
data_read
=
True
else
:
for
x
in
self
.
_active_tensors
.
values
():
if
x
():
x
().
_dev_tensor
()
x
().
_dev_tensor
()
def
_apply_graph_options
(
self
,
graph
):
def
_apply_graph_options
(
self
,
graph
):
...
@@ -528,7 +538,6 @@ class trace:
...
@@ -528,7 +538,6 @@ class trace:
info
.
varnode
=
opnode
.
outputs
[
0
]
info
.
varnode
=
opnode
.
outputs
[
0
]
in_out_links
+=
opnode
.
outputs
[
1
:]
in_out_links
+=
opnode
.
outputs
[
1
:]
cnt_data
,
cnt_value
,
cnt_shape
=
0
,
0
,
0
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
if
isinstance
(
op
,
str
)
and
op
==
"Const"
:
if
isinstance
(
op
,
str
)
and
op
==
"Const"
:
assert
len
(
ihandles
)
==
0
assert
len
(
ihandles
)
==
0
...
@@ -604,16 +613,13 @@ class trace:
...
@@ -604,16 +613,13 @@ class trace:
# Shape can be obtained from data so doesn't need its own
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
# to leverage eager h2d copy
cnt_data
+=
1
info
.
shape_read
=
False
info
.
shape_read
=
False
opnode
=
info
.
data_reader
=
G
.
OutputNode
(
v
,
*
in_out_links
)
opnode
=
info
.
data_reader
=
G
.
OutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
add_reader
(
opnode
)
if
info
.
value_read
:
if
info
.
value_read
:
cnt_value
+=
1
opnode
=
info
.
value_reader
=
G
.
ValueOutputNode
(
v
,
*
in_out_links
)
opnode
=
info
.
value_reader
=
G
.
ValueOutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
add_reader
(
opnode
)
if
info
.
shape_read
:
if
info
.
shape_read
:
cnt_shape
+=
1
opnode
=
info
.
shape_reader
=
G
.
AttrOutputNode
(
v
,
*
in_out_links
)
opnode
=
info
.
shape_reader
=
G
.
AttrOutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
add_reader
(
opnode
)
...
@@ -637,15 +643,17 @@ class trace:
...
@@ -637,15 +643,17 @@ class trace:
self
.
_process_inputs
(
*
args
,
**
kwargs
)
self
.
_process_inputs
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
outputs
=
self
.
__wrapped__
(
*
args
,
**
kwargs
)
transform
=
False
transform
=
False
# outputs can be None
if
outputs
is
not
None
:
if
outputs
is
not
None
:
if
not
isinstance
(
outputs
,
collections
.
abc
.
Sequence
):
if
not
isinstance
(
outputs
,
collections
.
abc
.
Sequence
):
transform
=
True
transform
=
True
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
for
o
in
outputs
:
for
o
in
outputs
:
# if outputs are copied, then use the newest info in trace data structure
if
o
.
_copied
:
if
o
.
_copied
:
self
.
_active_tensors
[
o
.
mixin_handle
]
=
TensorWeakRef
(
o
)
self
.
_active_tensors
[
o
.
_
mixin_handle
]
=
TensorWeakRef
(
o
)
if
self
.
_untraced
and
self
.
_symbolic
:
if
self
.
_untraced
and
self
.
_symbolic
:
self
.
_lazy_eval_tensors
[
o
.
mixin_handle
]
=
TensorWeakRef
(
o
)
self
.
_lazy_eval_tensors
[
o
.
_
mixin_handle
]
=
TensorWeakRef
(
o
)
if
self
.
_capture_as_const
:
if
self
.
_capture_as_const
:
self
.
_process_outputs
(
outputs
)
self
.
_process_outputs
(
outputs
)
if
transform
:
if
transform
:
...
@@ -819,8 +827,8 @@ class trace:
...
@@ -819,8 +827,8 @@ class trace:
info
.
device
=
x
.
device
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
numpy
().
shape
info
.
shape
=
x
.
numpy
().
shape
x
.
mixin_handle
=
h
x
.
_
mixin_handle
=
h
x
.
recording
=
True
x
.
_
recording
=
True
x
.
_trace_mixin_info
=
info
x
.
_trace_mixin_info
=
info
self
.
_inputs_to_restore
.
append
(
x
)
self
.
_inputs_to_restore
.
append
(
x
)
return
h
return
h
...
@@ -914,12 +922,12 @@ class trace:
...
@@ -914,12 +922,12 @@ class trace:
if
not
isinstance
(
x
,
RawTensor
):
if
not
isinstance
(
x
,
RawTensor
):
raise
TypeError
(
"every item of return value should be tensor"
)
raise
TypeError
(
"every item of return value should be tensor"
)
if
self
.
_untraced
:
if
self
.
_untraced
:
h
=
x
.
mixin_handle
h
=
x
.
_
mixin_handle
if
h
<
0
:
if
h
<
0
:
raise
RuntimeError
(
"output is not computed from inputs"
)
raise
RuntimeError
(
"output is not computed from inputs"
)
self
.
_output_bindings
.
append
(
h
)
self
.
_output_bindings
.
append
(
h
)
else
:
else
:
h
=
x
.
mixin_handle
h
=
x
.
_
mixin_handle
if
h
not
in
self
.
_output_handles
:
if
h
not
in
self
.
_output_handles
:
raise
RuntimeError
(
"output is not computed from inputs"
)
raise
RuntimeError
(
"output is not computed from inputs"
)
if
h
!=
self
.
_output_bindings
[
i
]:
if
h
!=
self
.
_output_bindings
[
i
]:
...
@@ -938,6 +946,11 @@ class trace:
...
@@ -938,6 +946,11 @@ class trace:
raise
RuntimeError
(
"trace is not set with profiling=True"
)
raise
RuntimeError
(
"trace is not set with profiling=True"
)
return
json
.
loads
(
self
.
_profiler
.
get
())
return
json
.
loads
(
self
.
_profiler
.
get
())
def
__del__
(
self
):
for
x
in
self
.
_tinfo
:
if
getattr
(
x
,
"bound_data"
,
None
):
x
.
bound_data
=
None
def
trace
(
self
,
*
args
,
**
kwargs
):
def
trace
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"trace is deemed unbeneficial with the new "
"trace is deemed unbeneficial with the new "
...
...
imperative/python/src/tensor.cpp
浏览文件 @
d3bfb0e9
...
@@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() {
...
@@ -291,7 +291,11 @@ PyObject* TensorWrapper::copied() {
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
PyObject* TensorWrapper::member() { \
if (m_tensor->m_trace_info.member) { \
return m_tensor->m_trace_info.member; \
return m_tensor->m_trace_info.member; \
} else { \
Py_RETURN_NONE; \
} \
} \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
void TensorWrapper::set_##member(PyObject* dest) { \
if (dest == Py_None) { \
if (dest == Py_None) { \
...
@@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) {
...
@@ -322,6 +326,7 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject
*
TensorWrapper
::
shape
()
{
PyObject
*
TensorWrapper
::
shape
()
{
// if it's tracing compiled mode, get value from compiled_info
if
(
m_tensor
->
m_trace_info
.
compiled_info
!=
nullptr
)
{
if
(
m_tensor
->
m_trace_info
.
compiled_info
!=
nullptr
)
{
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
return
PyTuple_New
(
0
);
return
PyTuple_New
(
0
);
...
@@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() {
...
@@ -332,15 +337,18 @@ PyObject* TensorWrapper::shape() {
}
}
return
shp
;
return
shp
;
}
}
// inside trace, if tensor shape is useful for other operations, set shape_read = true
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"shape_read"
,
py
::
cast
(
true
).
release
().
ptr
());
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"shape_read"
,
py
::
cast
(
true
).
release
().
ptr
());
}
}
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
return
PyTuple_New
(
0
);
return
PyTuple_New
(
0
);
}
}
TensorShape
shape
;
TensorShape
shape
;
if
(
m_tensor
->
m_var
)
{
if
(
m_tensor
->
m_var
)
{
// get shape from m_var
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
auto
*
tshp
=
mgr
.
infer_shape_fallible
(
m_tensor
->
m_var
);
auto
*
tshp
=
mgr
.
infer_shape_fallible
(
m_tensor
->
m_var
);
if
(
!
tshp
)
{
if
(
!
tshp
)
{
...
@@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() {
...
@@ -389,9 +397,11 @@ PyObject* TensorWrapper::numpy() {
}
}
return
np_val
;
return
np_val
;
}
}
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"value_read"
,
py
::
cast
(
true
).
release
().
ptr
());
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"value_read"
,
py
::
cast
(
true
).
release
().
ptr
());
}
}
if
(
m_tensor
->
m_handle
.
get
()
==
nullptr
&&
m_tensor
->
m_var
!=
nullptr
)
{
if
(
m_tensor
->
m_handle
.
get
()
==
nullptr
&&
m_tensor
->
m_var
!=
nullptr
)
{
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
auto
&&
type
=
mgr
.
get_infer_type
(
m_tensor
->
m_var
);
auto
&&
type
=
mgr
.
get_infer_type
(
m_tensor
->
m_var
);
...
@@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() {
...
@@ -411,12 +421,14 @@ PyObject* TensorWrapper::numpy() {
}
}
return
np_val
.
release
().
ptr
();
return
np_val
.
release
().
ptr
();
}
}
auto
&&
hv
=
interpreter_for_py
->
get_value
(
m_tensor
->
m_handle
.
get
());
auto
&&
hv
=
interpreter_for_py
->
get_value
(
m_tensor
->
m_handle
.
get
());
auto
arr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
hv
,
npy
::
ShareType
::
TRY_SHARE
));
auto
arr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
hv
,
npy
::
ShareType
::
TRY_SHARE
));
if
(
!
arr
)
{
if
(
!
arr
)
{
PyErr_SetString
(
PyExc_ValueError
,
"tensor invalid"
);
PyErr_SetString
(
PyExc_ValueError
,
"tensor invalid"
);
return
nullptr
;
return
nullptr
;
}
}
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
mgb_assert
(
PyArray_Check
(
arr
.
ptr
()));
mgb_assert
(
PyArray_Check
(
arr
.
ptr
()));
return
PyArray_Squeeze
(
reinterpret_cast
<
PyArrayObject
*>
(
arr
.
ptr
()));
return
PyArray_Squeeze
(
reinterpret_cast
<
PyArrayObject
*>
(
arr
.
ptr
()));
...
@@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() {
...
@@ -428,7 +440,7 @@ PyObject* TensorWrapper::varnode() {
if
(
m_tensor
->
m_var
)
{
if
(
m_tensor
->
m_var
)
{
return
py
::
cast
(
m_tensor
->
m_var
).
release
().
ptr
();
return
py
::
cast
(
m_tensor
->
m_var
).
release
().
ptr
();
}
}
return
py
::
none
().
release
().
ptr
()
;
Py_RETURN_NONE
;
}
}
void
TensorWrapper
::
reset
(
PyObject
*
tensor
)
{
void
TensorWrapper
::
reset
(
PyObject
*
tensor
)
{
...
@@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){
...
@@ -465,9 +477,13 @@ PyObject* TensorWrapper::_dev_tensor(){
if
(
dev_tensor
==
Py_None
)
{
if
(
dev_tensor
==
Py_None
)
{
throw
TraceReadError
(
"raw data of this tensor is not read in trace"
);
throw
TraceReadError
(
"raw data of this tensor is not read in trace"
);
}
}
// set m_handle to make it a real tensor
auto
py_dev_tensor
=
py
::
reinterpret_borrow
<
py
::
object
>
(
dev_tensor
);
auto
py_dev_tensor
=
py
::
reinterpret_borrow
<
py
::
object
>
(
dev_tensor
);
auto
sh
=
interpreter_for_py
->
put
(
py_dev_tensor
.
cast
<
DeviceTensorND
>
());
auto
sh
=
interpreter_for_py
->
put
(
py_dev_tensor
.
cast
<
DeviceTensorND
>
());
m_tensor
->
m_handle
=
std
::
move
(
SharedHandle
(
sh
));
m_tensor
->
m_handle
=
std
::
move
(
SharedHandle
(
sh
));
// compiled info is useless after m_handle is set
Py_DECREF
(
m_tensor
->
m_trace_info
.
compiled_info
);
Py_DECREF
(
m_tensor
->
m_trace_info
.
compiled_info
);
m_tensor
->
m_trace_info
.
compiled_info
=
nullptr
;
m_tensor
->
m_trace_info
.
compiled_info
=
nullptr
;
...
@@ -753,8 +769,8 @@ void init_tensor(py::module m) {
...
@@ -753,8 +769,8 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
copied
>
(
"_copied"
)
.
def_getset
<&
TensorWrapper
::
copied
>
(
"_copied"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"
_
mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"recording"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"
_
recording"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
.
def_getset
<&
TensorWrapper
::
compiled_info
,
&
TensorWrapper
::
set_compiled_info
>
(
"_compiled_info"
)
.
def_getset
<&
TensorWrapper
::
compiled_info
,
&
TensorWrapper
::
set_compiled_info
>
(
"_compiled_info"
)
.
def_getset
<&
TensorWrapper
::
trace_mixin_info
,
&
TensorWrapper
::
set_trace_mixin_info
>
(
"_trace_mixin_info"
)
.
def_getset
<&
TensorWrapper
::
trace_mixin_info
,
&
TensorWrapper
::
set_trace_mixin_info
>
(
"_trace_mixin_info"
)
...
...
imperative/python/src/trace.cpp
浏览文件 @
d3bfb0e9
...
@@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
...
@@ -55,7 +55,6 @@ apply_result_t apply_trace(ApplyContext& ctx) {
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
py
::
tuple
args
(
ctx
.
nargs
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
args
[
i
+
1
]
=
TensorWrapper
::
make
(
ctx
.
args
[
i
]
->
shared_from_this
());
args
[
i
+
1
]
=
TensorWrapper
::
make
(
ctx
.
args
[
i
]
->
shared_from_this
());
}
}
...
...
imperative/python/src/trace_info.h
浏览文件 @
d3bfb0e9
...
@@ -19,7 +19,9 @@ struct TraceInfo {
...
@@ -19,7 +19,9 @@ struct TraceInfo {
bool
recording
=
false
;
bool
recording
=
false
;
bool
copied
=
false
;
bool
copied
=
false
;
// refer to CompiledTensorProxy in tracing.py, works from second trace step
PyObject
*
compiled_info
=
nullptr
;
PyObject
*
compiled_info
=
nullptr
;
// refer to TensorInfo in tracing.py, only works in first trace step
PyObject
*
trace_mixin_info
=
nullptr
;
PyObject
*
trace_mixin_info
=
nullptr
;
TraceInfo
()
=
default
;
TraceInfo
()
=
default
;
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
d3bfb0e9
...
@@ -14,14 +14,17 @@ import pytest
...
@@ -14,14 +14,17 @@ import pytest
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.core.tensor.megbrain_graph
as
G
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.optimizer
as
optim
import
megengine.utils.comp_graph_tools
as
cgtools
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
tensor
from
megengine
import
Parameter
,
tensor
from
megengine.autodiff
import
GradManager
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.core.tensor.utils
import
isscalar
from
megengine.core.tensor.utils
import
isscalar
from
megengine.functional
import
exp
,
log
from
megengine.functional
import
exp
,
log
from
megengine.jit
import
exclude_from_trace
,
trace
from
megengine.jit
import
exclude_from_trace
,
trace
from
megengine.module
import
Module
from
megengine.random
import
normal
,
uniform
from
megengine.random
import
normal
,
uniform
...
@@ -39,8 +42,48 @@ def test_trace():
...
@@ -39,8 +42,48 @@ def test_trace():
np
.
testing
.
assert_equal
(
f
(
x
).
numpy
(),
y
)
np
.
testing
.
assert_equal
(
f
(
x
).
numpy
(),
y
)
def
test_output_copy_trace
():
class
Simple
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
Parameter
([
1.0
],
dtype
=
np
.
float32
)
def
forward
(
self
,
x
):
x
=
x
*
self
.
a
# will result into a copy of output in grad
x
=
F
.
exp
(
x
)
return
x
net
=
Simple
()
gm
=
GradManager
().
attach
(
net
.
parameters
())
opt
=
optim
.
SGD
(
net
.
parameters
(),
1e-3
,
momentum
=
0.9
)
data
=
tensor
(
np
.
arange
(
4
).
reshape
(
2
,
2
),
dtype
=
"float32"
)
@
trace
(
symbolic
=
False
)
def
train_f1
(
d
):
with
gm
:
loss
=
net
(
d
)
gm
.
backward
(
loss
)
opt
.
step
().
clear_grad
()
return
loss
@
trace
(
symbolic
=
True
)
def
train_f2
(
d
):
with
gm
:
loss
=
net
(
d
)
gm
.
backward
(
loss
)
opt
.
step
().
clear_grad
()
return
loss
for
i
in
range
(
2
):
y1
=
train_f1
(
data
).
numpy
()
y2
=
train_f2
(
data
).
numpy
()
np
.
testing
.
assert_equal
(
y1
,
y2
)
def
test_exclude_from_trace
():
def
test_exclude_from_trace
():
for
symbolic
in
[
False
]:
for
symbolic
in
[
False
,
True
]:
@
trace
(
symbolic
=
symbolic
)
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
def
f
(
x
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录