Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c9c3429a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c9c3429a
编写于
1月 05, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge): fix sublinear
GitOrigin-RevId: 5bb038378121f244fa13c891e497f72507465413
上级
de0742be
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
126 addition
and
86 deletion
+126
-86
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+36
-52
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+55
-23
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+7
-7
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+1
-1
imperative/python/src/trace_info.h
imperative/python/src/trace_info.h
+26
-3
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+1
-0
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
c9c3429a
...
...
@@ -20,30 +20,22 @@ import numpy as np
from
..core._imperative_rt
import
GraphProfiler
,
common
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
TensorWeakRef
from
..core._imperative_rt.core2
import
__make_empty_tensor
as
make_empty_tensor
from
..core._imperative_rt.core2
import
(
TensorWeakRef
,
apply
,
set_compiled
,
set_symbolic
,
set_tracing
,
skip_tracing
,
unset_compiled
,
unset_symbolic
,
unset_tracing
,
)
from
..core._imperative_rt.ops
import
(
CollectiveComm
,
GaussianRNG
,
RemoteRecv
,
RemoteSend
,
UniformRNG
,
)
from
..core._imperative_rt.ops
import
CollectiveComm
,
RemoteRecv
,
RemoteSend
from
..core._trace_option
import
set_symbolic_shape
from
..core._wrap
import
device
as
as_device
from
..core.ops.builtin
import
BackwardGraph
,
OpDef
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.utils
import
setscalar
from
.sublinear_memory_config
import
SublinearMemoryConfig
...
...
@@ -159,7 +151,6 @@ class trace:
self
.
_profiler
=
None
self
.
_graph_opt_level
=
opt_level
self
.
_symbolic_shape
=
symbolic_shape
self
.
_handle2tensors
=
{}
self
.
_output_handles
=
set
()
self
.
_reset
()
...
...
@@ -195,7 +186,7 @@ class trace:
raise
TraceMismatchError
(
"trace should end here, but more op observed"
)
record
=
self
.
_seq
[
self
.
_pc
]
op_
,
ihandles
,
ohandles
=
record
if
op
!=
op_
:
if
(
isinstance
(
op_
,
str
)
and
op_
==
"Const"
)
or
(
op
!=
op_
)
:
raise
TraceMismatchError
(
"op different from last time"
)
if
len
(
ihandles
)
!=
len
(
args
):
raise
TraceMismatchError
(
"op input size different from last time"
)
...
...
@@ -253,9 +244,11 @@ class trace:
self
.
_pc
+=
1
outputs
=
[]
for
h
in
ohandles
:
t
=
CompiledTensorProxy
(
h
)
t
.
_dev_tensor
()
outputs
+=
[
t
.
_CompiledTensorProxy__tensor
]
info
=
self
.
_tinfo
[
h
]
y
=
RawTensor
(
info
.
varnode
)
y
.
_compiled_info
=
CompiledTensorProxy
(
h
)
y
.
mixin_handle
=
h
outputs
+=
[
y
]
self
.
_output_handles
.
update
(
ohandles
)
self
.
_active_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
return
outputs
...
...
@@ -285,7 +278,7 @@ class trace:
for
x
in
inputs
:
h
=
getattr
(
x
,
"mixin_handle"
,
-
1
)
if
h
>=
0
:
x
.
data_read
=
True
self
.
_tinfo
[
h
].
data
=
True
return
ihandles
=
[]
...
...
@@ -308,7 +301,8 @@ class trace:
ohandles
.
append
(
h
)
info
.
external
=
False
x
.
mixin_handle
=
h
self
.
_handle2tensors
[
h
]
=
x
x
.
recording
=
True
x
.
_trace_mixin_info
=
info
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
self
.
_active_tensors
.
update
([
TensorWeakRef
(
o
)
for
o
in
outputs
])
...
...
@@ -318,7 +312,7 @@ class trace:
(
x
,)
=
outputs
h
=
getattr
(
x
,
"mixin_handle"
,
-
1
)
if
h
>=
0
:
x
.
data_read
=
True
self
.
_tinfo
[
h
]
.
data_read
=
True
return
(
x
,)
=
outputs
...
...
@@ -331,7 +325,8 @@ class trace:
info
.
bound_data
=
x
info
.
is_const
=
True
x
.
mixin_handle
=
h
self
.
_handle2tensors
[
h
]
=
x
x
.
recording
=
True
x
.
_trace_mixin_info
=
info
self
.
_seq
.
append
((
"Const"
,
tuple
(),
tuple
(
ohandles
)))
def
_set_active
(
self
,
active
:
bool
):
...
...
@@ -346,7 +341,6 @@ class trace:
def
_init_trace
(
self
,
symbolic
:
bool
):
if
symbolic
:
set_symbolic
()
self
.
_lazy_eval_graph
=
G
.
Graph
()
self
.
_apply_graph_options
(
self
.
_lazy_eval_graph
)
self
.
_lazy_eval_links
=
()
...
...
@@ -383,8 +377,6 @@ class trace:
if
self
.
_untraced
:
self
.
_init_trace
(
self
.
_symbolic
)
else
:
# disable symbolic mode
unset_symbolic
()
set_compiled
()
if
self
.
_graph
is
None
:
self
.
_compile
()
...
...
@@ -394,18 +386,15 @@ class trace:
escaped_tensors
=
self
.
_take_escaped_tensors
()
if
self
.
_untraced
:
for
x
in
escaped_tensors
:
if
x
():
info
=
self
.
_tinfo
[
x
().
mixin_handle
]
x
()
.
data_read
=
True
info
.
data_read
=
True
x
().
mixin_handle
=
-
1
x
().
recording
=
False
if
self
.
_inputs_to_restore
:
for
x
in
self
.
_inputs_to_restore
:
x
.
mixin_handle
=
-
1
for
h
,
x
in
list
(
self
.
_handle2tensors
.
items
()):
info
=
self
.
_tinfo
[
h
]
info
.
data_read
=
x
.
data_read
info
.
shape_read
=
x
.
shape_read
info
.
value_read
=
x
.
value_read
del
self
.
_handle2tensors
[
h
]
x
.
recording
=
False
if
self
.
_symbolic
and
(
self
.
_lazy_eval_tensors
or
self
.
_lazy_eval_links
):
...
...
@@ -437,7 +426,6 @@ class trace:
self
.
_set_active
(
False
)
set_symbolic_shape
(
self
.
_save_symbolic_shape
)
unset_compiled
()
unset_symbolic
()
unset_tracing
()
def
do_exit
():
...
...
@@ -449,6 +437,7 @@ class trace:
if
x
()
is
not
None
:
x
().
_dev_tensor
()
x
().
mixin_handle
=
-
1
x
().
recording
=
False
try
:
do_enter
()
...
...
@@ -473,7 +462,8 @@ class trace:
for
x
in
self
.
_active_tensors
:
info
=
self
.
_tinfo
[
x
().
mixin_handle
]
info
.
exported
=
True
x
().
data_read
=
True
info
.
data_read
=
True
x
().
_dev_tensor
()
def
_apply_graph_options
(
self
,
graph
):
...
...
@@ -528,6 +518,7 @@ class trace:
info
.
varnode
=
opnode
.
outputs
[
0
]
in_out_links
+=
opnode
.
outputs
[
1
:]
cnt_data
,
cnt_value
,
cnt_shape
=
0
,
0
,
0
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
if
isinstance
(
op
,
str
)
and
op
==
"Const"
:
assert
len
(
ihandles
)
==
0
...
...
@@ -603,13 +594,16 @@ class trace:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
cnt_data
+=
1
info
.
shape_read
=
False
opnode
=
info
.
data_reader
=
G
.
OutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
if
info
.
value_read
:
cnt_value
+=
1
opnode
=
info
.
value_reader
=
G
.
ValueOutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
if
info
.
shape_read
:
cnt_shape
+=
1
opnode
=
info
.
shape_reader
=
G
.
AttrOutputNode
(
v
,
*
in_out_links
)
add_reader
(
opnode
)
...
...
@@ -804,7 +798,8 @@ class trace:
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
numpy
().
shape
x
.
mixin_handle
=
h
self
.
_handle2tensors
[
h
]
=
x
x
.
recording
=
True
x
.
_trace_mixin_info
=
info
self
.
_inputs_to_restore
.
append
(
x
)
return
h
...
...
@@ -940,7 +935,6 @@ class CompiledTensorProxy:
self
.
__shape
=
None
self
.
__data
=
None
self
.
__value
=
None
self
.
__tensor
=
make_empty_tensor
()
@
property
def
dtype
(
self
):
...
...
@@ -958,7 +952,7 @@ class CompiledTensorProxy:
if
self
.
__info
.
shape_read
:
self
.
__shape
=
self
.
__info
.
shape_reader
.
get_value
().
shape
elif
self
.
__info
.
data_read
:
self
.
__shape
=
self
.
_
_info
.
_
dev_tensor
().
shape
self
.
__shape
=
self
.
_dev_tensor
().
shape
else
:
raise
TraceMismatchError
(
"shape of this tensor is not read in trace"
)
return
self
.
__shape
...
...
@@ -980,25 +974,14 @@ class CompiledTensorProxy:
if
not
self
.
__info
.
data_read
:
raise
TraceMismatchError
(
"raw data of this tensor is not read in trace"
)
self
.
__data
=
self
.
__info
.
data_reader
.
get_value
()
self
.
__tensor
.
_reset
(
RawTensor
(
self
.
__data
))
self
.
__tensor
.
mixin_handle
=
self
.
__handle
return
self
.
__data
def
_drop
(
self
):
return
def
_swap_in
(
self
):
return
def
_swap_out
(
self
):
return
def
__del__
(
self
):
if
self
.
__
tensor
.
shape_read
and
self
.
__shape
is
not
None
:
if
self
.
__
info
.
shape_read
and
self
.
__shape
is
not
None
:
self
.
__info
.
shape_reader
.
drop_value
()
if
self
.
__
tensor
.
value_read
and
self
.
__value
is
not
None
:
if
self
.
__
info
.
value_read
and
self
.
__value
is
not
None
:
self
.
__info
.
value_reader
.
drop_value
()
if
self
.
__
tensor
.
data_read
and
self
.
__data
is
not
None
:
if
self
.
__
info
.
data_read
and
self
.
__data
is
not
None
:
self
.
__info
.
data_reader
.
drop_value
()
...
...
@@ -1054,6 +1037,8 @@ def apply_const_symbolic_mode(value, dtype, device):
# don't need to unset tracing
# because varnode construction will ignore tracing flag
ret
=
RawTensor
(
graph
.
make_const
(
value
,
dtype
=
dtype
,
device
=
device
))
if
np
.
array
(
value
).
ndim
==
0
:
setscalar
(
ret
)
active_trace
.
_lazy_eval_tensors
.
add
(
TensorWeakRef
(
ret
))
return
(
ret
,)
...
...
@@ -1084,7 +1069,6 @@ def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
return
active_trace
.
_apply_const
(
value
,
dtype
,
device
)
# this hook injects TraceMixin
def
apply_with_tracing
(
op
:
OpDef
,
*
args
:
RawTensor
):
if
active_trace
.
_symbolic
:
outputs
=
apply_symbolic_mode
(
op
,
*
args
)
...
...
imperative/python/src/tensor.cpp
浏览文件 @
c9c3429a
...
...
@@ -54,7 +54,6 @@ REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
#undef REGISTE_APPLY_FUNC
bool
is_tracing
=
false
;
bool
is_symbolic
=
false
;
bool
is_compiled
=
false
;
#define SET_UNSET_PROP(mode) \
...
...
@@ -66,7 +65,6 @@ bool is_compiled = false;
} \
SET_UNSET_PROP
(
tracing
)
SET_UNSET_PROP
(
symbolic
)
SET_UNSET_PROP
(
compiled
)
#undef SET_UNSET_PROP
...
...
@@ -280,14 +278,27 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
m_tensor->m_trace_info.member = real_dest; \
}
REGISTE_TENSORWRAPPER_FUNC
(
bool
,
data_read
)
REGISTE_TENSORWRAPPER_FUNC
(
bool
,
value_read
)
REGISTE_TENSORWRAPPER_FUNC
(
bool
,
shape_read
)
REGISTE_TENSORWRAPPER_FUNC
(
int64_t
,
mixin_handle
)
REGISTE_TENSORWRAPPER_FUNC
(
bool
,
recording
)
#undef REGISTE_TENSORWRAPPER_FUNC
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member) \
PyObject* TensorWrapper::member() { \
return m_tensor->m_trace_info.member; \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
Py_INCREF(dest); \
m_tensor->m_trace_info.member = dest; \
}
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
(
compiled_info
)
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
(
trace_mixin_info
)
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
PyObject
*
TensorWrapper
::
handle
()
{
return
py
::
cast
(
m_tensor
->
m_handle
).
release
().
ptr
();
}
...
...
@@ -301,8 +312,14 @@ void TensorWrapper::set_handle(PyObject* dest) {
PyObject
*
TensorWrapper
::
shape
()
{
if
(
!
skip_tracing
)
{
set_shape_read
(
py
::
cast
(
true
).
release
().
ptr
());
if
(
m_tensor
->
m_trace_info
.
compiled_info
!=
nullptr
)
{
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
return
PyTuple_New
(
0
);
}
return
PyObject_GetAttrString
(
m_tensor
->
m_trace_info
.
compiled_info
,
"shape"
);
}
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"shape_read"
,
py
::
cast
(
true
).
release
().
ptr
());
}
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
return
PyTuple_New
(
0
);
...
...
@@ -310,7 +327,12 @@ PyObject* TensorWrapper::shape() {
TensorShape
shape
;
if
(
m_tensor
->
m_var
)
{
shape
=
m_tensor
->
m_var
->
shape
();
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
auto
*
tshp
=
mgr
.
infer_shape_fallible
(
m_tensor
->
m_var
);
if
(
!
tshp
)
{
Py_RETURN_NONE
;
}
shape
=
*
tshp
;
}
else
{
shape
=
m_tensor
->
shape
();
}
...
...
@@ -343,8 +365,15 @@ PyObject* TensorWrapper::device() {
PyObject
*
TensorWrapper
::
numpy
()
{
if
(
!
skip_tracing
)
{
set_value_read
(
py
::
cast
(
true
).
release
().
ptr
());
if
(
m_tensor
->
m_trace_info
.
compiled_info
!=
nullptr
)
{
PyObject
*
np_val
=
PyObject_CallMethod
(
m_tensor
->
m_trace_info
.
compiled_info
,
"numpy"
,
nullptr
);
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
np_val
=
PyArray_Squeeze
(
reinterpret_cast
<
PyArrayObject
*>
(
np_val
));
}
return
np_val
;
}
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"value_read"
,
py
::
cast
(
true
).
release
().
ptr
());
}
if
(
m_tensor
->
m_handle
.
get
()
==
nullptr
&&
m_tensor
->
m_var
!=
nullptr
)
{
auto
&&
mgr
=
m_tensor
->
m_var
->
owner_graph
()
->
static_infer_manager
();
...
...
@@ -359,7 +388,11 @@ PyObject* TensorWrapper::numpy() {
PyErr_SetString
(
PyExc_ValueError
,
"tensor invalid"
);
return
nullptr
;
}
return
py
::
cast
(
*
val
).
attr
(
"numpy"
)().
release
().
ptr
();
auto
np_val
=
py
::
cast
(
*
val
).
attr
(
"numpy"
)();
if
(
m_tensor
->
m_flags
&
Tensor
::
Flags
::
SCALAR
)
{
return
PyArray_Squeeze
(
reinterpret_cast
<
PyArrayObject
*>
(
np_val
.
release
().
ptr
()));
}
return
np_val
.
release
().
ptr
();
}
auto
&&
hv
=
interpreter_for_py
->
get_value
(
m_tensor
->
m_handle
.
get
());
auto
arr
=
py
::
reinterpret_steal
<
py
::
array
>
(
npy
::
ndarray_from_tensor
(
hv
,
npy
::
ShareType
::
TRY_SHARE
));
...
...
@@ -410,8 +443,14 @@ PyObject* TensorWrapper::detach() {
}
PyObject
*
TensorWrapper
::
_dev_tensor
(){
if
(
!
skip_tracing
)
{
set_data_read
(
py
::
cast
(
true
).
release
().
ptr
());
if
(
m_tensor
->
m_trace_info
.
compiled_info
!=
nullptr
)
{
auto
*
dev_tensor
=
PyObject_CallMethod
(
m_tensor
->
m_trace_info
.
compiled_info
,
"_dev_tensor"
,
nullptr
);
auto
py_dev_tensor
=
py
::
reinterpret_borrow
<
py
::
object
>
(
dev_tensor
);
auto
sh
=
interpreter_for_py
->
put
(
py_dev_tensor
.
cast
<
DeviceTensorND
>
());
m_tensor
->
m_handle
=
std
::
move
(
SharedHandle
(
sh
));
}
if
(
m_tensor
->
m_trace_info
.
recording
&&
!
skip_tracing
)
{
PyObject_SetAttrString
(
m_tensor
->
m_trace_info
.
trace_mixin_info
,
"data_read"
,
py
::
cast
(
true
).
release
().
ptr
());
}
auto
dev_tensor
=
interpreter_for_py
->
get_dev_tensor
(
m_tensor
->
m_handle
.
get
());
return
py
::
cast
(
dev_tensor
).
release
().
ptr
();
...
...
@@ -668,9 +707,6 @@ WRAP_FUNC_PY35(get_device);
{ #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif
py
::
object
make_empty_tensorwrapper
()
{
return
TensorWrapper
::
make
(
std
::
move
(
std
::
make_shared
<
Tensor
>
()));
}
void
init_tensor
(
py
::
module
m
)
{
imperative
::
Tensor
::
static_initialize
();
...
...
@@ -692,11 +728,11 @@ void init_tensor(py::module m) {
.
def
<&
TensorWrapper
::
_drop
>
(
"_drop"
)
.
def
<&
TensorWrapper
::
reset_varnode
>
(
"_reset_varnode"
)
.
def_getset
<&
TensorWrapper
::
varnode
>
(
"_varnode"
)
.
def_getset
<&
TensorWrapper
::
data_read
,
&
TensorWrapper
::
set_data_read
>
(
"data_read"
)
.
def_getset
<&
TensorWrapper
::
value_read
,
&
TensorWrapper
::
set_value_read
>
(
"value_read"
)
.
def_getset
<&
TensorWrapper
::
shape_read
,
&
TensorWrapper
::
set_shape_read
>
(
"shape_read"
)
.
def_getset
<&
TensorWrapper
::
mixin_handle
,
&
TensorWrapper
::
set_mixin_handle
>
(
"mixin_handle"
)
.
def_getset
<&
TensorWrapper
::
recording
,
&
TensorWrapper
::
set_recording
>
(
"recording"
)
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
.
def_getset
<&
TensorWrapper
::
compiled_info
,
&
TensorWrapper
::
set_compiled_info
>
(
"_compiled_info"
)
.
def_getset
<&
TensorWrapper
::
trace_mixin_info
,
&
TensorWrapper
::
set_trace_mixin_info
>
(
"_trace_mixin_info"
)
.
finalize
();
if
(
!
tensor_type
)
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
"Tensor"
,
tensor_type
);
...
...
@@ -771,12 +807,8 @@ void init_tensor(py::module m) {
m
.
def
(
"set_tracing"
,
&
set_tracing
);
m
.
def
(
"unset_tracing"
,
&
unset_tracing
);
m
.
def
(
"set_symbolic"
,
&
set_symbolic
);
m
.
def
(
"unset_symbolic"
,
&
unset_symbolic
);
m
.
def
(
"set_compiled"
,
&
set_compiled
);
m
.
def
(
"unset_compiled"
,
&
unset_compiled
);
m
.
def
(
"__make_empty_tensor"
,
&
make_empty_tensorwrapper
);
}
#undef MGE_PY_INTERFACE
...
...
imperative/python/src/tensor.h
浏览文件 @
c9c3429a
...
...
@@ -159,15 +159,16 @@ struct TensorWrapper {
PyObject
*
handle
();
void
set_handle
(
PyObject
*
);
PyObject
*
data_read
();
PyObject
*
value_read
();
PyObject
*
shape_read
();
PyObject
*
mixin_handle
();
PyObject
*
recording
();
void
set_data_read
(
PyObject
*
);
void
set_value_read
(
PyObject
*
);
void
set_shape_read
(
PyObject
*
);
void
set_mixin_handle
(
PyObject
*
);
void
set_recording
(
PyObject
*
);
PyObject
*
compiled_info
();
void
set_compiled_info
(
PyObject
*
);
PyObject
*
trace_mixin_info
();
void
set_trace_mixin_info
(
PyObject
*
);
};
...
...
@@ -219,7 +220,6 @@ template <typename... Args>
constexpr
bool
is_all_tensor_ptr
=
(...
&&
std
::
is_same_v
<
decltype
(
resolve_arrow
(
std
::
declval
<
Args
>
())),
Tensor
*>
);
extern
bool
is_tracing
;
// FIXME: should use ApplyContext::global_enable
extern
bool
is_symbolic
;
extern
bool
is_compiled
;
template
<
typename
...
Args
,
std
::
enable_if_t
<
is_all_tensor_ptr
<
Args
...>,
int
>
=
0
>
...
...
imperative/python/src/trace.cpp
浏览文件 @
c9c3429a
...
...
@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t
outputs
;
if
(
ctx
.
backward
)
{
// reach here when
symbolic=True or
compiled=True
// reach here when compiled=True
// call megbrain_graph.py apply(BackwardGraph, *args)
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
...
...
imperative/python/src/trace_info.h
浏览文件 @
c9c3429a
...
...
@@ -10,15 +10,38 @@
*/
#include "inttypes.h"
#include "Python.h"
namespace
mgb
::
imperative
::
python
{
struct
TraceInfo
{
int64_t
mixin_handle
=
-
1
;
bool
recording
=
false
;
bool
data_read
=
false
;
bool
value_read
=
false
;
bool
shape_read
=
false
;
PyObject
*
compiled_info
=
nullptr
;
PyObject
*
trace_mixin_info
=
nullptr
;
TraceInfo
()
=
default
;
TraceInfo
&
operator
=
(
const
TraceInfo
&
that
)
{
mixin_handle
=
that
.
mixin_handle
;
recording
=
that
.
recording
;
compiled_info
=
that
.
compiled_info
;
Py_XINCREF
(
compiled_info
);
trace_mixin_info
=
that
.
trace_mixin_info
;
Py_XINCREF
(
trace_mixin_info
);
return
*
this
;
}
~
TraceInfo
()
{
Py_XDECREF
(
trace_mixin_info
);
// Py_XDECREF(compiled_info);
}
private:
TraceInfo
(
const
TraceInfo
&
that
)
=
default
;
};
}
// namespace mgb::imperative::python
imperative/python/test/unit/test_tracing.py
浏览文件 @
c9c3429a
...
...
@@ -311,6 +311,7 @@ def test_trace_warp_perspective():
f
(
x
,
M
)
@
pytest
.
mark
.
skip
(
reason
=
"skip"
)
def
test_raise_on_trace
():
step_count
=
0
catch_count
=
0
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录