Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
6fb19b66
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
6fb19b66
编写于
1月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/src): name operators automatically when tracing
GitOrigin-RevId: ff8eb003c5e2ee17de7d5ebd55a62391e64a48b1
上级
09de5a07
变更
34
隐藏空白更改
内联
并排
Showing
34 changed file
with
621 addition
and
99 deletion
+621
-99
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+7
-5
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+61
-8
imperative/python/megengine/module/module.py
imperative/python/megengine/module/module.py
+9
-5
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+14
-3
imperative/python/megengine/utils/naming.py
imperative/python/megengine/utils/naming.py
+63
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+9
-5
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+29
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+25
-3
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+7
-0
imperative/python/test/unit/test_dump_naming.py
imperative/python/test/unit/test_dump_naming.py
+169
-0
imperative/python/test/unit/test_tracing.py
imperative/python/test/unit/test_tracing.py
+1
-1
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+14
-4
imperative/src/impl/op_trait.h
imperative/src/impl/op_trait.h
+4
-1
imperative/src/impl/ops/batch_norm.cpp
imperative/src/impl/ops/batch_norm.cpp
+3
-2
imperative/src/impl/ops/broadcast.cpp
imperative/src/impl/ops/broadcast.cpp
+5
-3
imperative/src/impl/ops/collective_comm.cpp
imperative/src/impl/ops/collective_comm.cpp
+1
-1
imperative/src/impl/ops/cond_take.cpp
imperative/src/impl/ops/cond_take.cpp
+2
-2
imperative/src/impl/ops/elemwise.cpp
imperative/src/impl/ops/elemwise.cpp
+2
-1
imperative/src/impl/ops/img_proc.cpp
imperative/src/impl/ops/img_proc.cpp
+2
-1
imperative/src/impl/ops/io_remote.cpp
imperative/src/impl/ops/io_remote.cpp
+4
-2
imperative/src/impl/ops/matrix_inverse.cpp
imperative/src/impl/ops/matrix_inverse.cpp
+3
-1
imperative/src/impl/ops/nms.cpp
imperative/src/impl/ops/nms.cpp
+3
-1
imperative/src/impl/ops/opr_attr.cpp
imperative/src/impl/ops/opr_attr.cpp
+8
-1
imperative/src/impl/ops/resize.cpp
imperative/src/impl/ops/resize.cpp
+2
-1
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+84
-36
imperative/src/impl/ops/tensor_manip.cpp
imperative/src/impl/ops/tensor_manip.cpp
+4
-3
imperative/src/impl/ops/tensorrt_runtime.cpp
imperative/src/impl/ops/tensorrt_runtime.cpp
+2
-1
imperative/src/impl/ops/warp_affine.cpp
imperative/src/impl/ops/warp_affine.cpp
+2
-1
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+7
-2
imperative/tablegen/autogen.cpp
imperative/tablegen/autogen.cpp
+41
-5
imperative/tablegen/helper.h
imperative/tablegen/helper.h
+30
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+1
-0
src/core/include/megbrain/ir/base.td
src/core/include/megbrain/ir/base.td
+1
-0
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+2
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
6fb19b66
...
...
@@ -96,7 +96,7 @@ class Graph(_imperative_rt.ComputingGraph):
data
=
data
.
numpy
()
return
self
.
_wrap
(
_imperative_rt
.
make_const
(
self
,
data
,
device
,
data
.
dtype
))
def
make_const
(
self
,
data
,
dtype
=
None
,
device
=
None
):
def
make_const
(
self
,
data
,
dtype
=
None
,
device
=
None
,
name
=
None
):
if
isinstance
(
data
,
_imperative_rt
.
DeviceTensorND
):
assert
dtype
is
None
and
device
is
None
return
self
.
_wrap
(
_imperative_rt
.
make_shared
(
self
,
data
))
...
...
@@ -107,7 +107,9 @@ class Graph(_imperative_rt.ComputingGraph):
elif
data
.
dtype
==
np
.
int64
:
data
=
data
.
astype
(
np
.
int32
)
device
=
as_device
(
device
).
to_c
()
return
self
.
_wrap
(
_imperative_rt
.
make_const
(
self
,
data
,
device
,
dtype
))
return
self
.
_wrap
(
_imperative_rt
.
make_const
(
self
,
data
,
device
,
dtype
,
name
)
)
def
make_input
(
self
,
*
args
:
"VarNode"
,
device
=
None
,
dtype
=
None
,
shape
=
None
):
opnode
=
InputNode
(
*
args
,
device
=
device
,
dtype
=
dtype
,
shape
=
shape
,
graph
=
self
)
...
...
@@ -305,7 +307,7 @@ def dump_graph(
output_vars
:
Union
[
Dict
[
str
,
VarNode
],
List
[
VarNode
]],
*
,
keep_var_name
:
int
=
1
,
keep_op
_name
:
bool
=
Tru
e
,
keep_op
r_name
:
bool
=
Fals
e
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
strip_info_file
=
None
,
...
...
@@ -326,7 +328,7 @@ def dump_graph(
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_op_name: whether to keep operator names.
:param keep_op
r
_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
...
...
@@ -370,7 +372,7 @@ def dump_graph(
dump_content
=
_imperative_rt
.
dump_graph
(
ov
,
keep_var_name
,
keep_op_name
,
keep_op
r
_name
,
keep_param_name
,
keep_opr_priority
,
stat
,
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
6fb19b66
...
...
@@ -36,6 +36,7 @@ from ..core.ops.builtin import BackwardGraph, OpDef
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.utils
import
setscalar
from
..utils.naming
import
auto_naming
from
.sublinear_memory_config
import
SublinearMemoryConfig
...
...
@@ -77,6 +78,7 @@ def exclude_from_trace():
class
TensorInfo
:
__slots__
=
(
# collected attributes
"name"
,
"external"
,
"data_read"
,
"shape_read"
,
...
...
@@ -96,6 +98,7 @@ class TensorInfo:
)
def
__init__
(
self
):
self
.
name
=
None
self
.
exported
=
None
self
.
data_read
=
None
self
.
shape_read
=
None
...
...
@@ -290,12 +293,16 @@ class trace:
h
=
getattr
(
x
,
"_mixin_handle"
,
-
1
)
if
h
<
0
or
(
not
self
.
_capture_as_const
and
self
.
_tinfo
[
h
].
exported
):
h
,
info
=
self
.
_new_handle
()
name
=
auto_naming
.
get_scope
()
+
"."
+
x
.
c_name
if
x
.
c_name
else
x
.
_name
info
.
name
=
name
info
.
external
=
True
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
shape
if
self
.
_capture_as_const
:
info
.
bound_data
=
RawTensor
(
x
.
numpy
(),
x
.
dtype
,
x
.
device
,
False
)
info
.
bound_data
=
RawTensor
(
x
.
numpy
(),
x
.
dtype
,
x
.
device
,
False
,
name
)
ihandles
.
append
(
h
)
...
...
@@ -669,6 +676,12 @@ class trace:
arg_names
=
None
,
output_names
=
None
,
append
=
False
,
keep_var_name
:
int
=
1
,
keep_opr_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
strip_info_file
=
None
,
append_json
=
False
,
optimize_for_inference
=
True
,
**
kwargs
):
...
...
@@ -681,6 +694,20 @@ class trace:
use the default name if not specified.
:param append: whether output is appended to ``file``.
Only works when ``file`` is str.
:param keep_var_name: level for keeping variable names:
* 0: none of the names are kept
* 1: (default)keep names of output vars
* 2: keep names of all (output and internal) vars
:param keep_opr_name: whether to keep operator names.
:param keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
:param keep_opr_priority: whether to keep priority setting for operators
:param strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
:param append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
...
...
@@ -785,7 +812,10 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
info
.
device
,
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
info
.
device
,
name
=
info
.
name
,
)
continue
ivars
=
[]
...
...
@@ -795,13 +825,26 @@ class trace:
assert
info
.
external
assert
info
.
bound_data
h2v
[
h
]
=
graph
.
make_const
(
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_device
info
.
bound_data
.
numpy
(),
dtype
=
info
.
dtype
,
device
=
dumped_device
,
name
=
info
.
name
,
)
ivars
.
append
(
h2v
[
h
])
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
auto_naming
.
record_opnode
(
ovars
[
0
].
op
)
assert
len
(
ovars
)
==
len
(
ohandles
)
h2v
.
update
(
zip
(
ohandles
,
ovars
))
for
i
in
ohandles
:
name
=
auto_naming
.
get_var_name
(
i
)
if
name
is
not
None
:
h2v
[
i
].
name
=
name
auto_naming
.
remove_duplicate_names
()
dest_vars
=
[]
for
i
,
h
in
enumerate
(
self
.
_output_bindings
):
v
=
h2v
[
h
]
...
...
@@ -815,7 +858,15 @@ class trace:
if
isinstance
(
file
,
str
):
permission
=
"wb"
if
append
==
False
else
"ab"
file
=
open
(
file
,
permission
)
dump_content
,
dump_info
=
G
.
dump_graph
(
dest_vars
)
dump_content
,
dump_info
=
G
.
dump_graph
(
dest_vars
,
keep_var_name
=
keep_var_name
,
keep_opr_name
=
keep_opr_name
,
keep_param_name
=
keep_param_name
,
keep_opr_priority
=
keep_opr_priority
,
strip_info_file
=
strip_info_file
,
append_json
=
append_json
,
)
file
.
write
(
dump_content
)
return
dump_info
...
...
@@ -1095,20 +1146,22 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor):
return
active_trace
.
_apply_op
(
op
,
args
)
def
apply_const_compiled_mode
(
value
,
dtype
,
device
,
is_const
,
no_cache
):
def
apply_const_compiled_mode
(
value
,
dtype
,
device
,
is_const
,
no_cache
,
name
):
if
skip_tracing
:
args
=
[
RawTensor
(
x
.
_dev_tensor
())
if
x
.
__class__
is
CompiledTensorProxy
else
x
for
x
in
args
]
unset_tracing
()
ret
=
RawTensor
(
value
,
dtype
,
device
,
False
)
ret
=
RawTensor
(
value
,
dtype
,
device
,
False
,
name
)
set_tracing
()
return
ret
return
active_trace
.
_apply_const
(
value
,
dtype
,
device
)
def
apply_with_tracing
(
op
:
OpDef
,
*
args
:
RawTensor
):
if
hasattr
(
op
,
"scope"
):
op
.
scope
=
auto_naming
.
get_scope
()
if
active_trace
.
_symbolic
:
outputs
=
apply_symbolic_mode
(
op
,
*
args
)
else
:
...
...
@@ -1120,12 +1173,12 @@ def apply_with_tracing(op: OpDef, *args: RawTensor):
return
list
(
outputs
)
def
apply_const_with_tracing
(
value
,
dtype
,
device
,
is_const
,
no_cache
):
def
apply_const_with_tracing
(
value
,
dtype
,
device
,
is_const
,
no_cache
,
name
):
if
active_trace
.
_symbolic
:
outputs
=
apply_const_symbolic_mode
(
value
,
dtype
,
device
)
else
:
unset_tracing
()
outputs
=
(
RawTensor
(
value
,
dtype
,
device
,
False
),)
outputs
=
(
RawTensor
(
value
,
dtype
,
device
,
False
,
name
),)
set_tracing
()
active_trace
.
_record_const
(
outputs
)
return
list
(
outputs
)
imperative/python/megengine/module/module.py
浏览文件 @
6fb19b66
...
...
@@ -12,12 +12,12 @@ from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
import
numpy
as
np
from
..core._imperative_rt.core2
import
pop_scope
,
push_scope
from
..core.tensor.utils
import
make_shape_tuple
from
..logger
import
get_logger
from
..tensor
import
Parameter
,
Tensor
from
..utils.deprecation
import
deprecated
from
..utils.hook
import
HookHandler
from
..utils.naming
import
auto_naming
logger
=
get_logger
(
__name__
)
...
...
@@ -69,7 +69,9 @@ class Module(metaclass=ABCMeta):
Base Module class.
"""
def
__init__
(
self
):
def
__init__
(
self
,
name
=
""
):
self
.
name
=
name
# runtime attributes
self
.
training
=
True
self
.
quantize_disabled
=
False
...
...
@@ -79,6 +81,8 @@ class Module(metaclass=ABCMeta):
self
.
_forward_hooks
=
OrderedDict
()
self
.
_modules
=
[]
# used for profiler and automatic naming
self
.
_name
=
"{anonymous}"
@
abstractmethod
...
...
@@ -105,7 +109,7 @@ class Module(metaclass=ABCMeta):
return
HookHandler
(
self
.
_forward_hooks
,
hook
)
def
__call__
(
self
,
*
inputs
,
**
kwargs
):
push_scope
(
self
.
_name
)
auto_naming
.
push_scope
(
self
.
name
if
self
.
name
else
self
.
_name
)
for
hook
in
self
.
_forward_pre_hooks
.
values
():
modified_inputs
=
hook
(
self
,
inputs
)
if
modified_inputs
is
not
None
:
...
...
@@ -119,7 +123,7 @@ class Module(metaclass=ABCMeta):
modified_outputs
=
hook
(
self
,
inputs
,
outputs
)
if
modified_outputs
is
not
None
:
outputs
=
modified_outputs
pop_scope
(
self
.
_name
)
auto_naming
.
pop_scope
(
)
return
outputs
def
_flatten
(
...
...
@@ -579,7 +583,7 @@ class Module(metaclass=ABCMeta):
value
=
super
().
__getattribute__
(
name
)
if
name
==
"_name"
:
return
value
if
_is_module
(
value
):
if
isinstance
(
value
,
(
Tensor
,
Module
)
):
value
.
_name
=
name
return
value
...
...
imperative/python/megengine/tensor.py
浏览文件 @
6fb19b66
...
...
@@ -20,6 +20,7 @@ from .core.tensor.array_method import ArrayMethodMixin
from
.device
import
_valid_device
,
get_default_device
from
.logger
import
get_logger
from
.utils.deprecation
import
deprecated
from
.utils.naming
import
auto_naming
class
Tensor
(
_Tensor
,
ArrayMethodMixin
):
...
...
@@ -27,7 +28,9 @@ class Tensor(_Tensor, ArrayMethodMixin):
dmap_callback
=
None
_q_dict
=
None
def
__new__
(
cls
,
data
,
dtype
=
None
,
device
=
None
,
is_const
=
False
,
no_cache
=
False
):
def
__new__
(
cls
,
data
,
dtype
=
None
,
device
=
None
,
is_const
=
False
,
no_cache
=
False
,
name
=
""
):
if
device
is
None
:
cn
=
get_default_device
()
elif
isinstance
(
device
,
str
):
...
...
@@ -51,8 +54,7 @@ class Tensor(_Tensor, ArrayMethodMixin):
if
isinstance
(
data
,
np
.
ndarray
):
if
0
in
data
.
strides
:
data
=
data
.
squeeze
().
reshape
(
data
.
shape
)
obj
=
_Tensor
.
__new__
(
cls
,
data
,
dtype
,
cn
,
is_const
,
no_cache
)
obj
=
_Tensor
.
__new__
(
cls
,
data
,
dtype
,
cn
,
is_const
,
no_cache
,
name
)
return
obj
@
property
...
...
@@ -91,6 +93,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
piece
+=
", device={}"
.
format
(
self
.
device
)
+
")"
return
piece
@
property
def
name
(
self
):
return
self
.
c_name
@
name
.
setter
def
name
(
self
,
name
):
self
.
c_name
=
name
auto_naming
.
record_var_name
(
self
.
_mixin_handle
,
name
)
@
deprecated
(
version
=
"1.0"
,
reason
=
"no need to reuse an existing tensor since 1.0"
)
def
set_value
(
self
,
value
):
if
not
isinstance
(
value
,
_Tensor
):
...
...
imperative/python/megengine/utils/naming.py
0 → 100644
浏览文件 @
6fb19b66
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..core._imperative_rt.core2
import
pop_scope
,
push_scope
class
AutoNaming
:
r
"""
Name all executed operators automaticlly during tracing and record all tensors
renamed by the user.
"""
def
__init__
(
self
):
self
.
scopes
=
[]
self
.
c_ops
=
[]
self
.
name2ops
=
{}
self
.
handle2names
=
{}
def
clear
(
self
):
for
var
in
vars
(
self
).
values
():
var
.
clear
()
def
push_scope
(
self
,
scope
):
push_scope
(
scope
)
self
.
scopes
.
append
(
scope
)
def
pop_scope
(
self
):
scope
=
self
.
scopes
.
pop
()
pop_scope
(
scope
)
def
get_scope
(
self
):
return
"."
.
join
(
self
.
scopes
)
def
record_var_name
(
self
,
handle
,
name
):
self
.
handle2names
[
handle
]
=
name
def
get_var_name
(
self
,
handle
):
return
self
.
handle2names
.
pop
(
handle
,
None
)
def
record_opnode
(
self
,
op
):
ops
=
self
.
name2ops
.
get
(
op
.
name
,
[])
ops
.
append
(
op
)
self
.
name2ops
[
op
.
name
]
=
ops
def
remove_duplicate_names
(
self
):
for
key
,
ops
in
self
.
name2ops
.
items
():
if
len
(
ops
)
==
1
:
continue
for
i
,
op
in
enumerate
(
ops
):
op
.
name
=
key
+
"[%s]"
%
str
(
i
)
if
len
(
op
.
outputs
)
==
1
:
continue
for
var
in
op
.
outputs
:
var
.
name
=
var
.
name
.
replace
(
key
,
op
.
name
)
self
.
name2ops
.
clear
()
auto_naming
=
AutoNaming
()
imperative/python/src/graph_rt.cpp
浏览文件 @
6fb19b66
...
...
@@ -294,7 +294,7 @@ void init_graph_rt(py::module m) {
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
bool
keep_op_name
,
bool
keep_op
r
_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
py
::
list
&
stat
,
...
...
@@ -307,7 +307,7 @@ void init_graph_rt(py::module m) {
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
keep_opr_priority
,
keep_op_name
};
keep_opr_priority
,
keep_op
r
_name
};
auto
rst
=
dumper
->
dump
(
symvars
,
config
);
for
(
auto
i
:
rst
.
inputs
)
{
...
...
@@ -457,13 +457,17 @@ void init_graph_rt(py::module m) {
return
opr
::
SharedDeviceTensor
::
make
(
*
graph
,
std
::
make_shared
<
DeviceTensorND
>
(
data
)).
node
();
});
m
.
def
(
"make_const"
,
[](
cg
::
ComputingGraph
*
graph
,
py
::
array
data
,
CompNode
cn
,
DType
dtype
)
{
m
.
def
(
"make_const"
,
[](
cg
::
ComputingGraph
*
graph
,
py
::
array
data
,
CompNode
cn
,
DType
dtype
,
std
::
optional
<
std
::
string
>
name
)
{
if
(
!
cn
.
valid
())
{
cn
=
CompNode
::
load
(
get_default_device
());
}
OperatorNodeConfig
config
(
cn
);
if
(
name
)
{
config
.
name
(
*
name
);
}
auto
hv
=
npy
::
np2tensor
(
data
.
ptr
(),
npy
::
Meth
::
borrow
(
cn
),
dtype
);
return
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
OperatorNodeConfig
(
cn
)
).
node
();
});
return
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
config
).
node
();
}
,
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
()
);
m
.
def
(
"make_h2d"
,
[](
cg
::
ComputingGraph
&
graph
,
CompNode
cn
,
DType
dtype
,
TensorShape
shape
,
std
::
optional
<
std
::
string
>
name
)
{
if
(
!
cn
.
valid
())
{
...
...
imperative/python/src/ops.cpp
浏览文件 @
6fb19b66
...
...
@@ -99,6 +99,14 @@ PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
#define py_get_generic(name, attr) \
py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template
<
typename
T
>
PyObject
*
py_get_scope_impl
(
PyObject
*
obj
,
void
*
/* closure */
)
{
// T: PyOpXXX inst(): return XXX in opdef.h.inl
auto
&
op
=
reinterpret_cast
<
T
*>
(
obj
)
->
inst
();
return
pyobj_convert_generic
<
std
::
string
>::
to
(
op
.
scope
());
}
#define py_get_scope(class) py_get_scope_impl<PyOp(class)>
template
<
typename
T
,
typename
U
,
U
T
::
Ty
::*
attr
>
int
py_set_generic_impl
(
PyObject
*
obj
,
PyObject
*
value
,
void
*
/* closure */
)
{
if
(
value
==
NULL
)
{
...
...
@@ -121,6 +129,27 @@ int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
#define py_set_generic(name, attr) \
py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>
template
<
typename
T
>
int
py_set_scope_impl
(
PyObject
*
obj
,
PyObject
*
value
,
void
*
/* closure */
)
{
if
(
value
==
NULL
)
{
PyErr_SetString
(
PyExc_TypeError
,
"Cannot delete the attribute"
);
return
-
1
;
}
auto
&
op
=
reinterpret_cast
<
T
*>
(
obj
)
->
inst
();
try
{
op
.
set_scope
(
pyobj_convert_generic
<
std
::
string
>::
from
(
value
));
return
0
;
}
catch
(
py
::
error_already_set
&
e
)
{
e
.
restore
();
}
catch
(
py
::
builtin_exception
&
e
)
{
e
.
set_error
();
}
catch
(...)
{
PyErr_SetString
(
PyExc_RuntimeError
,
"Unknown Error"
);
}
return
-
1
;
}
#define py_set_scope(class) py_set_scope_impl<PyOp(class)>
struct
PyOpDef
{
PyObject_HEAD
std
::
shared_ptr
<
OpDef
>
op
;
...
...
imperative/python/src/tensor.cpp
浏览文件 @
6fb19b66
...
...
@@ -24,6 +24,7 @@
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <range/v3/all.hpp>
#include <string>
#include <unordered_map>
...
...
@@ -222,14 +223,15 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
}
else
{
py
::
detail
::
loader_life_support
life_sup
;
// FIXME!!!required to cast DType
if
(
nargs
!=
4
&&
nargs
!=
5
)
{
throw
py
::
type_error
(
"expect
4 or 5
arguments"
);
if
(
nargs
!=
5
&&
nargs
!=
6
)
{
throw
py
::
type_error
(
"expect
5 or 6
arguments"
);
}
auto
data
=
tup
[
0
].
cast
<
py
::
array
>
();
DType
dtype
=
tup
[
1
].
cast
<
DType
>
();
CompNode
cn
=
tup
[
2
].
cast
<
CompNode
>
();
bool
is_const
=
tup
[
3
].
cast
<
bool
>
();
bool
no_cache
=
nargs
==
5
?
tup
[
4
].
cast
<
bool
>
()
:
false
;
bool
no_cache
=
nargs
==
6
?
tup
[
4
].
cast
<
bool
>
()
:
false
;
std
::
string
name
=
tup
[
nargs
-
1
].
cast
<
std
::
string
>
();
// const op
if
(
is_const
&&
is_tracing
)
{
...
...
@@ -259,6 +261,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
}
m_tensor
=
std
::
make_shared
<
Tensor
>
(
handle
);
m_tensor
->
user_custom_name
=
name
;
if
(
data
.
ndim
()
==
0
)
{
m_tensor
->
m_flags
|=
Tensor
::
Flags
::
SCALAR
;
...
...
@@ -313,6 +316,19 @@ REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)
#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC
#define SET_GET_NAME(member) \
PyObject* TensorWrapper::member() { \
return py::cast(m_tensor->member).release().ptr(); \
} \
void TensorWrapper::set_##member(PyObject* dest) { \
auto py_dest = py::reinterpret_borrow<py::object>(dest); \
m_tensor->member = py_dest.cast<std::string>(); \
}
SET_GET_NAME
(
user_custom_name
)
SET_GET_NAME
(
automatic_name
)
#undef SET_GET_NAME
PyObject
*
TensorWrapper
::
handle
()
{
return
py
::
cast
(
m_tensor
->
m_handle
).
release
().
ptr
();
}
...
...
@@ -453,7 +469,11 @@ void TensorWrapper::reset(PyObject* tensor) {
if
(
!
t
)
{
throw
py
::
type_error
(
"expect Tensor"
);
}
std
::
string
user_custom_name
=
m_tensor
->
user_custom_name
;
std
::
string
automatic_name
=
m_tensor
->
automatic_name
;
m_tensor
=
t
->
m_tensor
;
m_tensor
->
user_custom_name
=
user_custom_name
;
m_tensor
->
automatic_name
=
automatic_name
;
}
void
TensorWrapper
::
reset_varnode
()
{
...
...
@@ -785,6 +805,8 @@ void init_tensor(py::module m) {
.
def_getset
<&
TensorWrapper
::
handle
,
&
TensorWrapper
::
set_handle
>
(
"_handle"
)
.
def_getset
<&
TensorWrapper
::
compiled_info
,
&
TensorWrapper
::
set_compiled_info
>
(
"_compiled_info"
)
.
def_getset
<&
TensorWrapper
::
trace_mixin_info
,
&
TensorWrapper
::
set_trace_mixin_info
>
(
"_trace_mixin_info"
)
.
def_getset
<&
TensorWrapper
::
user_custom_name
,
&
TensorWrapper
::
set_user_custom_name
>
(
"c_name"
)
.
def_getset
<&
TensorWrapper
::
automatic_name
,
&
TensorWrapper
::
set_automatic_name
>
(
"_name"
)
.
finalize
();
if
(
!
tensor_type
)
throw
py
::
error_already_set
();
py
::
setattr
(
m
,
"Tensor"
,
tensor_type
);
...
...
imperative/python/src/tensor.h
浏览文件 @
6fb19b66
...
...
@@ -15,6 +15,7 @@
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include <string>
#include "./pyext17.h"
...
...
@@ -70,6 +71,8 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
GradInfo
m_grad_info
;
TraceInfo
m_trace_info
;
SharedHandle
m_handle
;
std
::
string
user_custom_name
;
std
::
string
automatic_name
;
cg
::
VarNode
*
m_var
;
using
Handle
=
interpreter
::
Interpreter
::
Handle
;
...
...
@@ -170,6 +173,10 @@ struct TensorWrapper {
void
set_compiled_info
(
PyObject
*
);
PyObject
*
trace_mixin_info
();
void
set_trace_mixin_info
(
PyObject
*
);
PyObject
*
user_custom_name
();
void
set_user_custom_name
(
PyObject
*
);
PyObject
*
automatic_name
();
void
set_automatic_name
(
PyObject
*
);
PyObject
*
_use_cnt
()
{
return
PyLong_FromSize_t
(
m_tensor
.
use_count
());
};
};
...
...
imperative/python/test/unit/test_dump_naming.py
0 → 100644
浏览文件 @
6fb19b66
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
io
import
numpy
as
np
import
pytest
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine
import
Parameter
,
Tensor
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.jit.tracing
import
trace
from
megengine.utils.naming
import
auto_naming
def
_dump_and_load
(
func
,
symbolic
,
keep_opr_name
=
True
):
auto_naming
.
clear
()
func
=
trace
(
func
,
symbolic
=
symbolic
,
capture_as_const
=
True
)
x
=
Tensor
(
np
.
ones
(
shape
=
(
2
,
3
)))
func
(
x
).
numpy
()
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
,
arg_names
=
"x"
,
keep_opr_name
=
keep_opr_name
,
keep_var_name
=
2
,
)
file
.
seek
(
0
)
*
_
,
outputs
=
G
.
load_graph
(
file
)
op
=
cgtools
.
get_oprs_seq
(
outputs
)[
-
1
]
return
op
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_auto_naming
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
def
forward
(
self
,
x
):
return
x
+
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.ADD"
assert
op
.
outputs
[
0
].
name
==
"simple.ADD"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_user_named_tensor
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
self
.
k
=
Parameter
(
1.0
,
name
=
"k"
)
def
forward
(
self
,
x
):
x
=
x
+
x
x
.
name
=
"o_x"
return
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.ADD"
assert
op
.
outputs
[
0
].
name
==
"o_x"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_user_named_param
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
self
.
k
=
Parameter
(
2.0
,
name
=
"k"
)
def
forward
(
self
,
x
):
return
self
.
k
*
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
inputs
[
0
].
name
==
"x"
assert
op
.
inputs
[
1
].
name
==
"simple.k"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_without_module
(
symbolic
):
def
f
(
x
):
return
2
*
x
op
=
_dump_and_load
(
f
,
symbolic
)
assert
op
.
name
==
"MUL"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_with_submodule
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
self
.
linear
=
M
.
Linear
(
3
,
3
)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
return
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.linear.ADD"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.linear.MatrixMul"
assert
op
.
outputs
[
0
].
name
==
"simple.linear.ADD"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_named_submodule
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
self
.
linear
=
M
.
Linear
(
3
,
3
,
name
=
"x"
)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
return
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.x.ADD"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.x.MatrixMul"
assert
op
.
outputs
[
0
].
name
==
"simple.x.ADD"
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_with_same_operators
(
symbolic
):
class
Simple
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
()
self
.
name
=
name
def
forward
(
self
,
x
):
x
=
F
.
relu
(
x
)
x
=
F
.
relu
(
x
)
return
x
m
=
Simple
(
"simple"
)
op
=
_dump_and_load
(
m
,
symbolic
)
assert
op
.
name
==
"simple.RELU[1]"
assert
op
.
inputs
[
0
].
owner
.
name
==
"simple.RELU[0]"
def
test_not_keep_opr_name
():
def
f
(
x
):
return
2
*
x
op
=
_dump_and_load
(
f
,
True
,
False
)
assert
op
.
name
==
"MUL(x,2[2])[4]"
imperative/python/test/unit/test_tracing.py
浏览文件 @
6fb19b66
...
...
@@ -148,7 +148,7 @@ def test_dump():
dump_info
=
f
.
dump
(
file
)
assert
dump_info
.
nr_opr
==
3
np
.
testing
.
assert_equal
(
dump_info
.
inputs
,
[
"arg_0"
,
"arg_1"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD
(arg_0,arg_1)[4]
"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD"
])
file
.
seek
(
0
)
infer_cg
=
cgtools
.
GraphInference
(
file
)
result
=
list
((
infer_cg
.
run
(
a
,
b
)).
values
())[
0
]
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
6fb19b66
...
...
@@ -75,10 +75,6 @@ std::vector<std::pair<const char*, std::string>> OpDef::props(
return
def
.
trait
()
->
props
(
def
);
}
const
char
*
OpDef
::
name
()
const
{
return
trait
()
->
name
;
}
std
::
string
OpDef
::
to_string
()
const
{
std
::
string
builder
=
"{"
;
for
(
auto
&&
[
name
,
value
]
:
props
(
*
this
))
{
...
...
@@ -107,6 +103,20 @@ const OpTrait* OpDef::trait() const {
return
m_trait
;
}
const
std
::
string
OpDef
::
scope
()
const
{
return
m_scope
;
}
void
OpDef
::
set_scope
(
const
std
::
string
&
scope
)
{
m_scope
=
scope
;
}
const
std
::
string
OpDef
::
make_name
()
const
{
if
(
m_scope
.
empty
())
return
trait
()
->
make_name
(
*
this
);
return
m_scope
+
"."
+
trait
()
->
make_name
(
*
this
);
}
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/impl/op_trait.h
浏览文件 @
6fb19b66
...
...
@@ -75,6 +75,7 @@ using GradMaker = detail::OpMeth<
using
Props
=
detail
::
OpMeth
<
decltype
(
OpDef
::
props
)
>
;
using
HashFunc
=
detail
::
OpMeth
<
size_t
(
const
OpDef
&
)
>
;
using
IsSame
=
detail
::
OpMeth
<
bool
(
const
OpDef
&
,
const
OpDef
&
)
>
;
using
MakeNameFunc
=
detail
::
OpMeth
<
std
::
string
(
const
OpDef
&
)
>
;
struct
OpTrait
{
const
char
*
name
;
...
...
@@ -88,6 +89,7 @@ struct OpTrait {
Props
props
;
HashFunc
hash
;
IsSame
is_same_st
;
MakeNameFunc
make_name
;
OpTrait
(
const
char
*
name
);
static
OpTrait
*
find_by_name
(
const
char
*
name
);
static
OpTrait
*
find_by_typeinfo
(
Typeinfo
*
type
);
...
...
@@ -104,7 +106,8 @@ struct OpTrait {
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st)
cb(is_same_st) \
cb(make_name)
struct
OpTraitRegistry
{
OpTrait
*
trait
;
...
...
imperative/src/impl/ops/batch_norm.cpp
浏览文件 @
6fb19b66
...
...
@@ -30,13 +30,14 @@ cg::OperatorNodeBase* apply_on_var_node(
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
3
||
nr_inp
==
5
,
"BatchNorm expects 3 or 5 inputs; got %lu actually"
,
nr_inp
);
OperatorNodeConfig
config
{
bn_opr
.
make_name
()};
if
(
nr_inp
==
3
)
{
return
opr
::
BatchNorm
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
bn_opr
.
param
())[
0
]
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
bn_opr
.
param
()
,
config
)[
0
]
.
node
()
->
owner_opr
();
}
else
{
return
opr
::
BatchNorm
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
4
],
bn_opr
.
param
())[
0
]
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
inputs
[
4
],
bn_opr
.
param
()
,
config
)[
0
]
.
node
()
->
owner_opr
();
}
}
...
...
imperative/src/impl/ops/broadcast.cpp
浏览文件 @
6fb19b66
...
...
@@ -27,10 +27,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
def
.
cast_final_safe
<
Broadcast
>
();
auto
&&
op
=
def
.
cast_final_safe
<
Broadcast
>
();
size_t
nr_inp
=
inputs
.
size
();
mgb_assert
(
nr_inp
==
2
,
"Broadcast expects 2 inputs; got %lu actually"
,
nr_inp
);
return
opr
::
Broadcast
::
make
(
inputs
[
0
],
inputs
[
1
]).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Broadcast
::
make
(
inputs
[
0
],
inputs
[
1
],
config
).
node
()
->
owner_opr
();
}
bool
valid_broadcast
(
const
TensorShape
&
src_shape
,
...
...
@@ -96,7 +97,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Reshape
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
Reshape
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Reshape
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
...
...
imperative/src/impl/ops/collective_comm.cpp
浏览文件 @
6fb19b66
...
...
@@ -35,7 +35,7 @@ cg::OperatorNodeBase* apply_on_var_node(
auto
disable
=
std
::
make_shared
<
DTypeScalar
>
();
disable
->
set
(
0
);
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
{
comm
.
make_name
()}
;
if
(
comm
.
comp_node
.
size
()
>
0
)
{
config
.
comp_node
(
CompNode
::
load
(
comm
.
comp_node
));
}
...
...
imperative/src/impl/ops/cond_take.cpp
浏览文件 @
6fb19b66
...
...
@@ -23,12 +23,12 @@ namespace {
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
def
.
cast_final_safe
<
CondTake
>
();
auto
&&
op
=
def
.
cast_final_safe
<
CondTake
>
();
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
opr
::
CondTake
::
Param
param
;
param
.
val
=
1
;
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
{
op
.
make_name
()}
;
cg
::
OperatorNodeBase
*
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
opr
::
CondTake
>
(
inputs
[
0
],
inputs
[
1
],
param
,
config
));
...
...
imperative/src/impl/ops/elemwise.cpp
浏览文件 @
6fb19b66
...
...
@@ -31,7 +31,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
elemwise_opr
=
def
.
cast_final_safe
<
Elemwise
>
();
return
opr
::
Elemwise
::
make
(
inputs
,
elemwise_opr
.
mode
).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
elemwise_opr
.
make_name
()};
return
opr
::
Elemwise
::
make
(
inputs
,
elemwise_opr
.
mode
,
config
).
node
()
->
owner_opr
();
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
...
...
imperative/src/impl/ops/img_proc.cpp
浏览文件 @
6fb19b66
...
...
@@ -23,7 +23,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
CvtColor
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
CvtColor
::
make
(
inputs
[
0
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
CvtColor
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
CvtColor
,
CvtColor
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
imperative/src/impl/ops/io_remote.cpp
浏览文件 @
6fb19b66
...
...
@@ -32,7 +32,7 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
ssprintf
(
"%s:%d"
,
send
.
addr
.
data
(),
send
.
port
));
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
{
send
.
make_name
()}
;
cg
::
OperatorNodeBase
*
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
RemoteSend
>
(
send
.
key
,
inputs
[
0
],
group_client
,
true
,
config
));
...
...
@@ -42,11 +42,13 @@ cg::OperatorNodeBase* apply_on_var_node_remote_send(
cg
::
OperatorNodeBase
*
apply_on_var_node_remote_recv
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
recv
=
def
.
cast_final_safe
<
RemoteRecv
>
();
OperatorNodeConfig
config
{
recv
.
cn
};
config
.
name
(
recv
.
make_name
());
auto
group_client
=
std
::
make_shared
<
GroupClientProxy
>
(
ssprintf
(
"%s:%d"
,
recv
.
addr
.
data
(),
recv
.
port
));
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
return
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
RemoteRecv
>
(
recv
.
key
,
inputs
[
0
],
*
graph
,
group_client
,
OperatorNodeConfig
{
recv
.
cn
}
,
recv
.
key
,
inputs
[
0
],
*
graph
,
group_client
,
config
,
recv
.
shape
,
recv
.
dtype
));
}
...
...
imperative/src/impl/ops/matrix_inverse.cpp
浏览文件 @
6fb19b66
...
...
@@ -21,8 +21,10 @@ namespace {
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
MatrixInverse
>
();
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
MatrixInverse
::
make
(
inputs
[
0
]);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
MatrixInverse
::
make
(
inputs
[
0
],
{},
config
);
}
OP_TRAIT_REG
(
MatrixInverse
,
MatrixInverse
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
imperative/src/impl/ops/nms.cpp
浏览文件 @
6fb19b66
...
...
@@ -29,7 +29,9 @@ cg::OperatorNodeBase* apply_on_var_node(
param
.
iou_thresh
=
nms_keep
.
iou_thresh
;
param
.
max_output
=
nms_keep
.
max_output
;
return
NMSKeepOpr
::
make
(
inputs
[
0
],
param
).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
nms_keep
.
make_name
()};
return
NMSKeepOpr
::
make
(
inputs
[
0
],
param
,
config
).
node
()
->
owner_opr
();
}
OP_TRAIT_REG
(
NMSKeep
,
NMSKeep
,
NMSKeepOpr
)
...
...
imperative/src/impl/ops/opr_attr.cpp
浏览文件 @
6fb19b66
...
...
@@ -79,11 +79,13 @@ public:
cg
::
OperatorNodeBase
*
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
attr
=
def
.
cast_final_safe
<
OprAttr
>
();
auto
config
=
attr
.
config
;
config
.
name
(
attr
.
make_name
());
mgb_assert
(
!
inputs
.
empty
());
auto
registry
=
serialization
::
OprRegistry
::
find_by_name
(
attr
.
type
);
mgb_assert
(
registry
,
"operator %s not found"
,
attr
.
type
.
c_str
());
OprParamsLoadContext
ctx
{
attr
.
param
,
inputs
[
0
]
->
owner_graph
()};
return
registry
->
loader
(
ctx
,
inputs
,
attr
.
config
);
return
registry
->
loader
(
ctx
,
inputs
,
config
);
}
std
::
shared_ptr
<
OpDef
>
make_from_op_node
(
cg
::
OperatorNodeBase
*
opr
)
{
...
...
@@ -99,10 +101,15 @@ std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
return
{};
}
std
::
string
make_name
(
const
OpDef
&
def
)
{
return
"OprAttr"
;
}
OP_TRAIT_REG
(
OprAttr
,
OprAttr
)
.
make_from_op_node
(
make_from_op_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
props
(
props
)
.
make_name
(
make_name
)
.
fallback
();
}
// anonymous namespace
...
...
imperative/src/impl/ops/resize.cpp
浏览文件 @
6fb19b66
...
...
@@ -24,7 +24,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Resize
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
Resize
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Resize
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
Resize
,
Resize
)
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
6fb19b66
...
...
@@ -46,7 +46,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
Convolution
&>
(
def
);
return
opr
::
Convolution
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
());
OperatorNodeConfig
config
{
conv
.
make_name
()};
return
opr
::
Convolution
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
OP_TRAIT_REG
(
Convolution
,
Convolution
,
opr
::
Convolution
)
...
...
@@ -60,7 +61,7 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
ConvolutionBackwardData
&>
(
def
);
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
{
conv
.
make_name
()}
;
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
ConvolutionBackwardData
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
{
...
...
@@ -88,7 +89,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
ds
=
static_cast
<
const
Dimshuffle
&>
(
def
);
return
opr
::
Dimshuffle
::
make
(
inputs
[
0
],
ds
.
pattern
);
OperatorNodeConfig
config
{
ds
.
make_name
()};
return
opr
::
Dimshuffle
::
make
(
inputs
[
0
],
ds
.
pattern
,
0UL
,
config
);
}
OP_TRAIT_REG
(
Dimshuffle
,
Dimshuffle
,
opr
::
Dimshuffle
)
...
...
@@ -107,7 +109,8 @@ auto apply_on_var_node(
for
(
auto
&&
i
:
add_axis
.
axis
)
{
param
.
push_back
(
Desc
::
make_add
(
i
));
}
return
opr
::
AxisAddRemove
::
make
(
inputs
[
0
],
param
);
OperatorNodeConfig
config
{
add_axis
.
make_name
()};
return
opr
::
AxisAddRemove
::
make
(
inputs
[
0
],
param
,
config
);
}
OP_TRAIT_REG
(
AddAxis
,
AddAxis
)
...
...
@@ -125,7 +128,8 @@ auto apply_on_var_node(
for
(
auto
&&
i
:
remove_axis
.
axis
)
{
param
.
push_back
(
Desc
::
make_remove
(
i
));
}
return
opr
::
AxisAddRemove
::
make
(
inputs
[
0
],
param
);
OperatorNodeConfig
config
{
remove_axis
.
make_name
()};
return
opr
::
AxisAddRemove
::
make
(
inputs
[
0
],
param
,
config
);
}
OP_TRAIT_REG
(
RemoveAxis
,
RemoveAxis
)
...
...
@@ -138,7 +142,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
topk
=
static_cast
<
const
TopK
&>
(
def
);
return
opr
::
TopK
::
make
(
inputs
[
0
],
inputs
[
1
],
topk
.
param
())[
0
]
OperatorNodeConfig
config
{
topk
.
make_name
()};
return
opr
::
TopK
::
make
(
inputs
[
0
],
inputs
[
1
],
topk
.
param
(),
config
)[
0
]
.
node
()
->
owner_opr
();
}
...
...
@@ -152,10 +157,12 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
reduce
=
static_cast
<
const
Reduce
&>
(
def
);
OperatorNodeConfig
config
{
reduce
.
make_name
()};
if
(
inputs
.
size
()
>
1
)
{
return
opr
::
Reduce
::
make
(
inputs
[
0
],
reduce
.
param
(),
inputs
[
1
]);
return
opr
::
Reduce
::
make
(
inputs
[
0
],
reduce
.
param
(),
inputs
[
1
]
,
config
);
}
else
{
return
opr
::
Reduce
::
make
(
inputs
[
0
],
reduce
.
param
());
return
opr
::
Reduce
::
make
(
inputs
[
0
],
reduce
.
param
(),
(
cg
::
VarNode
*
)
nullptr
,
config
);
}
}
...
...
@@ -175,7 +182,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
pool
=
static_cast
<
const
AdaptivePooling
&>
(
def
);
return
opr
::
AdaptivePooling
::
make
(
inputs
[
0
],
inputs
[
1
],
pool
.
param
());
OperatorNodeConfig
config
{
pool
.
make_name
()};
return
opr
::
AdaptivePooling
::
make
(
inputs
[
0
],
inputs
[
1
],
pool
.
param
(),
config
);
}
OP_TRAIT_REG
(
AdaptivePooling
,
AdaptivePooling
)
...
...
@@ -189,6 +197,7 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
ConvBias
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
conv
.
dtype
};
config
.
name
(
conv
.
make_name
());
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
ConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
3
)
{
...
...
@@ -210,6 +219,7 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
conv
=
static_cast
<
const
BatchConvBias
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
conv
.
dtype
};
config
.
name
(
conv
.
make_name
());
if
(
inputs
.
size
()
==
2
)
{
return
opr
::
BatchConvBias
::
make
(
inputs
[
0
],
inputs
[
1
],
conv
.
param
(),
conv
.
policy
(),
config
);
}
else
if
(
inputs
.
size
()
==
3
)
{
...
...
@@ -230,7 +240,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
pool
=
static_cast
<
const
Pooling
&>
(
def
);
return
opr
::
Pooling
::
make
(
inputs
[
0
],
pool
.
param
());
OperatorNodeConfig
config
{
pool
.
make_name
()};
return
opr
::
Pooling
::
make
(
inputs
[
0
],
pool
.
param
(),
config
);
}
OP_TRAIT_REG
(
Pooling
,
Pooling
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -243,8 +254,9 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
static_cast
<
const
MatrixMul
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
MatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
());
matmul
.
policy
()
,
config
);
}
OP_TRAIT_REG
(
MatrixMul
,
MatrixMul
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -257,8 +269,9 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
matmul
=
static_cast
<
const
BatchedMatrixMul
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
OperatorNodeConfig
config
{
matmul
.
make_name
()};
return
opr
::
BatchedMatrixMul
::
make
(
inputs
[
0
],
inputs
[
1
],
matmul
.
param
(),
matmul
.
policy
());
matmul
.
policy
()
,
config
);
}
OP_TRAIT_REG
(
BatchedMatrixMul
,
BatchedMatrixMul
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -267,10 +280,12 @@ OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
namespace
{
namespace
dot
{
auto
apply_on_var_node
(
const
OpDef
&
,
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Dot
>
();
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
Dot
::
make
(
inputs
[
0
],
inputs
[
1
]);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Dot
::
make
(
inputs
[
0
],
inputs
[
1
],
config
);
}
OP_TRAIT_REG
(
Dot
,
Dot
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -282,7 +297,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
argsort
=
static_cast
<
const
Argsort
&>
(
def
);
return
opr
::
Argsort
::
make
(
inputs
[
0
],
argsort
.
param
());
OperatorNodeConfig
config
{
argsort
.
make_name
()};
return
opr
::
Argsort
::
make
(
inputs
[
0
],
argsort
.
param
(),
config
);
}
OP_TRAIT_REG
(
Argsort
,
Argsort
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -294,7 +310,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
argmax
=
static_cast
<
const
Argmax
&>
(
def
);
return
opr
::
Argmax
::
make
(
inputs
[
0
],
argmax
.
param
());
OperatorNodeConfig
config
{
argmax
.
make_name
()};
return
opr
::
Argmax
::
make
(
inputs
[
0
],
argmax
.
param
(),
config
);
}
OP_TRAIT_REG
(
Argmax
,
Argmax
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -306,7 +323,8 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
argmin
=
static_cast
<
const
Argmin
&>
(
def
);
return
opr
::
Argmin
::
make
(
inputs
[
0
],
argmin
.
param
());
OperatorNodeConfig
config
{
argmin
.
make_name
()};
return
opr
::
Argmin
::
make
(
inputs
[
0
],
argmin
.
param
(),
config
);
}
OP_TRAIT_REG
(
Argmin
,
Argmin
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -318,11 +336,13 @@ auto apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
warp
=
static_cast
<
const
WarpPerspective
&>
(
def
);
OperatorNodeConfig
config
{
warp
.
make_name
()};
if
(
inputs
.
size
()
==
3
)
{
return
opr
::
WarpPerspective
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
warp
.
param
());
return
opr
::
WarpPerspective
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
warp
.
param
()
,
config
);
}
else
{
mgb_assert
(
inputs
.
size
()
==
4
);
return
opr
::
WarpPerspective
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
warp
.
param
());
return
opr
::
WarpPerspective
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
inputs
[
3
],
warp
.
param
(),
config
);
}
}
OP_TRAIT_REG
(
WarpPerspective
,
WarpPerspective
)
...
...
@@ -336,7 +356,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
local
=
static_cast
<
const
GroupLocal
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
GroupLocal
::
make
(
inputs
[
0
],
inputs
[
1
],
local
.
param
());
OperatorNodeConfig
config
{
local
.
make_name
()};
return
opr
::
GroupLocal
::
make
(
inputs
[
0
],
inputs
[
1
],
local
.
param
(),
config
);
}
OP_TRAIT_REG
(
GroupLocal
,
GroupLocal
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -349,7 +370,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
IndexingOneHot
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
IndexingOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
IndexingOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
IndexingOneHot
,
IndexingOneHot
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -362,7 +384,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
IndexingSetOneHot
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
return
opr
::
IndexingSetOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
IndexingSetOneHot
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
IndexingSetOneHot
,
IndexingSetOneHot
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -375,7 +398,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
TypeCvt
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
TypeCvt
::
make
(
inputs
[
0
],
op
.
dtype
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
TypeCvt
::
make
(
inputs
[
0
],
op
.
dtype
,
config
);
}
OP_TRAIT_REG
(
TypeCvt
,
TypeCvt
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -388,6 +412,7 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Concat
&>
(
def
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
return
opr
::
Concat
::
make
(
inputs
,
op
.
axis
,
config
);
}
OP_TRAIT_REG
(
Concat
,
Concat
)
...
...
@@ -402,6 +427,7 @@ auto apply_on_var_node(
auto
&&
op
=
static_cast
<
const
Copy
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
return
opr
::
Copy
::
make
(
inputs
[
0
],
config
);
}
OP_TRAIT_REG
(
Copy
,
Copy
)
...
...
@@ -411,10 +437,12 @@ OP_TRAIT_REG(Copy, Copy)
namespace
{
namespace
identity
{
auto
apply_on_var_node
(
const
OpDef
&
,
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Identity
>
();
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
Identity
::
make
(
inputs
[
0
]);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Identity
::
make
(
inputs
[
0
],
config
);
}
OP_TRAIT_REG
(
Identity
,
Identity
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -427,7 +455,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
AssertEqual
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
...
...
@@ -443,7 +472,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
UniformRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
UniformRNG
::
make
(
inputs
[
0
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
UniformRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
UniformRNG
,
UniformRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -456,7 +486,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
GaussianRNG
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
GaussianRNG
::
make
(
inputs
[
0
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
GaussianRNG
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
GaussianRNG
,
GaussianRNG
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -469,7 +500,9 @@ VarNodeArray apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIAlign
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
auto
*
opr
=
opr
::
ROIAlign
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
()).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIAlign
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
).
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
OP_TRAIT_REG
(
ROIAlign
,
ROIAlign
)
...
...
@@ -484,7 +517,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
NvOf
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
NvOf
::
make
(
inputs
[
0
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
NvOf
::
make
(
inputs
[
0
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
NvOf
,
NvOf
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -499,6 +533,7 @@ auto apply_on_var_node(
auto
&&
op
=
static_cast
<
const
Linspace
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
return
opr
::
Linspace
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
Linspace
,
Linspace
)
...
...
@@ -513,6 +548,7 @@ auto apply_on_var_node(
auto
&&
op
=
static_cast
<
const
Eye
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
cg
::
OperatorNodeConfig
config
{
op
.
comp_node
};
config
.
name
(
op
.
make_name
());
opr
::
Eye
::
Param
param
{
op
.
k
,
op
.
dtype
.
enumv
()};
return
opr
::
Eye
::
make
(
inputs
[
0
],
param
,
config
);
}
...
...
@@ -527,7 +563,10 @@ VarNodeArray apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ROIPooling
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
auto
*
opr
=
opr
::
ROIPooling
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
()).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
op
.
make_name
()};
auto
*
opr
=
opr
::
ROIPooling
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
).
node
()
->
owner_opr
();
return
{
opr
->
output
(
0
),
opr
->
output
(
1
)};
}
OP_TRAIT_REG
(
ROIPooling
,
ROIPooling
)
...
...
@@ -541,7 +580,8 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
Remap
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
Remap
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Remap
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
Remap
,
Remap
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
@@ -578,7 +618,8 @@ auto apply_on_var_node( \
const OpDef& def, \
const VarNodeArray& inputs) { \
auto&& op = static_cast<const NAME&>(def); \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items)); \
OperatorNodeConfig config{op.make_name()}; \
return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \
} \
OP_TRAIT_REG(NAME, NAME) \
.apply_on_var_node(apply_on_var_node) \
...
...
@@ -609,30 +650,35 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
FakeQuant
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
3
);
return
opr
::
FakeQuant
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
FakeQuant
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
FakeQuant
,
FakeQuant
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// fake_quant
namespace
{
namespace
tqt
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
TQT
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
TQT
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
TQT
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
TQT
,
TQT
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// tqt
namespace
{
namespace
elemwise_multi_type
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
ElemwiseMultiType
&>
(
def
);
OperatorNodeConfig
config
{
op
.
dtype
};
config
.
name
(
op
.
make_name
());
return
opr
::
ElemwiseMultiType
::
make
(
inputs
,
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
ElemwiseMultiType
,
ElemwiseMultiType
)
...
...
@@ -646,7 +692,9 @@ auto apply_on_var_node(
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
SVD
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
1
);
return
opr
::
SVD
::
make
(
inputs
[
0
],
op
.
param
())[
0
].
node
()
->
owner_opr
()
->
usable_output
();
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
SVD
::
make
(
inputs
[
0
],
op
.
param
(),
config
)[
0
]
.
node
()
->
owner_opr
()
->
usable_output
();
}
OP_TRAIT_REG
(
SVD
,
SVD
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
imperative/src/impl/ops/tensor_manip.cpp
浏览文件 @
6fb19b66
...
...
@@ -21,7 +21,8 @@ cg::OperatorNodeBase* apply_on_var_node(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op_def
=
def
.
cast_final_safe
<
GetVarShape
>
();
return
opr
::
GetVarShape
::
make
(
inputs
,
op_def
.
param
()).
node
()
->
owner_opr
();
OperatorNodeConfig
config
{
op_def
.
make_name
()};
return
opr
::
GetVarShape
::
make
(
inputs
,
op_def
.
param
(),
config
).
node
()
->
owner_opr
();
}
DispatchMode
decide_dispatch_mode
(
...
...
@@ -152,7 +153,7 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node(
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
auto
&&
shapes
=
get_shapes
(
param
.
shapes
);
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
(
param
.
make_name
())
;
cg
::
OperatorNodeBase
*
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
ParamPackSplit
>
(
inputs
[
0
],
param
.
offsets
,
shapes
,
config
));
...
...
@@ -189,7 +190,7 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node(
auto
&&
graph
=
inputs
[
0
]
->
owner_graph
();
VarNodeArray
inps
(
inputs
.
begin
(),
inputs
.
end
()
-
1
);
cg
::
OperatorNodeConfig
config
;
OperatorNodeConfig
config
{
param
.
make_name
()}
;
cg
::
OperatorNodeBase
*
opr
=
graph
->
insert_opr
(
std
::
make_unique
<
mgb
::
opr
::
ParamPackConcat
>
(
inps
,
inputs
.
back
(),
param
.
offsets
,
config
));
...
...
imperative/src/impl/ops/tensorrt_runtime.cpp
浏览文件 @
6fb19b66
...
...
@@ -20,8 +20,9 @@ namespace { namespace tensorrt_runtime {
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
TensorRTRuntime
&>
(
def
);
OperatorNodeConfig
config
{
op
.
make_name
()};
SymbolVarArray
sinputs
(
inputs
.
begin
(),
inputs
.
end
());
return
opr
::
TensorRTRuntimeOpr
::
make
(
op
.
buf
.
c_str
(),
op
.
buf_size
,
sinputs
);
return
opr
::
TensorRTRuntimeOpr
::
make
(
op
.
buf
.
c_str
(),
op
.
buf_size
,
sinputs
,
config
);
}
OP_TRAIT_REG
(
TensorRTRuntime
,
TensorRTRuntime
)
.
apply_on_var_node
(
apply_on_var_node
)
...
...
imperative/src/impl/ops/warp_affine.cpp
浏览文件 @
6fb19b66
...
...
@@ -21,7 +21,8 @@ namespace { namespace warp_affine {
const
VarNodeArray
&
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
3
);
auto
&&
op
=
static_cast
<
const
WarpAffine
&>
(
def
);
return
opr
::
WarpAffine
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
());
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
WarpAffine
::
make
(
inputs
[
0
],
inputs
[
1
],
inputs
[
2
],
op
.
param
(),
config
);
}
OP_TRAIT_REG
(
WarpAffine
,
WarpAffine
)
...
...
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
6fb19b66
...
...
@@ -36,6 +36,7 @@ class OpDef : public Hashable,
public
NonCopyableObj
,
public
std
::
enable_shared_from_this
<
OpDef
>
{
mutable
const
OpTrait
*
m_trait
=
nullptr
;
std
::
string
m_scope
;
public:
virtual
~
OpDef
()
=
default
;
...
...
@@ -86,10 +87,14 @@ public:
const
OpTrait
*
trait
()
const
;
const
char
*
name
()
const
;
std
::
string
to_string
()
const
;
const
std
::
string
scope
()
const
;
const
std
::
string
make_name
()
const
;
void
set_scope
(
const
std
::
string
&
scope
);
virtual
size_t
hash
()
const
;
virtual
bool
is_same_st
(
const
Hashable
&
)
const
;
...
...
imperative/tablegen/autogen.cpp
浏览文件 @
6fb19b66
...
...
@@ -113,9 +113,10 @@ static void gen_op_def_c_header_single(raw_ostream &os, MgbOp& op) {
"{0}({0}_)"
,
i
.
name
));
}
paramList
.
push_back
(
"std::string scope_ = {}"
);
gen_ctor
(
llvm
::
join
(
paramList
,
", "
),
": "
+
llvm
::
join
(
initList
,
", "
),
" {}"
);
" {
set_scope(scope_);
}"
);
}
auto
packedParams
=
op
.
getPackedParams
();
...
...
@@ -236,11 +237,19 @@ static void gen_op_def_c_body_single(raw_ostream &os, MgbOp& op) {
os
<<
mlir
::
tblgen
::
tgfmt
(
hashable
->
getPropsFunctionTemplate
(),
&
ctx
);
os
<<
"}
\n
"
;
// generate make_name()
os
<<
formatv
(
"std::string {0}(const OpDef& def_) {{
\n
"
,
formatMethImpl
(
"make_name"
)
);
os
<<
mlir
::
tblgen
::
tgfmt
(
hashable
->
getNameFunctionTemplate
(),
&
ctx
);
os
<<
"}
\n
"
;
os
<<
"} // anonymous namespace
\n
"
;
methods
.
push_back
(
"hash"
);
methods
.
push_back
(
"is_same_st"
);
methods
.
push_back
(
"props"
);
methods
.
push_back
(
"make_name"
);
}
if
(
!
methods
.
empty
())
{
os
<<
formatv
(
...
...
@@ -327,7 +336,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
targs
.
push_back
(
i
.
attr
.
getReturnType
());
}
os
<<
llvm
::
join
(
targs
,
", "
);
os
<<
">()"
;
os
<<
"
, std::string
>()"
;
for
(
auto
&&
i
:
op
.
getMgbAttributes
())
{
os
<<
formatv
(
", py::arg(
\"
{0}
\"
)"
,
i
.
name
);
auto
defaultValue
=
i
.
attr
.
getDefaultValue
();
...
...
@@ -337,7 +346,7 @@ static void gen_op_def_pybind11_single(raw_ostream &os, MgbOp& op, EnumContext&
hasDefaultCtor
=
true
;
}
}
os
<<
")"
;
os
<<
"
, py::arg(
\"
scope
\"
) = {}
)"
;
}
if
(
hasDefaultCtor
)
{
os
<<
"
\n
.def(py::init<>())"
;
...
...
@@ -442,6 +451,10 @@ EnumWrapper<{0}::{1}>::type2str = {{
className
,
i
.
name
));
}
getsetters
.
push_back
(
formatv
(
"{{
\"
scope
\"
, py_get_scope({0}), py_set_scope({0}),
\"
scope
\"
, NULL},"
,
className
));
// generate tp_init
std
::
string
initBody
;
if
(
!
op
.
getMgbAttributes
().
empty
())
{
...
...
@@ -449,6 +462,7 @@ EnumWrapper<{0}::{1}>::type2str = {{
llvm
::
for_each
(
op
.
getMgbAttributes
(),
[
&
](
auto
&&
attr
)
{
initBody
+=
formatv
(
"
\"
{0}
\"
, "
,
attr
.
name
);
});
initBody
+=
"
\"
scope
\"
, "
;
initBody
+=
"NULL};
\n
"
;
initBody
+=
" PyObject "
;
std
::
vector
<
std
::
string
>
attrs
;
...
...
@@ -456,12 +470,15 @@ EnumWrapper<{0}::{1}>::type2str = {{
attrs
.
push_back
(
formatv
(
"*{0} = NULL"
,
attr
.
name
));
});
initBody
+=
llvm
::
join
(
attrs
,
", "
)
+
";
\n
"
;
initBody
+=
" PyObject *scope = NULL;
\n
"
;
initBody
+=
" if (!PyArg_ParseTupleAndKeywords(args, kwds,
\"
|"
;
initBody
+=
std
::
string
(
op
.
getMgbAttributes
().
size
(),
'O'
);
// an extra slot created for name
initBody
+=
std
::
string
(
op
.
getMgbAttributes
().
size
()
+
1
,
'O'
);
initBody
+=
"
\"
, const_cast<char**>(kwlist)"
;
llvm
::
for_each
(
op
.
getMgbAttributes
(),
[
&
](
auto
&&
attr
)
{
initBody
+=
formatv
(
"
,
&{0}"
,
attr
.
name
);
initBody
+=
formatv
(
"
,
&{0}"
,
attr
.
name
);
});
initBody
+=
", &scope"
;
initBody
+=
"))
\n
"
;
initBody
+=
" return -1;
\n
"
;
llvm
::
for_each
(
op
.
getMgbAttributes
(),
[
&
](
auto
&&
attr
)
{
...
...
@@ -483,6 +500,25 @@ EnumWrapper<{0}::{1}>::type2str = {{
}
)"
,
className
,
attr
.
name
);
});
initBody
+=
formatv
(
R"(
if (scope) {{
try {{
reinterpret_cast<PyOp({0})*>(self)->inst().set_scope(
pyobj_convert_generic<std::string>::from(scope));
} catch(py::error_already_set& e) {{
e.restore();
return -1;
} catch(py::builtin_exception& e) {{
e.set_error();
return -1;
} catch(...) {{
PyErr_SetString(PyExc_RuntimeError, "Unknown Error");
return -1;
}
}
)"
,
className
);
}
initBody
+=
"
\n
return 0;"
;
...
...
imperative/tablegen/helper.h
浏览文件 @
6fb19b66
...
...
@@ -241,6 +241,30 @@ private:
body
+=
" return props_;
\n
"
;
return
body
;
}
std
::
string
getModeName
()
const
{
std
::
string
body
=
formatv
(
" auto&& op_ = def_.cast_final_safe<{0}>();
\n
"
" static_cast<void>(op_);
\n
"
,
getCppClassName
()
);
for
(
auto
&&
it
:
getMgbAttributes
())
{
if
(
it
.
name
==
"mode"
)
{
auto
*
enumAttr
=
llvm
::
dyn_cast
<
MgbEnumAttrMixin
>
(
&
it
.
attr
);
body
+=
" switch (op_.mode){
\n
"
;
for
(
auto
&&
enumMember
:
enumAttr
->
getEnumMembers
())
{
body
+=
formatv
(
" case {0}::{1}::{2}:
\n
"
,
getCppClassName
(),
enumAttr
->
getEnumName
(),
enumMember
);
body
+=
formatv
(
" return
\"
{0}
\"
;
\n
"
,
enumMember
);
}
body
+=
formatv
(
" default: return
\"
{0}::Unknown
\"
;
\n
"
,
getCppClassName
());
body
+=
" }
\n
"
;
}
}
return
body
;
}
public:
static
bool
classof
(
const
Operator
*
op
)
{
return
op
->
getDef
().
isSubClassOf
(
"MgbHashableOpMixin"
);
...
...
@@ -264,6 +288,12 @@ public:
}
return
getDefaultPropsFunction
();
}
std
::
string
getNameFunctionTemplate
()
const
{
if
(
getDef
().
getValueAsBit
(
"usingModeName"
))
{
return
getModeName
();
}
return
formatv
(
" return
\"
{0}
\"
;
\n
"
,
getCppClassName
());
}
};
}
// namespace tblgen
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
6fb19b66
...
...
@@ -476,6 +476,7 @@ def main():
output_mgbvars
=
feeds
[
"outputs"
]
output_mgbvars
=
optimize_for_inference
(
args
,
output_mgbvars
)
output_mgbvars
=
[
var
.
_node
for
var
in
output_mgbvars
]
inputs
=
cgtools
.
get_dep_vars
(
output_mgbvars
,
"Host2DeviceCopy"
)
inputs
=
sorted
((
i
.
name
,
i
.
dtype
)
for
i
in
inputs
)
...
...
src/core/include/megbrain/ir/base.td
浏览文件 @
6fb19b66
...
...
@@ -242,6 +242,7 @@ class MgbPackedParamBase<string className, string accessor>:
class MgbHashableOpMixin {
string hashFunction = ?;
string cmpFunction = ?;
bit usingModeName = 0;
}
class MgbOp<string mnemonic, list<MgbParamBase> params=[], list<OpTrait> traits=[]>:
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
6fb19b66
...
...
@@ -21,6 +21,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def Elemwise : MgbHashableOp<"Elemwise", [ElemwiseParam], [NoSideEffect]> {
let inputs = (ins Variadic<AnyType>:$input);
let results = (outs AnyType);
let usingModeName = 1;
}
def Reduce: MgbHashableOp<"Reduce", [ReduceParam]>;
...
...
@@ -247,6 +248,7 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara
let extraArguments = (ins
MgbDTypeAttr:$dtype
);
let usingModeName = 1;
}
def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录