Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d4bad711
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d4bad711
编写于
8月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add jit.trace
GitOrigin-RevId: ec647324c0e207b6185efe118b61a094c959ce7f
上级
0b88ec3c
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
905 addition
and
72 deletion
+905
-72
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+67
-27
imperative/python/megengine/core/tensor/raw_tensor/__init__.py
...ative/python/megengine/core/tensor/raw_tensor/__init__.py
+14
-6
imperative/python/megengine/jit/__init__.py
imperative/python/megengine/jit/__init__.py
+1
-0
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+514
-0
imperative/python/src/common.cpp
imperative/python/src/common.cpp
+36
-6
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+50
-8
imperative/python/src/graph_rt.h
imperative/python/src/graph_rt.h
+20
-1
imperative/python/src/helper.h
imperative/python/src/helper.h
+4
-1
imperative/python/src/imperative_rt.cpp
imperative/python/src/imperative_rt.cpp
+1
-0
imperative/python/src/pyext17.h
imperative/python/src/pyext17.h
+68
-19
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+65
-0
imperative/src/impl/interpreter_impl.cpp
imperative/src/impl/interpreter_impl.cpp
+9
-0
imperative/src/impl/interpreter_impl.h
imperative/src/impl/interpreter_impl.h
+1
-0
imperative/src/impl/opr_utility.cpp
imperative/src/impl/opr_utility.cpp
+42
-4
imperative/src/include/megbrain/imperative/interpreter.h
imperative/src/include/megbrain/imperative/interpreter.h
+1
-0
imperative/src/include/megbrain/imperative/opr_utility.h
imperative/src/include/megbrain/imperative/opr_utility.h
+12
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
d4bad711
...
...
@@ -17,15 +17,31 @@ from ..ops.builtin import OpDef
from
.core
import
OpBase
,
TensorBase
,
apply
class
CompiledFunction
:
def
__init__
(
self
,
graph
,
function
):
self
.
_graph
=
graph
self
.
_function
=
function
class
Graph
(
_imperative_rt
.
ComputingGraph
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_var_cache
=
weakref
.
WeakKeyDictionary
()
self
.
_op_cache
=
weakref
.
WeakKeyDictionary
()
self
.
_executor
=
ThreadPoolExecutor
(
1
)
self
.
_function
=
None
self
.
_future
=
None
def
_wrap
(
self
,
obj
):
if
type
(
obj
)
is
_imperative_rt
.
VarNode
:
wrapper
,
cache
=
VarNode
,
self
.
_var_cache
elif
type
(
obj
)
is
_imperative_rt
.
OperatorNode
:
wrapper
,
cache
=
OpNode
,
self
.
_op_cache
if
obj
not
in
cache
:
cache
[
obj
]
=
wrapper
(
obj
)
return
cache
[
obj
]
def
compile
(
self
,
*
args
):
self
.
_function
=
super
().
compile
(
_unwrap
(
args
))
return
self
def
execute
(
self
,
*
args
):
assert
self
.
_future
is
None
self
.
_future
=
self
.
_
graph
.
_
executor
.
submit
(
self
.
_function
.
execute
,
*
args
)
self
.
_future
=
self
.
_executor
.
submit
(
self
.
_function
.
execute
,
*
args
)
def
wait
(
self
):
assert
self
.
_future
is
not
None
...
...
@@ -40,30 +56,23 @@ class CompiledFunction:
self
.
execute
(
*
args
)
return
self
.
wait
()
def
make_const
(
self
,
data
,
dtype
=
None
,
device
=
None
):
if
isinstance
(
data
,
_imperative_rt
.
DeviceTensorND
):
assert
dtype
is
None
and
device
is
None
return
self
.
_wrap
(
_imperative_rt
.
make_shared
(
self
,
data
))
else
:
device
=
as_device
(
device
).
to_c
()
return
self
.
_wrap
(
_imperative_rt
.
make_const
(
self
,
data
,
device
,
dtype
))
class
Graph
(
_imperative_rt
.
ComputingGraph
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_var_cache
=
weakref
.
WeakKeyDictionary
()
self
.
_op_cache
=
weakref
.
WeakKeyDictionary
()
self
.
_executor
=
ThreadPoolExecutor
(
1
)
def
_wrap
(
self
,
obj
):
if
type
(
obj
)
is
_imperative_rt
.
VarNode
:
wrapper
,
cache
=
VarNode
,
self
.
_var_cache
elif
type
(
obj
)
is
_imperative_rt
.
OperatorNode
:
wrapper
,
cache
=
OpNode
,
self
.
_op_cache
if
obj
not
in
cache
:
cache
[
obj
]
=
wrapper
(
obj
)
return
cache
[
obj
]
def
compile
(
self
,
*
args
):
return
CompiledFunction
(
self
,
super
().
compile
(
_unwrap
(
args
)))
def
make_input
(
self
,
*
args
:
"VarNode"
,
device
=
None
,
dtype
=
None
,
shape
=
None
):
opnode
=
InputNode
(
*
args
,
device
=
device
,
dtype
=
dtype
,
shape
=
shape
,
graph
=
self
)
return
opnode
.
outputs
[
0
]
class
VarNode
(
TensorBase
):
def
__init__
(
self
,
node
:
_imperative_rt
.
VarNode
):
self
.
_node
=
node
self
.
graph
.
_var_cache
[
node
]
=
self
@
property
def
graph
(
self
)
->
Graph
:
...
...
@@ -81,10 +90,15 @@ class VarNode(TensorBase):
def
device
(
self
):
return
as_device
(
self
.
_node
.
comp_node
)
@
property
def
shape
(
self
):
return
self
.
_node
.
shape
class
OpNode
:
def
__init__
(
self
,
node
:
_imperative_rt
.
OperatorNode
):
self
.
_node
=
node
self
.
graph
.
_op_cache
[
node
]
=
self
@
property
def
graph
(
self
)
->
Graph
:
...
...
@@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode):
return
_wrap
(
outputs
)
def
input_callback
(
callback
,
*
args
,
device
=
None
,
dtype
=
None
,
graph
=
None
):
def
input_callback
(
callback
,
*
args
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
outputs
=
_imperative_rt
.
input_callback
(
callback
,
as_device
(
device
).
to_c
(),
dtype
,
_unwrap
(
args
),
graph
=
graph
callback
,
as_device
(
device
).
to_c
(),
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
)
value
,
dummy
=
_wrap
(
outputs
)
return
value
,
dummy
class
InputNode
(
OpNode
):
def
__init__
(
self
,
*
args
:
VarNode
,
device
=
None
,
dtype
=
None
,
graph
=
None
):
def
__init__
(
self
,
*
args
:
VarNode
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
r
=
_imperative_rt
.
DeviceTensorNDRendezvous
()
if
device
is
not
None
:
device
=
as_device
(
device
).
to_c
()
outputs
=
_imperative_rt
.
input_callback
(
r
,
device
,
dtype
,
_unwrap
(
args
),
graph
=
graph
r
,
device
,
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
)
super
().
__init__
(
outputs
[
0
].
owner
)
self
.
_rendezvous
=
r
...
...
@@ -169,6 +183,29 @@ class OutputNode(OpNode):
def
get_value
(
self
):
return
self
.
_rendezvous
.
get
()
def
drop_value
(
self
):
self
.
_rendezvous
.
drop
()
def
reset
(
self
):
self
.
_rendezvous
.
reset
()
class
ValueOutputNode
(
OpNode
):
def
__init__
(
self
,
var
,
*
args
):
args
=
(
var
,)
+
args
r
=
_imperative_rt
.
HostTensorNDRendezvous
()
dummy
=
_imperative_rt
.
value_output_callback
(
r
,
_unwrap
(
args
))
super
().
__init__
(
dummy
.
owner
)
self
.
_rendezvous
=
r
def
get_value
(
self
):
hostnd
,
event
=
self
.
_rendezvous
.
get
()
event
.
wait
()
return
hostnd
.
numpy
()
def
drop_value
(
self
):
self
.
_rendezvous
.
drop
()
def
reset
(
self
):
self
.
_rendezvous
.
reset
()
...
...
@@ -192,5 +229,8 @@ class AttrOutputNode(OpNode):
attr
=
self
.
_rendezvous
.
get
()
return
TensorAttr
(
attr
.
shape
,
attr
.
dtype
,
as_device
(
attr
.
comp_node
))
def
drop_value
(
self
):
self
.
_rendezvous
.
drop
()
def
reset
(
self
):
self
.
_rendezvous
.
reset
()
imperative/python/megengine/core/tensor/raw_tensor/__init__.py
浏览文件 @
d4bad711
...
...
@@ -31,11 +31,13 @@ class RawTensor(TensorBase):
_init_cb
=
None
_del_cb
=
None
_handle
=
None
def
__init__
(
self
,
handle
):
def
__init__
(
self
,
handle
=
None
):
self
.
_handle
=
handle
if
self
.
_init_cb
:
self
.
_init_cb
()
if
handle
is
not
None
:
if
self
.
_init_cb
:
self
.
_init_cb
()
@
property
def
dtype
(
self
):
...
...
@@ -61,9 +63,10 @@ class RawTensor(TensorBase):
)
def
__del__
(
self
):
if
self
.
_del_cb
:
self
.
_del_cb
()
delete
(
self
.
_handle
)
if
self
.
_handle
is
not
None
:
if
self
.
_del_cb
:
self
.
_del_cb
()
delete
(
self
.
_handle
)
@
apply
.
register
()
...
...
@@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None):
return
as_raw_tensor
(
obj
,
device
=
device
)
@
as_raw_tensor
.
register
(
DeviceTensorND
)
def
_
(
data
:
DeviceTensorND
):
return
RawTensor
(
put
(
data
))
@
as_raw_tensor
.
register
(
np
.
ndarray
)
def
_
(
array
:
np
.
ndarray
,
dtype
=
None
,
device
=
None
):
device
=
None
if
device
is
None
else
as_device
(
device
).
to_c
()
...
...
imperative/python/megengine/jit/__init__.py
0 → 100644
浏览文件 @
d4bad711
from
.tracing
import
exclude_from_trace
,
trace
imperative/python/megengine/jit/tracing.py
0 → 100644
浏览文件 @
d4bad711
import
contextlib
import
functools
import
typing
import
weakref
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.core
import
OpBase
,
apply
from
..core.tensor.raw_tensor
import
OpDef
,
RawTensor
,
as_raw_tensor
class
TraceMismatchError
(
RuntimeError
):
pass
active_trace
=
None
skip_tracing
=
False
@
contextlib
.
contextmanager
def
exclude_from_trace
():
global
skip_tracing
if
skip_tracing
:
yield
return
try
:
skip_tracing
=
True
if
active_trace
is
not
None
:
active_trace
.
_begin_excluded_region
()
yield
finally
:
skip_tracing
=
False
class
TensorInfo
:
__slots__
=
(
# collected attributes
"external"
,
"exported"
,
"data_read"
,
"shape_read"
,
"value_read"
,
"device"
,
"dtype"
,
"bound_data"
,
# resources for execution
"varnode"
,
"data_setter"
,
"shape_reader"
,
"value_reader"
,
"data_reader"
,
)
def
__init__
(
self
):
self
.
exported
=
None
self
.
data_read
=
None
self
.
shape_read
=
None
self
.
value_read
=
None
self
.
bound_data
=
None
self
.
data_setter
=
None
self
.
shape_reader
=
None
self
.
value_reader
=
None
self
.
data_reader
=
None
class
trace
:
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
args
:
return
functools
.
partial
(
cls
,
**
kwargs
)
self
=
super
().
__new__
(
cls
)
self
.
__init__
(
*
args
,
**
kwargs
)
return
self
def
__init__
(
self
,
function
,
symbolic
=
False
,
capture_as_const
=
False
):
self
.
__wrapped__
=
function
self
.
_symbolic
=
symbolic
self
.
_capture_as_const
=
capture_as_const
self
.
_capture_static_shape
=
False
self
.
_untraced
=
True
self
.
_tinfo
=
[]
# handle -> TensorInfo
self
.
_seq
=
[]
self
.
_pc
=
0
self
.
_graph
=
None
self
.
_need_reset_nodes
=
None
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
weakref
.
WeakSet
()
self
.
_active_tensors
=
weakref
.
WeakSet
()
def
_new_handle
(
self
):
handle
=
len
(
self
.
_tinfo
)
info
=
TensorInfo
()
self
.
_tinfo
.
append
(
info
)
return
handle
,
info
def
_apply_op
(
self
,
op
,
args
):
assert
not
self
.
_untraced
# check against trace
if
self
.
_pc
>=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"trace should end here, but more op observed"
)
record
=
self
.
_seq
[
self
.
_pc
]
op_
,
ihandles
,
ohandles
=
record
if
op
!=
op_
:
raise
TraceMismatchError
(
"op different from last time"
)
if
len
(
ihandles
)
!=
len
(
args
):
raise
TraceMismatchError
(
"op input size different from last time"
)
for
h
,
x
in
zip
(
ihandles
,
args
):
info
=
self
.
_tinfo
[
h
]
if
info
.
external
:
if
(
x
.
__class__
is
CompiledTensorProxy
and
not
self
.
_tinfo
[
x
.
_CompiledTensorProxy__handle
].
exported
):
raise
TraceMismatchError
(
"failed to capture: input was an external tensor "
"last time, got an internal tensor this time"
)
if
info
.
bound_data
:
if
x
.
__class__
is
CompiledTensorProxy
:
raise
TraceMismatchError
(
"const capture violated: was an external tensor "
"last time, got an internal tensor this time"
)
if
x
.
_handle
!=
info
.
bound_data
.
_handle
:
raise
TraceMismatchError
(
"const capture violated: got "
"a different tensor this time"
)
else
:
if
info
.
dtype
!=
x
.
dtype
:
raise
TraceMismatchError
(
"failed to capture: different dtype from last time"
)
if
info
.
device
!=
x
.
device
:
raise
TraceMismatchError
(
"failed to capture: different device from last time"
)
info
.
data_setter
.
set_value
(
x
.
_dev_tensor
())
else
:
if
x
.
__class__
is
not
CompiledTensorProxy
:
raise
TraceMismatchError
(
"unexpected capture: trying to use an external tensor as input, "
"but that input was an internal tensor last time"
)
if
x
.
_CompiledTensorProxy__handle
!=
h
:
raise
TraceMismatchError
(
"mis-wiring: input edge to an data flow "
"graph node is different from last time"
)
self
.
_pc
+=
1
outputs
=
tuple
([
CompiledTensorProxy
(
h
)
for
h
in
ohandles
])
self
.
_active_tensors
.
update
(
outputs
)
return
outputs
def
_record_op
(
self
,
op
,
inputs
,
outputs
):
if
skip_tracing
:
for
x
in
inputs
:
h
=
getattr
(
x
,
"_TraceMixin__handle"
,
None
)
if
h
is
not
None
:
self
.
_tinfo
[
h
].
data_read
=
True
return
ihandles
=
[]
for
x
in
inputs
:
h
=
getattr
(
x
,
"_TraceMixin__handle"
,
None
)
if
h
is
None
or
(
not
self
.
_capture_as_const
and
self
.
_tinfo
[
h
].
exported
):
h
,
info
=
self
.
_new_handle
()
info
.
external
=
True
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
if
self
.
_capture_as_const
:
info
.
bound_data
=
x
ihandles
.
append
(
h
)
ohandles
=
[]
for
x
in
outputs
:
h
,
info
=
self
.
_new_handle
()
ohandles
.
append
(
h
)
info
.
external
=
False
TraceMixin
.
_TraceMixin__inject
(
x
,
h
)
self
.
_seq
.
append
((
op
,
tuple
(
ihandles
),
tuple
(
ohandles
)))
self
.
_active_tensors
.
update
(
outputs
)
@
contextlib
.
contextmanager
def
_setup
(
self
):
global
active_trace
if
active_trace
:
raise
NotImplementedError
(
"sorry, not implemented: nested trace"
)
active_trace
=
self
if
self
.
_untraced
:
apply
.
enable
(
apply_with_tracing
)
if
self
.
_symbolic
:
apply
.
enable
(
apply_symbolic_mode
)
self
.
_lazy_eval_graph
=
G
.
Graph
()
else
:
apply
.
enable
(
apply_compiled_mode
)
if
self
.
_graph
is
None
:
self
.
_compile
()
self
.
_graph
.
execute
()
yield
escaped_tensors
=
tuple
(
self
.
_active_tensors
)
self
.
_active_tensors
.
clear
()
if
self
.
_untraced
:
for
x
in
escaped_tensors
:
info
=
self
.
_tinfo
[
x
.
_TraceMixin__handle
]
info
.
data_read
=
True
x
.
_TraceMixin__restore
()
if
self
.
_symbolic
:
# eval lazy eval tensors
lazy_eval_tensors
=
tuple
(
self
.
_lazy_eval_tensors
)
if
lazy_eval_tensors
:
readers
=
[
G
.
OutputNode
(
x
.
_LazyEvalTensor__varnode
).
outputs
[
0
]
for
x
in
lazy_eval_tensors
]
self
.
_lazy_eval_graph
.
compile
(
*
readers
)
self
.
_lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
assign_raw_tensor
(
x
,
as_raw_tensor
(
r
.
op
.
get_value
()))
self
.
_lazy_eval_graph
=
None
self
.
_lazy_eval_tensors
=
None
self
.
_untraced
=
False
else
:
if
self
.
_pc
!=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"premature end"
)
for
x
in
escaped_tensors
:
assign_raw_tensor
(
x
,
as_raw_tensor
(
x
.
_dev_tensor
()))
self
.
_graph
.
wait
()
self
.
_reset_exec_env
()
self
.
_pc
=
0
apply
.
disable
(
apply_with_tracing
)
apply
.
disable
(
apply_symbolic_mode
)
apply
.
disable
(
apply_compiled_mode
)
active_trace
=
None
def
_begin_excluded_region
(
self
):
if
self
.
_untraced
:
# conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read
for
x
in
self
.
_active_tensors
:
info
=
self
.
_tinfo
[
x
.
_TraceMixin__handle
]
info
.
exported
=
True
info
.
data_read
=
True
def
_compile
(
self
):
graph
=
self
.
_graph
=
G
.
Graph
()
# graph.options.graph_opt_level = 0
need_reset_nodes
=
self
.
_need_reset_nodes
=
[]
# links enforce ordering of I/O nodes
links
=
()
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
ivars
=
[]
readers
=
[]
for
h
in
ihandles
:
info
=
self
.
_tinfo
[
h
]
if
not
hasattr
(
info
,
"varnode"
):
assert
info
.
external
if
info
.
bound_data
:
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
_dev_tensor
())
else
:
opnode
=
info
.
data_setter
=
G
.
InputNode
(
*
links
,
device
=
info
.
device
,
dtype
=
info
.
dtype
,
graph
=
graph
)
need_reset_nodes
.
append
(
opnode
)
info
.
varnode
,
*
links
=
opnode
.
outputs
ivars
.
append
(
info
.
varnode
)
ovars
=
apply
(
op
,
*
ivars
)
assert
len
(
ovars
)
==
len
(
ohandles
)
for
h
,
v
in
zip
(
ohandles
,
ovars
):
info
=
self
.
_tinfo
[
h
]
info
.
varnode
=
v
def
add_reader
(
opnode
):
nonlocal
links
need_reset_nodes
.
append
(
opnode
)
readers
.
append
(
opnode
.
outputs
[
0
])
links
=
opnode
.
outputs
if
info
.
data_read
:
# Shape can be obtained from data so doesn't need its own
# output node. On the other hand, value is read separately
# to leverage eager h2d copy
info
.
shape_read
=
False
opnode
=
info
.
data_reader
=
G
.
OutputNode
(
v
,
*
links
)
add_reader
(
opnode
)
if
info
.
value_read
:
opnode
=
info
.
value_reader
=
G
.
ValueOutputNode
(
v
,
*
links
)
add_reader
(
opnode
)
if
info
.
shape_read
:
opnode
=
info
.
shape_reader
=
G
.
AttrOutputNode
(
v
,
*
links
)
add_reader
(
opnode
)
graph
.
compile
(
*
readers
)
def
_reset_exec_env
(
self
):
for
opnode
in
self
.
_need_reset_nodes
:
opnode
.
reset
()
def
_require_shape
(
self
,
handle
):
info
=
self
.
_tinfo
[
handle
]
info
.
shape_read
=
True
def
_require_value
(
self
,
handle
):
info
=
self
.
_tinfo
[
handle
]
info
.
value_read
=
True
def
_require_data
(
self
,
handle
):
info
=
self
.
_tinfo
[
handle
]
info
.
data_read
=
True
def
__call__
(
self
,
*
args
,
**
kwargs
):
with
self
.
_setup
():
return
self
.
__wrapped__
(
*
args
,
**
kwargs
)
class
CompiledTensorProxy
(
RawTensor
):
"""
Duck-typed RawTensor
"""
def
__init__
(
self
,
handle
):
self
.
__handle
=
handle
self
.
__info
=
active_trace
.
_tinfo
[
handle
]
self
.
__shape
=
None
self
.
__data
=
None
self
.
__value
=
None
@
property
def
dtype
(
self
):
return
self
.
__info
.
varnode
.
dtype
@
property
def
device
(
self
):
return
self
.
__info
.
varnode
.
device
@
property
def
shape
(
self
):
if
self
.
__shape
is
None
:
if
self
.
__info
.
shape_read
:
self
.
__shape
=
self
.
__info
.
shape_reader
.
get_value
().
shape
elif
self
.
__info
.
data_read
:
self
.
__shape
=
self
.
_dev_tensor
().
shape
else
:
raise
TraceMismatchError
(
"shape of this tensor is not read in trace"
)
return
self
.
__shape
def
numpy
(
self
):
if
self
.
__value
is
None
:
if
self
.
__info
.
value_read
:
self
.
__value
=
self
.
__info
.
value_reader
.
get_value
()
elif
self
.
__info
.
data_read
:
self
.
__value
=
self
.
_dev_tensor
().
numpy
()
else
:
raise
TraceMismatchError
(
"value of this tensor is not read in trace"
)
return
self
.
__value
def
_dev_tensor
(
self
):
if
self
.
__data
is
None
:
if
not
self
.
__info
.
data_read
:
raise
TraceMismatchError
(
"raw data of this tensor is not read in trace"
)
self
.
__data
=
self
.
__info
.
data_reader
.
get_value
()
return
self
.
__data
def
__del__
(
self
):
if
self
.
__info
.
shape_read
and
self
.
__shape
is
not
None
:
self
.
__info
.
shape_reader
.
drop_value
()
if
self
.
__info
.
value_read
and
self
.
__value
is
not
None
:
self
.
__info
.
value_reader
.
drop_value
()
if
self
.
__info
.
data_read
and
self
.
__data
is
not
None
:
self
.
__info
.
data_reader
.
drop_value
()
class
LazyEvalTensor
(
RawTensor
):
def
__init__
(
self
,
varnode
):
self
.
__varnode
=
varnode
@
property
def
dtype
(
self
):
return
self
.
__varnode
.
dtype
@
property
def
device
(
self
):
return
self
.
__varnode
.
device
@
property
def
shape
(
self
):
return
self
.
__varnode
.
shape
def
numpy
(
self
):
raise
RuntimeError
(
"cannot read value during symbolic tracing"
)
def
_dev_tensor
(
self
):
raise
RuntimeError
(
"cannot access data during symbolic tracing"
)
class
TraceMixin
:
__subclass_cache
=
{}
def
__inject
(
self
,
handle
):
cache
=
__class__
.
__subclass_cache
cls
=
self
.
__class__
subcls
=
cache
.
get
(
cls
)
if
subcls
is
None
:
subcls
=
cache
[
cls
]
=
type
(
"Traced"
+
cls
.
__name__
,
(
__class__
,
cls
),
{})
self
.
__class__
=
subcls
self
.
__handle
=
handle
self
.
__cls
=
cls
return
self
def
__restore
(
self
):
cls
=
self
.
__cls
del
self
.
__handle
del
self
.
__cls
self
.
__class__
=
cls
return
self
@
property
def
shape
(
self
):
if
not
skip_tracing
:
active_trace
.
_require_shape
(
self
.
__handle
)
return
super
().
shape
def
numpy
(
self
):
if
not
skip_tracing
:
active_trace
.
_require_value
(
self
.
__handle
)
return
super
().
numpy
()
def
_dev_tensor
(
self
):
if
not
skip_tracing
:
active_trace
.
_require_data
(
self
.
__handle
)
return
super
().
_dev_tensor
()
class
TracedRawTensor
(
TraceMixin
,
RawTensor
):
pass
class
TracedLazyTensor
(
TraceMixin
,
LazyEvalTensor
):
pass
def
assign_raw_tensor
(
lhs
,
rhs
):
handle
=
rhs
.
_handle
rhs
.
__dict__
.
clear
()
lhs
.
__dict__
.
clear
()
lhs
.
__class__
=
RawTensor
lhs
.
__init__
(
handle
)
# this hook turns RawTensor into LazyEvalTensor
@
apply
.
register
()
def
apply_symbolic_mode
(
op
:
OpDef
,
*
args
:
RawTensor
):
graph
=
active_trace
.
_lazy_eval_graph
ivars
=
[
getattr
(
x
,
"_LazyEvalTensor__varnode"
,
None
)
or
graph
.
make_const
(
x
.
_dev_tensor
())
for
x
in
args
]
ovars
=
apply
(
op
,
*
ivars
)
outputs
=
[
LazyEvalTensor
(
v
)
for
v
in
ovars
]
active_trace
.
_lazy_eval_tensors
.
update
(
outputs
)
return
outputs
apply
.
disable
(
apply_symbolic_mode
)
@
apply
.
register
()
def
apply_compiled_mode
(
op
:
OpDef
,
*
args
:
RawTensor
):
if
skip_tracing
:
args
=
[
as_raw_tensor
(
x
.
_dev_tensor
())
if
x
.
__class__
is
CompiledTensorProxy
else
x
for
x
in
args
]
return
apply
.
super
(
op
,
*
args
)
return
active_trace
.
_apply_op
(
op
,
args
)
apply
.
disable
(
apply_compiled_mode
)
# this hook injects TraceMixin
@
apply
.
register
()
def
apply_with_tracing
(
op
:
OpDef
,
*
args
:
RawTensor
):
outputs
=
apply
.
super
(
op
,
*
args
)
active_trace
.
_record_op
(
op
,
args
,
outputs
)
return
outputs
apply
.
disable
(
apply_with_tracing
)
# @apply.register()
# def _(op: Const, *args: RawTensor):
# return active_trace._apply_const(op, args)
class
BrokenRawTensor
(
RawTensor
):
def
__getattribute__
(
self
,
_
):
raise
RuntimeError
(
"broken due to misuse of tracing"
)
def
__setattr__
(
self
,
*
_
):
raise
RuntimeError
(
"broken due to misuse of tracing"
)
imperative/python/src/common.cpp
浏览文件 @
d4bad711
...
...
@@ -23,10 +23,29 @@ namespace py = pybind11;
using
namespace
mgb
;
using
namespace
imperative
;
namespace
{
template
<
typename
XTensorND
>
auto
def_TensorND
(
py
::
object
parent
,
const
char
*
name
)
{
return
py
::
class_
<
XTensorND
>
(
parent
,
name
)
.
def_property_readonly
(
"shape"
,
py
::
overload_cast
<>
(
&
XTensorND
::
shape
,
py
::
const_
))
.
def_property_readonly
(
"dtype"
,
py
::
overload_cast
<>
(
&
XTensorND
::
dtype
,
py
::
const_
))
.
def_property_readonly
(
"comp_node"
,
py
::
overload_cast
<>
(
&
XTensorND
::
comp_node
,
py
::
const_
))
.
def
(
"copy_from"
,
&
XTensorND
::
template
copy_from
<
DeviceTensorStorage
>)
.
def
(
"copy_from"
,
&
XTensorND
::
template
copy_from
<
HostTensorStorage
>)
.
def
(
"copy_from_fixlayout"
,
py
::
overload_cast
<
const
DeviceTensorND
&>
(
&
XTensorND
::
template
copy_from_fixlayout
<
DeviceTensorStorage
>))
.
def
(
"copy_from_fixlayout"
,
py
::
overload_cast
<
const
HostTensorND
&>
(
&
XTensorND
::
template
copy_from_fixlayout
<
HostTensorStorage
>));
}
}
// namespace
void
init_common
(
py
::
module
m
)
{
py
::
class_
<
CompNode
>
(
m
,
"CompNode"
)
auto
&&
PyCompNode
=
py
::
class_
<
CompNode
>
(
m
,
"CompNode"
)
.
def
(
py
::
init
())
.
def
(
py
::
init
(
py
::
overload_cast
<
const
std
::
string
&>
(
&
CompNode
::
load
)))
.
def
(
"create_event"
,
&
CompNode
::
create_event
,
py
::
arg
(
"flags"
)
=
0ul
)
.
def
(
"__str__"
,
&
CompNode
::
to_string_logical
)
.
def_static
(
"_sync_all"
,
&
CompNode
::
sync_all
)
.
def
(
py
::
self
==
py
::
self
)
...
...
@@ -40,19 +59,30 @@ void init_common(py::module m) {
return
CompNode
::
load
(
cn
);
}));
py
::
class_
<
CompNode
::
Event
,
std
::
shared_ptr
<
CompNode
::
Event
>>
(
PyCompNode
,
"Event"
)
.
def
(
"record"
,
&
CompNode
::
Event
::
record
)
.
def
(
"wait"
,
&
CompNode
::
Event
::
host_wait
);
py
::
implicitly_convertible
<
std
::
string
,
CompNode
>
();
py
::
class_
<
DeviceTensorND
>
(
m
,
"DeviceTensorND"
)
.
def
(
py
::
init
())
.
def_property_readonly
(
"shape"
,
py
::
overload_cast
<>
(
&
DeviceTensorND
::
shape
,
py
::
const_
))
.
def_property_readonly
(
"dtype"
,
py
::
overload_cast
<>
(
&
DeviceTensorND
::
dtype
,
py
::
const_
))
.
def_property_readonly
(
"comp_node"
,
py
::
overload_cast
<>
(
&
DeviceTensorND
::
comp_node
,
py
::
const_
))
def_TensorND
<
DeviceTensorND
>
(
m
,
"DeviceTensorND"
)
.
def
(
"numpy"
,
[](
const
DeviceTensorND
&
self
)
{
HostTensorND
hv
;
hv
.
copy_from
(
self
).
sync
();
return
py
::
handle
(
npy
::
ndarray_from_tensor
(
hv
,
npy
::
ShareType
::
TRY_SHARE
));
});
def_TensorND
<
HostTensorND
>
(
m
,
"HostTensorND"
)
.
def
(
py
::
init
([](
py
::
array
data
,
CompNode
cn
,
DType
dtype
)
{
if
(
!
cn
.
valid
())
{
throw
py
::
type_error
(
"device must not be None"
);
}
return
npy
::
np2tensor
(
data
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
);
}))
.
def
(
"numpy"
,
[](
const
HostTensorND
&
self
)
{
return
py
::
reinterpret_steal
<
py
::
object
>
(
npy
::
ndarray_from_tensor
(
self
,
npy
::
ShareType
::
TRY_SHARE
));
});
py
::
class_
<
cg
::
OperatorNodeConfig
>
(
m
,
"OperatorNodeConfig"
)
.
def
(
py
::
init
())
.
def_property
(
"name"
,
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
d4bad711
...
...
@@ -12,6 +12,7 @@
#include "./graph_rt.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/imperative.h"
#include "./helper.h"
...
...
@@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) {
.
def
(
py
::
init
([](){
return
std
::
make_shared
<
Rendezvous
<
T
>>
();}))
.
def
(
"set"
,
[](
Rendezvous
<
T
>&
r
,
T
v
)
{
r
.
set
(
std
::
move
(
v
));})
.
def
(
"get"
,
[](
Rendezvous
<
T
>&
r
)
{
return
r
.
get
();},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"drop"
,
&
Rendezvous
<
T
>::
drop
)
.
def
(
"reset"
,
&
Rendezvous
<
T
>::
reset
);
}
using
TensorAttr
=
LogicalTensorDesc
;
using
HostNDWithEvent
=
std
::
pair
<
HostTensorND
,
std
::
shared_ptr
<
CompNode
::
Event
>>
;
void
init_graph_rt
(
py
::
module
m
)
{
def_rendezvous
<
DeviceTensorND
>
(
m
,
"DeviceTensorNDRendezvous"
);
def_rendezvous
<
HostNDWithEvent
>
(
m
,
"HostTensorNDRendezvous"
);
def_rendezvous
<
TensorAttr
>
(
m
,
"TensorAttrRendezvous"
);
py
::
class_
<
cg
::
VarNode
,
GraphNodePtr
<
cg
::
VarNode
>>
(
m
,
"VarNode"
)
.
def_property_readonly
(
"owner"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_opr
();})
.
def_property_readonly
(
"graph"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_graph
();})
.
def_property_readonly
(
"name"
,
py
::
overload_cast
<>
(
&
VarNode
::
name
,
py
::
const_
))
.
def_property_readonly
(
"dtype"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
dtype
();})
.
def_property_readonly
(
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();});
.
def_property_readonly
(
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();})
.
def_property_readonly
(
"shape"
,
[](
cg
::
VarNode
*
v
)
->
const
TensorShape
*
{
auto
&&
mgr
=
v
->
owner_graph
()
->
static_infer_manager
();
auto
&&
type
=
mgr
.
get_infer_type
(
v
);
using
InferType
=
cg
::
static_infer
::
InferType
;
if
(
!
(
type
.
shape
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
)))
{
return
nullptr
;
}
return
mgr
.
infer_shape_fallible
(
v
);
});
py
::
class_
<
cg
::
OperatorNodeBase
,
GraphNodePtr
<
cg
::
OperatorNodeBase
>>
(
m
,
"OperatorNode"
)
.
def_property_readonly
(
"graph"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
opr
->
owner_graph
();})
.
def_property_readonly
(
"name"
,
py
::
overload_cast
<>
(
&
cg
::
OperatorNodeBase
::
name
,
py
::
const_
))
.
def_property_readonly
(
"inputs"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
to_tuple
(
opr
->
input
());
})
.
def_property_readonly
(
"outputs"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
to_tuple
(
opr
->
output
());
return
to_tuple
(
opr
->
usable_
output
());
});
py
::
class_
<
cg
::
AsyncExecutable
>
(
m
,
"AsyncExecutable"
)
...
...
@@ -117,7 +133,7 @@ void init_graph_rt(py::module m) {
common
.
def
(
"invoke_op"
,
[](
const
OpDef
&
def
,
const
std
::
vector
<
cg
::
VarNode
*>
inputs
,
cg
::
ComputingGraph
*
graph
)
{
cg
::
VarNodeArray
vinputs
(
inputs
.
begin
(),
inputs
.
end
());
auto
opr
=
OpDef
::
apply_on_var_node
(
def
,
vinputs
);
auto
outputs
=
opr
->
output
();
auto
outputs
=
opr
->
usable_
output
();
return
to_tuple
(
outputs
);
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
...
...
@@ -125,6 +141,7 @@ void init_graph_rt(py::module m) {
auto
input_callback
=
[](
auto
callback
,
const
CompNode
&
comp_node
,
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
if
(
!
graph
)
{
...
...
@@ -135,7 +152,7 @@ void init_graph_rt(py::module m) {
sinputs
.
emplace_back
(
i
);
}
static_assert
(
!
std
::
is_reference
<
decltype
(
callback
)
>::
value
);
auto
soutputs
=
opr
::
InputCallback
::
make
(
*
graph
,
std
::
move
(
callback
),
comp_node
,
dtype
,
sinputs
);
auto
soutputs
=
opr
::
InputCallback
::
make
(
*
graph
,
std
::
move
(
callback
),
comp_node
,
dtype
,
s
hape
,
s
inputs
);
std
::
vector
<
VarNode
*>
outputs
;
outputs
.
reserve
(
soutputs
.
size
());
for
(
auto
i
:
soutputs
)
{
...
...
@@ -144,26 +161,40 @@ void init_graph_rt(py::module m) {
return
outputs
;
};
m
.
def
(
"make_shared"
,
[](
cg
::
ComputingGraph
*
graph
,
const
DeviceTensorND
&
data
)
{
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
std
::
make_shared
<
DeviceTensorND
>
(
data
)).
node
();
});
m
.
def
(
"make_const"
,
[](
cg
::
ComputingGraph
*
graph
,
py
::
array
data
,
CompNode
cn
,
DType
dtype
)
{
if
(
!
cn
.
valid
())
{
throw
py
::
type_error
(
"device must not be None"
);
}
auto
hv
=
npy
::
np2tensor
(
data
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
);
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
OperatorNodeConfig
(
cn
)).
node
();
});
m
.
def
(
"input_callback"
,
[
input_callback
](
std
::
function
<
DeviceTensorND
(
void
)
>
callback
,
const
CompNode
&
comp_node
,
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
return
input_callback
([
f
=
std
::
move
(
callback
)](){
py
::
gil_scoped_acquire
_
;
return
f
();},
comp_node
,
dtype
,
inputs
,
graph
);
return
input_callback
([
f
=
std
::
move
(
callback
)](){
py
::
gil_scoped_acquire
_
;
return
f
();},
comp_node
,
dtype
,
shape
,
inputs
,
graph
);
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
m
.
def
(
"input_callback"
,
[
input_callback
](
std
::
shared_ptr
<
Rendezvous
<
DeviceTensorND
>>
p
,
const
CompNode
&
comp_node
,
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
auto
f
=
[
p
]()
->
DeviceTensorND
{
return
p
->
get
();
};
return
input_callback
(
std
::
move
(
f
),
comp_node
,
dtype
,
inputs
,
graph
);
return
input_callback
(
std
::
move
(
f
),
comp_node
,
dtype
,
shape
,
inputs
,
graph
);
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
auto
output_callback
=
[](
auto
callback
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
bool
borrow
=
false
)
{
SymbolVarArray
sinputs
;
...
...
@@ -193,6 +224,17 @@ void init_graph_rt(py::module m) {
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
));
});
m
.
def
(
"value_output_callback"
,
[
output_callback
](
std
::
shared_ptr
<
Rendezvous
<
HostNDWithEvent
>>
p
,
std
::
vector
<
cg
::
VarNode
*>
inputs
)
{
auto
f
=
[
p
](
DeviceTensorND
dv
)
{
HostNDWithEvent
hv_with_event
;
hv_with_event
.
first
.
copy_from
(
dv
);
hv_with_event
.
second
=
dv
.
comp_node
().
create_event
();
hv_with_event
.
second
->
record
();
p
->
set
(
std
::
move
(
hv_with_event
));
};
return
output_callback
(
std
::
move
(
f
),
std
::
move
(
inputs
),
true
);
});
m
.
def
(
"attr_output_callback"
,
[
output_callback
](
std
::
shared_ptr
<
Rendezvous
<
TensorAttr
>>
p
,
std
::
vector
<
cg
::
VarNode
*>
inputs
)
{
auto
f
=
[
p
](
DeviceTensorND
dv
)
{
p
->
set
(
TensorAttr
{
TensorLayout
{
dv
.
shape
(),
dv
.
dtype
()},
dv
.
comp_node
()});
...
...
imperative/python/src/graph_rt.h
浏览文件 @
d4bad711
...
...
@@ -39,6 +39,7 @@ template<typename R>
class
Rendezvous
{
std
::
mutex
m_lock
;
int
m_read_ahead
=
0
;
bool
m_drop_next
=
false
;
std
::
promise
<
R
>
m_promise
;
public:
Rendezvous
()
=
default
;
...
...
@@ -47,6 +48,7 @@ public:
Rendezvous
&
operator
=
(
const
Rendezvous
&
rhs
)
=
delete
;
Rendezvous
&
operator
=
(
Rendezvous
&&
rhs
)
{
MGB_LOCK_GUARD
(
m_lock
);
m_drop_next
=
rhs
.
m_drop_next
;
m_read_ahead
=
rhs
.
m_read_ahead
;
m_promise
=
std
::
move
(
rhs
.
m_promise
);
return
*
this
;
...
...
@@ -67,12 +69,28 @@ public:
return
f
.
get
();
}
void
drop
()
{
MGB_LOCK_GUARD
(
m_lock
);
mgb_assert
(
m_read_ahead
<=
0
);
mgb_assert
(
m_read_ahead
>=
-
1
);
if
(
m_read_ahead
==
-
1
)
{
m_promise
=
{};
}
else
{
m_drop_next
=
true
;
}
++
m_read_ahead
;
}
template
<
typename
T
>
void
set
(
T
&&
value
)
{
MGB_LOCK_GUARD
(
m_lock
);
mgb_assert
(
m_read_ahead
>=
0
);
mgb_assert
(
m_read_ahead
<=
1
);
m_promise
.
set_value
(
std
::
forward
<
T
>
(
value
));
if
(
m_drop_next
)
{
m_drop_next
=
false
;
}
else
{
m_promise
.
set_value
(
std
::
forward
<
T
>
(
value
));
}
if
(
m_read_ahead
==
1
)
{
m_promise
=
{};
}
...
...
@@ -83,6 +101,7 @@ public:
MGB_LOCK_GUARD
(
m_lock
);
m_promise
=
{};
m_read_ahead
=
0
;
m_drop_next
=
false
;
}
};
...
...
imperative/python/src/helper.h
浏览文件 @
d4bad711
...
...
@@ -280,9 +280,12 @@ namespace detail {
public:
bool
load
(
handle
src
,
bool
convert
)
{
auto
obj
=
reinterpret_steal
<
object
>
(
src
);
if
(
!
isinstance
<
tuple
>
(
obj
))
{
if
(
!
convert
&&
!
isinstance
<
tuple
>
(
obj
))
{
return
false
;
}
if
(
obj
.
is_none
())
{
return
true
;
}
value
.
ndim
=
len
(
obj
);
mgb_assert
(
value
.
ndim
<=
mgb
::
TensorShape
::
MAX_NDIM
);
size_t
i
=
0
;
...
...
imperative/python/src/imperative_rt.cpp
浏览文件 @
d4bad711
...
...
@@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) {
return
self
.
put
(
npy
::
np2tensor
(
data
.
ptr
(),
npy
::
Meth
::
copy_into
(
&
ret
),
dtype
));
}
},
py
::
arg
(),
py
::
arg
(
"dtype"
)
=
py
::
none
(),
py
::
arg
(
"device"
)
=
py
::
none
())
.
def
(
"put"
,
py
::
overload_cast
<
const
DeviceTensorND
&>
(
&
Interpreter
::
Channel
::
put
))
.
def
(
"delete"
,
[](
Interpreter
::
Channel
&
self
,
Interpreter
::
Handle
handle
)
{
return
self
.
del
(
handle
);
})
...
...
imperative/python/src/pyext17.h
浏览文件 @
d4bad711
...
...
@@ -24,6 +24,12 @@ constexpr bool has_fastcall = true;
constexpr
bool
has_fastcall
=
false
;
#endif
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
constexpr
bool
has_vectorcall
=
true
;
#else
constexpr
bool
has_vectorcall
=
false
;
#endif
template
<
typename
...
Args
>
struct
invocable_with
{
template
<
typename
T
>
...
...
@@ -55,6 +61,9 @@ private:
public:
PyObject_HEAD
std
::
aligned_storage_t
<
sizeof
(
T
),
alignof
(
T
)
>
storage
;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
PyObject
*
vectorcall_slot
;
#endif
inline
T
*
inst
()
{
return
reinterpret_cast
<
T
*>
(
&
storage
);
...
...
@@ -155,6 +164,51 @@ private:
// polyfills
struct
tp_vectorcall
{
static
constexpr
bool
valid
=
HAS_MEMBER
(
T
,
tp_vectorcall
);
static
constexpr
bool
haskw
=
[](){
if
constexpr
(
valid
)
if
constexpr
(
std
::
is_invocable_v
<
T
::
tp_vectorcall
,
T
,
PyObject
*
const
*
,
size_t
,
PyObject
*>
)
return
true
;
return
false
;}();
template
<
typename
=
void
>
static
PyObject
*
impl
(
PyObject
*
self
,
PyObject
*
const
*
args
,
size_t
nargsf
,
PyObject
*
kwnames
)
{
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
if
constexpr
(
haskw
)
{
CVT_RET_PYOBJ
(
inst
->
tp_vectorcall
(
args
,
nargsf
,
kwnames
));
}
else
{
if
(
kwnames
&&
PyTuple_GET_SIZE
(
kwnames
))
{
PyErr_SetString
(
PyExc_TypeError
,
"expect no keyword argument"
);
return
nullptr
;
}
CVT_RET_PYOBJ
(
inst
->
tp_vectorcall
(
args
,
nargsf
));
}
}
static
constexpr
Py_ssize_t
offset
=
[]()
{
if
constexpr
(
valid
)
return
offsetof
(
wrap_t
,
vectorcall_slot
);
else
return
0
;}();
};
struct
tp_call
{
static
constexpr
bool
provided
=
HAS_MEMBER
(
T
,
tp_call
);
static
constexpr
bool
static_form
=
invocable_with
<
T
,
PyObject
*
,
PyObject
*
,
PyObject
*>
{}(
[](
auto
&&
t
,
auto
...
args
)
->
decltype
(
std
::
decay_t
<
decltype
(
t
)
>::
tp_call
(
args
...))
{});
static
constexpr
bool
valid
=
provided
||
tp_vectorcall
::
valid
;
template
<
typename
=
void
>
static
PyObject
*
impl
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
CVT_RET_PYOBJ
(
inst
->
tp_call
(
args
,
kwargs
));
}
static
constexpr
ternaryfunc
value
=
[]()
{
if
constexpr
(
static_form
)
return
T
::
tp_call
;
else
if
constexpr
(
provided
)
return
impl
<>
;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
else
if
constexpr
(
valid
)
return
PyVectorcall_Call
;
#endif
else
return
nullptr
;}();
};
struct
tp_new
{
static
constexpr
bool
provided
=
HAS_MEMBER
(
T
,
tp_new
);
static
constexpr
bool
varkw
=
std
::
is_constructible_v
<
T
,
PyObject
*
,
PyObject
*>
;
...
...
@@ -163,11 +217,14 @@ private:
template
<
typename
=
void
>
static
PyObject
*
impl
(
PyTypeObject
*
type
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
auto
*
self
=
type
->
tp_alloc
(
type
,
0
);
auto
*
ptr
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
if
constexpr
(
has_vectorcall
&&
tp_vectorcall
::
valid
)
{
reinterpret_cast
<
wrap_t
*>
(
self
)
->
vectorcall_slot
=
&
tp_vectorcall
::
template
impl
<
>;
}
if
constexpr
(
varkw
)
{
new
(
ptr
)
T
(
args
,
kwargs
);
new
(
inst
)
T
(
args
,
kwargs
);
}
else
{
new
(
ptr
)
T
();
new
(
inst
)
T
();
}
return
self
;
}
...
...
@@ -190,22 +247,6 @@ private:
else
return
impl
<>
;}();
};
struct
tp_call
{
static
constexpr
bool
valid
=
HAS_MEMBER
(
T
,
tp_call
);
static
constexpr
bool
static_form
=
invocable_with
<
T
,
PyObject
*
,
PyObject
*
,
PyObject
*>
{}(
[](
auto
&&
t
,
auto
...
args
)
->
decltype
(
std
::
decay_t
<
decltype
(
t
)
>::
tp_call
(
args
...))
{});
template
<
typename
=
void
>
static
PyObject
*
impl
(
PyObject
*
self
,
PyObject
*
args
,
PyObject
*
kwargs
)
{
auto
*
inst
=
reinterpret_cast
<
wrap_t
*>
(
self
)
->
inst
();
CVT_RET_PYOBJ
(
inst
->
tp_call
(
args
,
kwargs
));
}
static
constexpr
ternaryfunc
value
=
[]()
{
if
constexpr
(
static_form
)
return
T
::
tp_call
;
else
if
constexpr
(
valid
)
return
impl
<>
;
else
return
nullptr
;}();
};
public:
class
TypeBuilder
{
std
::
vector
<
PyMethodDef
>
m_methods
;
...
...
@@ -228,9 +269,17 @@ public:
m_type
.
tp_name
=
T
::
tp_name
;
}
m_type
.
tp_dealloc
=
tp_dealloc
::
value
;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
m_type
.
tp_vectorcall_offset
=
tp_vectorcall
::
offset
;
#endif
m_type
.
tp_call
=
tp_call
::
value
;
m_type
.
tp_basicsize
=
sizeof
(
wrap_t
);
m_type
.
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
;
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL
if
constexpr
(
tp_vectorcall
::
valid
)
{
m_type
.
tp_flags
|=
_Py_TPFLAGS_HAVE_VECTORCALL
;
}
#endif
m_type
.
tp_new
=
tp_new
::
value
;
}
...
...
imperative/python/test/unit/test_tracing.py
0 → 100644
浏览文件 @
d4bad711
import
numpy
as
np
from
megengine.core.ops
import
builtin
as
ops
from
megengine.core.tensor.core
import
apply
from
megengine.core.tensor.raw_tensor
import
as_raw_tensor
from
megengine.jit
import
exclude_from_trace
,
trace
def
test_trace
():
for
symbolic
in
[
False
,
True
]:
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
op
=
ops
.
Elemwise
(
mode
=
"negate"
)
(
y
,)
=
apply
(
op
,
x
)
return
y
x
=
as_raw_tensor
([
1
]).
numpy
()
y
=
f
.
__wrapped__
(
as_raw_tensor
(
x
)).
numpy
()
for
i
in
range
(
3
):
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
def
test_exclude_from_trace
():
for
symbolic
in
[
False
,
True
]:
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
neg
=
ops
.
Elemwise
(
mode
=
"negate"
)
(
x
,)
=
apply
(
neg
,
x
)
with
exclude_from_trace
():
if
i
%
2
:
(
x
,)
=
apply
(
neg
,
x
)
(
x
,)
=
apply
(
neg
,
x
)
return
x
x
=
as_raw_tensor
([
1
]).
numpy
()
for
i
in
range
(
3
):
y
=
f
.
__wrapped__
(
as_raw_tensor
(
x
)).
numpy
()
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
def
test_print_in_trace
():
for
symbolic
in
[
False
]:
# cannot read value in symbolic mode
@
trace
(
symbolic
=
symbolic
)
def
f
(
x
):
nonlocal
buf
neg
=
ops
.
Elemwise
(
mode
=
"negate"
)
(
x
,)
=
apply
(
neg
,
x
)
buf
=
x
.
numpy
()
(
x
,)
=
apply
(
neg
,
x
)
return
x
buf
=
None
x
=
as_raw_tensor
([
1
]).
numpy
()
for
i
in
range
(
3
):
y
=
f
.
__wrapped__
(
as_raw_tensor
(
x
)).
numpy
()
z
=
buf
buf
=
None
np
.
testing
.
assert_equal
(
f
(
as_raw_tensor
(
x
)).
numpy
(),
y
)
np
.
testing
.
assert_equal
(
z
,
buf
)
imperative/src/impl/interpreter_impl.cpp
浏览文件 @
d4bad711
...
...
@@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) {
return
info
;
}
void
*
ChannelImpl
::
put
(
const
DeviceTensorND
&
data
)
{
auto
info
=
alloc
();
info
->
desc
.
layout
=
data
.
layout
();
info
->
desc
.
comp_node
=
data
.
comp_node
();
info
->
ptr
=
Tensor
::
make
(
data
);
m_valid_handle
.
insert
(
info
);
return
info
;
}
void
ChannelImpl
::
del
(
void
*
handle
)
{
mgb_assert
(
m_valid_handle
.
erase
(
handle
),
"invalid handle: %p"
,
handle
);
m_worker
.
add_task
(
Del
{
reinterpret_cast
<
TensorInfo
*>
(
handle
)});
...
...
imperative/src/impl/interpreter_impl.h
浏览文件 @
d4bad711
...
...
@@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel {
~
ChannelImpl
()
override
;
Handle
put
(
const
HostTensorND
&
value
)
override
;
Handle
put
(
const
DeviceTensorND
&
value
)
override
;
void
del
(
Handle
)
override
;
...
...
imperative/src/impl/opr_utility.cpp
浏览文件 @
d4bad711
...
...
@@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback);
InputCallback
::
InputCallback
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
const
VarNodeArray
&
inputs
,
const
TensorShape
&
output_shape
,
const
OperatorNodeConfig
&
config
)
:
Super
(
&
graph
,
config
,
"input_callback"
,
inputs
),
m_callback
(
callback
)
{
m_
output_shape
(
output_shape
),
m_
callback
(
callback
)
{
for
(
VarNode
*
i
:
inputs
)
{
add_input
({
i
});
}
...
...
@@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
SymbolVarArray
InputCallback
::
make
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
CompNode
comp_node
,
DType
dtype
,
const
SymbolVarArray
&
inputs
)
{
DType
dtype
,
const
TensorShape
&
shape
,
const
SymbolVarArray
&
inputs
)
{
mgb_assert
(
comp_node
.
valid
());
mgb_assert
(
dtype
.
valid
());
OperatorNodeConfig
config
;
...
...
@@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
config
.
output_dtype
(
dtype
);
auto
vinputs
=
to_var_node_array
(
inputs
);
auto
opr
=
graph
.
insert_opr
(
std
::
make_unique
<
InputCallback
>
(
graph
,
callback
,
vinputs
,
config
));
std
::
make_unique
<
InputCallback
>
(
graph
,
callback
,
vinputs
,
shape
,
config
));
return
to_symbol_var_array
(
opr
->
output
());
}
void
InputCallback
::
init_output_static_infer_desc
()
{}
void
InputCallback
::
init_output_static_infer_desc
()
{
if
(
m_output_shape
.
ndim
)
{
using
namespace
cg
::
static_infer
;
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
auto
infer_shape
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
)
{
dest
=
m_output_shape
;
return
true
;
};
mgr
.
register_shape_infer
(
output
(
0
),
{
SourceType
::
CONSTANT
,
{},
infer_shape
});
}
}
cg
::
OperatorNodeBase
::
NodeProp
*
InputCallback
::
do_make_node_prop
()
const
{
NodeProp
*
prop
=
Super
::
do_make_node_prop
();
...
...
@@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
void
InputCallback
::
scn_do_execute
()
{
auto
dev_tensor
=
m_callback
();
if
(
m_output_shape
.
ndim
)
{
mgb_assert
(
dev_tensor
.
shape
().
eq_shape
(
m_output_shape
));
}
output
(
0
)
->
reset_dev_tensor_from_tensor
(
dev_tensor
);
}
cg
::
OperatorNodeBase
*
InputCallback
::
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
InputCallback
>
();
auto
*
graph
=
ctx
.
owner_graph
(
opr
,
inputs
);
return
graph
->
insert_opr
(
std
::
make_unique
<
InputCallback
>
(
*
graph
,
opr
.
m_callback
,
inputs
,
opr
.
m_output_shape
,
config
));
}
MGB_REG_OPR_SHALLOW_COPY
(
InputCallback
,
InputCallback
::
shallow_copy
);
/* ================ OutputCallback ================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
OutputCallback
);
...
...
@@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() {
m_param
.
callback
(
input
(
0
)
->
dev_tensor
());
}
cg
::
OperatorNodeBase
*
OutputCallback
::
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
OutputCallback
>
();
auto
*
graph
=
ctx
.
owner_graph
(
opr
,
inputs
);
return
graph
->
insert_opr
(
std
::
make_unique
<
OutputCallback
>
(
opr
.
m_param
,
inputs
,
config
));
}
MGB_REG_OPR_SHALLOW_COPY
(
OutputCallback
,
OutputCallback
::
shallow_copy
);
/* ================ NopCallback ================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
NopCallback
);
...
...
imperative/src/include/megbrain/imperative/interpreter.h
浏览文件 @
d4bad711
...
...
@@ -22,6 +22,7 @@ struct Interpreter {
virtual
~
Channel
()
=
default
;
virtual
Handle
put
(
const
HostTensorND
&
value
)
=
0
;
virtual
Handle
put
(
const
DeviceTensorND
&
value
)
=
0
;
virtual
void
del
(
Handle
)
=
0
;
...
...
imperative/src/include/megbrain/imperative/opr_utility.h
浏览文件 @
d4bad711
...
...
@@ -17,6 +17,7 @@
#include "megbrain/opr/internal/param_tag_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/oprs/utils.h"
...
...
@@ -33,17 +34,24 @@ public:
InputCallback
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
const
VarNodeArray
&
inputs
,
const
TensorShape
&
output_shape
,
const
OperatorNodeConfig
&
config
);
static
SymbolVarArray
make
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
CompNode
comp_node
,
DType
dtype
,
const
TensorShape
&
shape
,
const
SymbolVarArray
&
inputs
=
{});
static
cg
::
OperatorNodeBase
*
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
);
protected:
void
scn_do_execute
()
override
;
void
init_output_static_infer_desc
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
private:
TensorShape
m_output_shape
;
callback_t
m_callback
;
};
...
...
@@ -63,6 +71,10 @@ public:
SymbolVar
input
)
{
return
make
(
std
::
move
(
param
),
SymbolVarArray
{
input
});
}
static
cg
::
OperatorNodeBase
*
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
);
protected:
void
scn_do_execute
()
override
;
void
init_output_static_infer_desc
()
override
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录