Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cc85047b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cc85047b
编写于
11月 19, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/trace): fix sublinear in trace
GitOrigin-RevId: 356dcd9523fde3c041deb590a9cd7b19ec31a918
上级
b9918c32
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
224 addition
and
49 deletion
+224
-49
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+24
-3
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+135
-12
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+17
-13
imperative/python/test/integration/test_optimizer.py
imperative/python/test/integration/test_optimizer.py
+3
-1
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+2
-0
imperative/src/impl/opr_utility.cpp
imperative/src/impl/opr_utility.cpp
+38
-18
imperative/src/include/megbrain/imperative/opr_utility.h
imperative/src/include/megbrain/imperative/opr_utility.h
+5
-2
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
cc85047b
...
...
@@ -74,6 +74,11 @@ class Graph(_imperative_rt.ComputingGraph):
self
.
execute
(
*
args
)
return
self
.
wait
()
def
_make_const_for_backward
(
self
,
data
):
device
=
as_device
(
data
.
comp_node
).
to_c
()
data
=
data
.
numpy
()
return
self
.
_wrap
(
_imperative_rt
.
make_const
(
self
,
data
,
device
,
data
.
dtype
))
def
make_const
(
self
,
data
,
dtype
=
None
,
device
=
None
):
if
isinstance
(
data
,
_imperative_rt
.
DeviceTensorND
):
assert
dtype
is
None
and
device
is
None
...
...
@@ -437,7 +442,9 @@ def _(op: OpDef, *args: VarNode):
def
_
(
op
:
BackwardGraph
,
*
args
:
VarNode
):
assert
args
graph
=
args
[
0
].
graph
return
op
.
interpret
(
lambda
op
,
args
:
apply
(
op
,
*
args
),
graph
.
make_const
,
args
)
return
op
.
interpret
(
lambda
op
,
args
:
apply
(
op
,
*
args
),
graph
.
_make_const_for_backward
,
args
)
def
input_callback
(
callback
,
*
args
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
...
...
@@ -449,12 +456,26 @@ def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=N
class
InputNode
(
OpNode
):
def
__init__
(
self
,
*
args
:
VarNode
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
def
__init__
(
self
,
*
args
:
VarNode
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
,
use_static_shape
=
False
):
r
=
_imperative_rt
.
DeviceTensorNDRendezvous
()
if
device
is
not
None
:
device
=
as_device
(
device
).
to_c
()
outputs
=
_imperative_rt
.
input_callback
(
r
,
device
,
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
r
,
device
,
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
,
use_static_shape
=
use_static_shape
,
)
super
().
__init__
(
outputs
[
0
].
owner
)
self
.
_rendezvous
=
r
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
cc85047b
...
...
@@ -11,6 +11,7 @@ import contextlib
import
functools
import
itertools
import
json
import
os
import
typing
import
warnings
import
weakref
...
...
@@ -35,6 +36,10 @@ from ..core.tensor.tensor import Tensor
from
.sublinear_memory_config
import
SublinearMemoryConfig
def
_input_node_use_static_shape
():
return
os
.
environ
.
get
(
"MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE"
)
is
not
None
class
TraceMismatchError
(
RuntimeError
):
pass
...
...
@@ -76,6 +81,7 @@ class TensorInfo:
"device"
,
"dtype"
,
"shape"
,
"is_const"
,
"bound_data"
,
# resources for execution
"varnode"
,
...
...
@@ -242,6 +248,28 @@ class trace:
self
.
_active_tensors
.
update
(
outputs
)
return
outputs
def
_apply_const
(
self
,
op
,
args
):
assert
not
self
.
_untraced
# check against trace
if
self
.
_pc
>=
len
(
self
.
_seq
):
raise
TraceMismatchError
(
"trace should end here, but more op observed"
)
record
=
self
.
_seq
[
self
.
_pc
]
op_
,
ihandles
,
ohandles
=
record
assert
isinstance
(
op_
,
Const
)
eq
=
op_
.
value
==
op
.
value
if
not
isinstance
(
eq
,
bool
):
eq
=
all
(
eq
)
if
not
eq
:
raise
TraceMismatchError
(
"const tensor violated: got a different tensor this time"
)
self
.
_pc
+=
1
(
h
,)
=
ohandles
outputs
=
tuple
([
self
.
_tinfo
[
h
].
bound_data
])
return
outputs
def
_record_op
(
self
,
op
,
inputs
,
outputs
):
if
skip_tracing
:
for
x
in
inputs
:
...
...
@@ -275,7 +303,24 @@ class trace:
self
.
_active_tensors
.
update
(
outputs
)
def
_record_const
(
self
,
op
,
outputs
):
pass
if
skip_tracing
:
(
x
,)
=
outputs
h
=
getattr
(
x
,
"_TraceMixin__handle"
,
None
)
if
h
is
not
None
:
self
.
_tinfo
[
h
].
data_read
=
True
return
(
x
,)
=
outputs
h
,
info
=
self
.
_new_handle
()
ohandles
=
[
h
]
info
.
external
=
True
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
shape
info
.
bound_data
=
x
info
.
is_const
=
True
TraceMixin
.
_TraceMixin__inject
(
x
,
h
)
self
.
_seq
.
append
((
op
,
tuple
(),
tuple
(
ohandles
)))
def
_set_active
(
self
,
active
:
bool
):
global
active_trace
...
...
@@ -308,6 +353,11 @@ class trace:
for
x
in
lazy_eval_tensors
]
self
.
_apply_graph_options
(
lazy_eval_graph
)
# FIXME
if
self
.
_graph_opt_level
is
not
None
:
lazy_eval_graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
else
:
lazy_eval_graph
.
options
.
graph_opt_level
=
2
lazy_eval_graph
.
compile
(
*
lazy_eval_links
,
*
readers
)
lazy_eval_graph
()
for
r
,
x
in
zip
(
readers
,
lazy_eval_tensors
):
...
...
@@ -323,6 +373,7 @@ class trace:
self
.
_init_trace
(
self
.
_symbolic
)
else
:
apply
.
enable
(
apply_compiled_mode
)
apply
.
enable
(
apply_const_compiled_mode
)
if
self
.
_graph
is
None
:
self
.
_compile
()
self
.
_graph
.
execute
()
...
...
@@ -370,6 +421,7 @@ class trace:
apply
.
disable
(
apply_symbolic_mode
)
apply
.
disable
(
apply_const_symbolic_mode
)
apply
.
disable
(
apply_compiled_mode
)
apply
.
disable
(
apply_const_compiled_mode
)
self
.
_set_active
(
False
)
def
do_exit
():
...
...
@@ -409,8 +461,10 @@ class trace:
graph
.
options
.
no_force_inplace
=
True
graph
.
options
.
seq_opt
.
enable_seq_comp_node_opt
=
False
# graph opt level
if
self
.
_graph_opt_level
is
not
None
:
graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
# if self._graph_opt_level is not None:
# graph.options.graph_opt_level = self._graph_opt_level
# FIXME
graph
.
options
.
graph_opt_level
=
0
# sublinear
if
self
.
_sublinear_memory_config
is
not
None
:
graph
.
options
.
enable_sublinear_memory_opt
=
True
...
...
@@ -442,22 +496,49 @@ class trace:
for
h
in
itertools
.
chain
(
self
.
_arg_bindings
,
self
.
_kwarg_bindings
.
values
()):
info
=
self
.
_tinfo
[
h
]
opnode
=
info
.
data_setter
=
G
.
InputNode
(
device
=
info
.
device
,
dtype
=
info
.
dtype
,
shape
=
info
.
shape
,
graph
=
graph
device
=
info
.
device
,
dtype
=
info
.
dtype
,
shape
=
info
.
shape
,
graph
=
graph
,
use_static_shape
=
_input_node_use_static_shape
(),
)
need_reset_nodes
.
append
(
opnode
)
info
.
varnode
=
opnode
.
outputs
[
0
]
links
+=
opnode
.
outputs
[
1
:]
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
require_links
=
type
(
op
)
in
_io_op_types
if
isinstance
(
op
,
Const
):
assert
len
(
ihandles
)
==
0
(
h
,)
=
ohandles
info
=
self
.
_tinfo
[
h
]
if
not
hasattr
(
info
,
"varnode"
):
assert
info
.
external
assert
info
.
bound_data
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
device
,
)
continue
require_links
=
type
(
op
)
in
_io_op_types
ivars
=
[]
for
i
,
h
in
enumerate
(
ihandles
):
info
=
self
.
_tinfo
[
h
]
if
not
hasattr
(
info
,
"varnode"
):
assert
info
.
external
if
info
.
bound_data
:
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
_dev_tensor
())
if
hasattr
(
info
,
"is_const"
)
and
info
.
is_const
:
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
info
.
bound_data
.
dtype
,
info
.
bound_data
.
device
,
)
else
:
info
.
varnode
=
graph
.
make_const
(
info
.
bound_data
.
_dev_tensor
()
# info.bound_data.numpy()
)
else
:
opnode
=
info
.
data_setter
=
G
.
InputNode
(
*
links
,
...
...
@@ -465,6 +546,7 @@ class trace:
dtype
=
info
.
dtype
,
shape
=
info
.
shape
,
graph
=
graph
,
use_static_shape
=
_input_node_use_static_shape
(),
)
need_reset_nodes
.
append
(
opnode
)
info
.
varnode
,
*
links
=
opnode
.
outputs
...
...
@@ -500,7 +582,11 @@ class trace:
if
info
.
shape_read
:
opnode
=
info
.
shape_reader
=
G
.
AttrOutputNode
(
v
,
*
links
)
add_reader
(
opnode
)
# FIXME
if
self
.
_graph_opt_level
is
not
None
:
graph
.
options
.
graph_opt_level
=
self
.
_graph_opt_level
else
:
graph
.
options
.
graph_opt_level
=
2
graph
.
compile
(
*
readers
)
def
_reset_exec_env
(
self
):
...
...
@@ -643,6 +729,17 @@ class trace:
)
for
op
,
ihandles
,
ohandles
in
self
.
_seq
:
if
isinstance
(
op
,
Const
):
assert
len
(
ihandles
)
==
0
(
h
,)
=
ohandles
info
=
self
.
_tinfo
[
h
]
if
h
not
in
h2v
:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
info
.
device
,
)
continue
ivars
=
[]
for
h
in
ihandles
:
info
=
self
.
_tinfo
[
h
]
...
...
@@ -874,6 +971,7 @@ class CompiledTensorProxy(RawTensor):
class
LazyEvalTensor
(
RawTensor
):
def
__init__
(
self
,
varnode
):
super
(
LazyEvalTensor
,
self
).
__init__
()
self
.
__varnode
=
varnode
@
property
...
...
@@ -953,11 +1051,22 @@ def assign_raw_tensor(lhs, rhs):
@
apply
.
register
()
def
apply_symbolic_mode
(
op
:
OpDef
,
*
args
:
RawTensor
):
graph
=
active_trace
.
_lazy_eval_graph
ivars
=
[
getattr
(
x
,
"_LazyEvalTensor__varnode"
,
None
)
or
graph
.
make_const
(
x
.
_dev_tensor
())
for
x
in
args
]
ivars
=
[]
for
x
in
args
:
var
=
getattr
(
x
,
"_LazyEvalTensor__varnode"
,
None
)
if
var
:
ivars
.
append
(
var
)
else
:
data_setter
=
G
.
InputNode
(
device
=
x
.
device
,
dtype
=
x
.
dtype
,
shape
=
x
.
shape
,
graph
=
graph
,
use_static_shape
=
True
,
)
var
=
data_setter
.
outputs
[
0
]
ivars
.
append
(
var
)
data_setter
.
set_value
(
x
.
_dev_tensor
())
require_links
=
type
(
op
)
in
_io_op_types
...
...
@@ -1004,6 +1113,20 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
apply
.
disable
(
apply_compiled_mode
)
@
apply
.
register
()
def
apply_const_compiled_mode
(
op
:
Const
,
*
args
:
RawTensor
):
if
skip_tracing
:
args
=
[
as_raw_tensor
(
x
.
_dev_tensor
())
if
x
.
__class__
is
CompiledTensorProxy
else
x
for
x
in
args
]
return
apply
.
super
(
op
,
*
args
)
return
active_trace
.
_apply_const
(
op
,
args
)
apply
.
disable
(
apply_const_compiled_mode
)
# this hook injects TraceMixin
@
apply
.
register
()
def
apply_with_tracing
(
op
:
OpDef
,
*
args
:
RawTensor
):
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
cc85047b
...
...
@@ -145,11 +145,6 @@ void init_graph_rt(py::module m) {
.
def_property_readonly
(
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();})
.
def_property_readonly
(
"shape"
,
[](
cg
::
VarNode
*
v
)
->
const
TensorShape
*
{
auto
&&
mgr
=
v
->
owner_graph
()
->
static_infer_manager
();
auto
&&
type
=
mgr
.
get_infer_type
(
v
);
using
InferType
=
cg
::
static_infer
::
InferType
;
if
(
!
(
type
.
shape
&
(
InferType
::
CONST
|
InferType
::
RT_STATIC
)))
{
return
nullptr
;
}
return
mgr
.
infer_shape_fallible
(
v
);
})
.
def_property_readonly
(
"value"
,
[](
cg
::
VarNode
*
v
)
->
py
::
object
{
...
...
@@ -437,7 +432,8 @@ void init_graph_rt(py::module m) {
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
cg
::
ComputingGraph
*
graph
,
bool
use_static_shape
)
{
if
(
!
graph
)
{
graph
=
inputs
[
0
]
->
owner_graph
();
}
...
...
@@ -446,7 +442,9 @@ void init_graph_rt(py::module m) {
sinputs
.
emplace_back
(
i
);
}
static_assert
(
!
std
::
is_reference
<
decltype
(
callback
)
>::
value
);
auto
soutputs
=
opr
::
InputCallback
::
make
(
*
graph
,
std
::
move
(
callback
),
comp_node
,
dtype
,
shape
,
sinputs
);
auto
soutputs
=
opr
::
InputCallback
::
make
(
*
graph
,
std
::
move
(
callback
),
comp_node
,
dtype
,
shape
,
sinputs
,
use_static_shape
);
std
::
vector
<
VarNode
*>
outputs
;
outputs
.
reserve
(
soutputs
.
size
());
for
(
auto
i
:
soutputs
)
{
...
...
@@ -490,23 +488,29 @@ void init_graph_rt(py::module m) {
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
return
input_callback
([
f
=
std
::
move
(
callback
)](){
py
::
gil_scoped_acquire
_
;
return
f
();},
comp_node
,
dtype
,
shape
,
inputs
,
graph
);
cg
::
ComputingGraph
*
graph
,
bool
use_static_shape
)
{
return
input_callback
(
[
f
=
std
::
move
(
callback
)](){
py
::
gil_scoped_acquire
_
;
return
f
();},
comp_node
,
dtype
,
shape
,
inputs
,
graph
,
use_static_shape
);
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
(),
py
::
arg
(
"use_static_shape"
)
=
false
);
m
.
def
(
"input_callback"
,
[
input_callback
](
std
::
shared_ptr
<
Rendezvous
<
DeviceTensorND
>>
p
,
const
CompNode
&
comp_node
,
const
DType
&
dtype
,
const
TensorShape
&
shape
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
cg
::
ComputingGraph
*
graph
)
{
cg
::
ComputingGraph
*
graph
,
bool
use_static_shape
)
{
auto
f
=
[
p
]()
->
DeviceTensorND
{
return
p
->
get
();
};
return
input_callback
(
std
::
move
(
f
),
comp_node
,
dtype
,
shape
,
inputs
,
graph
);
return
input_callback
(
std
::
move
(
f
),
comp_node
,
dtype
,
shape
,
inputs
,
graph
,
use_static_shape
);
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
());
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
(),
py
::
arg
()
=
py
::
tuple
(),
py
::
arg
(
"graph"
)
=
py
::
none
(),
py
::
arg
(
"use_static_shape"
)
=
false
);
auto
output_callback
=
[](
auto
callback
,
const
std
::
vector
<
cg
::
VarNode
*>&
inputs
,
std
::
shared_ptr
<
RendezvousBase
>
r
=
{},
bool
borrow
=
false
,
bool
prefer_host_value
=
false
)
{
...
...
imperative/python/test/integration/test_optimizer.py
浏览文件 @
cc85047b
...
...
@@ -97,7 +97,9 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for
param
in
net
.
parameters
():
ori_params
[
param
]
=
np
.
copy
(
param
.
numpy
())
train_func
(
np
.
random
.
random
(
data_shape
).
astype
(
np
.
float32
),
opt
=
opt
,
gm
=
gm
)
train_func
(
tensor
(
np
.
random
.
random
(
data_shape
).
astype
(
np
.
float32
)),
opt
=
opt
,
gm
=
gm
)
step
+=
1
check_func
(
ori_params
,
net
.
parameters
(),
step
)
...
...
imperative/python/test/unit/test_tracing.py
浏览文件 @
cc85047b
...
...
@@ -176,6 +176,7 @@ def test_trace_profiler():
assert
out
.
get
(
"profiler"
)
@
pytest
.
mark
.
skip
(
reason
=
"force opt_level=0 when building graph"
)
def
test_goptions
():
@
trace
(
symbolic
=
True
,
opt_level
=
0
,
capture_as_const
=
True
)
def
f
(
x
):
...
...
@@ -194,6 +195,7 @@ def test_goptions():
np
.
testing
.
assert_equal
(
g
(
d
).
numpy
().
item
(),
1.0
)
@
pytest
.
mark
.
skip
(
reason
=
"force opt_level=0 when building graph"
)
def
test_goptions_log_sum_exp
():
@
trace
(
symbolic
=
True
,
opt_level
=
0
,
capture_as_const
=
True
)
def
f
(
x
,
y
):
...
...
imperative/src/impl/opr_utility.cpp
浏览文件 @
cc85047b
...
...
@@ -33,14 +33,18 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback);
InputCallback
::
InputCallback
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
const
VarNodeArray
&
inputs
,
const
TensorShape
&
output_shape
,
const
OperatorNodeConfig
&
config
)
const
OperatorNodeConfig
&
config
,
bool
use_static_shape
)
:
Super
(
&
graph
,
config
,
"input_callback"
,
inputs
),
m_output_shape
(
output_shape
),
m_callback
(
callback
)
{
m_output_shape
(
output_shape
),
m_callback
(
callback
)
,
m_use_static_shape
(
use_static_shape
)
{
for
(
VarNode
*
i
:
inputs
)
{
add_input
({
i
});
}
DType
dt
=
config
.
output_dtype
();
mgb_assert
(
dt
.
valid
());
if
(
m_use_static_shape
){
mgb_assert
(
m_output_shape
.
ndim
);
}
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
NO_SYS_MEM_ALLOC
).
dtype
(
dt
);
add_output
(
None
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
)
...
...
@@ -52,7 +56,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback,
SymbolVarArray
InputCallback
::
make
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
CompNode
comp_node
,
DType
dtype
,
const
TensorShape
&
shape
,
const
SymbolVarArray
&
inputs
)
{
const
SymbolVarArray
&
inputs
,
bool
use_static_shape
)
{
mgb_assert
(
comp_node
.
valid
());
mgb_assert
(
dtype
.
valid
());
OperatorNodeConfig
config
;
...
...
@@ -60,24 +65,33 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph,
config
.
output_dtype
(
dtype
);
auto
vinputs
=
to_var_node_array
(
inputs
);
auto
opr
=
graph
.
insert_opr
(
std
::
make_unique
<
InputCallback
>
(
graph
,
callback
,
vinputs
,
shape
,
config
));
std
::
make_unique
<
InputCallback
>
(
graph
,
callback
,
vinputs
,
shape
,
config
,
use_static_shape
));
return
to_symbol_var_array
(
opr
->
output
());
}
void
InputCallback
::
init_output_static_infer_desc
()
{
if
(
m_output_shape
.
ndim
)
{
// Write this shape to static infer manager. The effect is
// that infer_shape_fallible() will return a non-empty shape
// while get_infer_type() remains NO_DESC. Most places check
// infer type before relying on inferred shape so things
// won't break. Memory optimizer however, deliberately omits
// infer type check so it will be able to use this shape for hint.
using
namespace
cg
::
static_infer
;
auto
*
var
=
output
(
0
);
var
->
shape
(
m_output_shape
);
auto
&&
mgr
=
cg
::
ComputingGraphImpl
::
downcast
(
owner_graph
())
->
static_infer_manager_impl
();
auto
*
handle
=
mgr
.
get_tag_handler_for_shape
(
var
);
handle
->
sync_from_var
();
using
namespace
cg
::
static_infer
;
if
(
m_use_static_shape
)
{
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
auto
infer_shape
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
)
{
dest
=
m_output_shape
;
return
true
;
};
mgr
.
register_shape_infer
(
output
(
0
),
{
SourceType
::
CONSTANT
,
{},
infer_shape
});
}
else
{
if
(
m_output_shape
.
ndim
)
{
// Write this shape to static infer manager. The effect is
// that infer_shape_fallible() will return a non-empty shape
// while get_infer_type() remains NO_DESC. Most places check
// infer type before relying on inferred shape so things
// won't break. Memory optimizer however, deliberately omits
// infer type check so it will be able to use this shape for hint.
auto
*
var
=
output
(
0
);
var
->
shape
(
m_output_shape
);
auto
&&
mgr
=
cg
::
ComputingGraphImpl
::
downcast
(
owner_graph
())
->
static_infer_manager_impl
();
auto
*
handle
=
mgr
.
get_tag_handler_for_shape
(
var
);
handle
->
sync_from_var
();
}
}
}
...
...
@@ -92,6 +106,9 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const {
void
InputCallback
::
scn_do_execute
()
{
auto
dev_tensor
=
m_callback
();
if
(
m_use_static_shape
)
{
mgb_assert
(
dev_tensor
.
shape
().
eq_shape
(
m_output_shape
));
}
output
(
0
)
->
reset_dev_tensor_from_tensor
(
dev_tensor
);
}
...
...
@@ -101,7 +118,10 @@ cg::OperatorNodeBase* InputCallback::shallow_copy(
const
OperatorNodeConfig
&
config
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
InputCallback
>
();
auto
*
graph
=
ctx
.
owner_graph
(
opr
,
inputs
);
return
graph
->
insert_opr
(
std
::
make_unique
<
InputCallback
>
(
*
graph
,
opr
.
m_callback
,
inputs
,
opr
.
m_output_shape
,
config
));
return
graph
->
insert_opr
(
std
::
make_unique
<
InputCallback
>
(
*
graph
,
opr
.
m_callback
,
inputs
,
opr
.
m_output_shape
,
config
,
opr
.
m_use_static_shape
));
}
MGB_REG_OPR_SHALLOW_COPY
(
InputCallback
,
InputCallback
::
shallow_copy
);
...
...
imperative/src/include/megbrain/imperative/opr_utility.h
浏览文件 @
cc85047b
...
...
@@ -35,13 +35,15 @@ public:
callback_t
callback
,
const
VarNodeArray
&
inputs
,
const
TensorShape
&
output_shape
,
const
OperatorNodeConfig
&
config
);
const
OperatorNodeConfig
&
config
,
bool
use_static_shape
);
static
SymbolVarArray
make
(
cg
::
ComputingGraph
&
graph
,
callback_t
callback
,
CompNode
comp_node
,
DType
dtype
,
const
TensorShape
&
shape
,
const
SymbolVarArray
&
inputs
=
{});
const
SymbolVarArray
&
inputs
=
{},
bool
use_static_shape
=
false
);
static
cg
::
OperatorNodeBase
*
shallow_copy
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
...
...
@@ -53,6 +55,7 @@ protected:
private:
TensorShape
m_output_shape
;
callback_t
m_callback
;
bool
m_use_static_shape
;
};
MGB_DEFINE_OPR_CLASS
(
OutputCallback
,
cg
::
SingleCNOperatorNodeBase
)
// {
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录