Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
90dd0716
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
90dd0716
编写于
9月 03, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): modify the python interface of custom op
GitOrigin-RevId: e82e5de480048bda95faf4107fbf9bbacfb79233
上级
cbf024bf
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
22 addition
and
45 deletion
+22
-45
imperative/python/megengine/core/ops/custom.py
imperative/python/megengine/core/ops/custom.py
+6
-11
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+5
-5
imperative/src/impl/ops/custom_opdef.cpp
imperative/src/impl/ops/custom_opdef.cpp
+11
-11
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
...rative/src/include/megbrain/imperative/ops/custom_opdef.h
+0
-13
src/opr/impl/custom_opnode.cpp
src/opr/impl/custom_opnode.cpp
+0
-5
未找到文件。
imperative/python/megengine/core/ops/custom
/__init__
.py
→
imperative/python/megengine/core/ops/custom.py
浏览文件 @
90dd0716
...
@@ -7,24 +7,19 @@
...
@@ -7,24 +7,19 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..
._imperative_rt.ops
import
_custom
from
..
_imperative_rt.ops._custom
import
_install
,
_uninstall
,
_get_custom_op_list
,
_make_custom_op
__all__
=
[]
__all__
=
[
"load"
]
for
k
,
v
in
_custom
.
__dict__
.
items
():
def
_gen_custom_op_maker
(
custom_op_name
):
globals
()[
k
]
=
v
__all__
.
append
(
k
)
def
gen_custom_op_maker
(
custom_op_name
):
def
op_maker
(
**
kwargs
):
def
op_maker
(
**
kwargs
):
return
make_custom_op
(
custom_op_name
,
kwargs
)
return
_
make_custom_op
(
custom_op_name
,
kwargs
)
return
op_maker
return
op_maker
def
load
(
lib_path
):
def
load
(
lib_path
):
op_in_this_lib
=
install
(
lib_path
[
0
:
-
3
],
lib_path
)
op_in_this_lib
=
_
install
(
lib_path
[
0
:
-
3
],
lib_path
)
for
op
in
op_in_this_lib
:
for
op
in
op_in_this_lib
:
op_maker
=
gen_custom_op_maker
(
op
)
op_maker
=
_
gen_custom_op_maker
(
op
)
globals
()[
op
]
=
op_maker
globals
()[
op
]
=
op_maker
__all__
.
append
(
op
)
__all__
.
append
(
op
)
imperative/python/src/ops.cpp
浏览文件 @
90dd0716
...
@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
...
@@ -684,7 +684,7 @@ py::list install_custom(const std::string &name, const std::string &path) {
for
(
const
auto
&
op
:
ops_in_lib
)
{
for
(
const
auto
&
op
:
ops_in_lib
)
{
ret
.
append
(
op
);
ret
.
append
(
op
);
}
}
return
std
::
move
(
ret
)
;
return
ret
;
}
}
bool
uninstall_custom
(
const
std
::
string
&
name
)
{
bool
uninstall_custom
(
const
std
::
string
&
name
)
{
...
@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) {
...
@@ -701,12 +701,12 @@ py::list get_custom_op_list(void) {
}
}
void
init_custom
(
pybind11
::
module
m
)
{
void
init_custom
(
pybind11
::
module
m
)
{
m
.
def
(
"install"
,
&
install_custom
);
m
.
def
(
"
_
install"
,
&
install_custom
);
m
.
def
(
"uninstall"
,
&
uninstall_custom
);
m
.
def
(
"
_
uninstall"
,
&
uninstall_custom
);
m
.
def
(
"get_custom_op_list"
,
&
get_custom_op_list
);
m
.
def
(
"
_
get_custom_op_list"
,
&
get_custom_op_list
);
static
PyMethodDef
method_def
=
{
static
PyMethodDef
method_def
=
{
"make_custom_op"
,
(
PyCFunction
)
make_custom_op
,
METH_FASTCALL
,
""
"
_
make_custom_op"
,
(
PyCFunction
)
make_custom_op
,
METH_FASTCALL
,
""
};
};
auto
*
func
=
PyCFunction_NewEx
(
&
method_def
,
nullptr
,
nullptr
);
auto
*
func
=
PyCFunction_NewEx
(
&
method_def
,
nullptr
,
nullptr
);
pybind11
::
setattr
(
m
,
method_def
.
ml_name
,
func
);
pybind11
::
setattr
(
m
,
method_def
.
ml_name
,
func
);
...
...
imperative/src/impl/ops/custom_opdef.cpp
浏览文件 @
90dd0716
...
@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) {
...
@@ -286,19 +286,19 @@ std::string make_name(const OpDef& def) {
return
op
.
name
();
return
op
.
name
();
}
}
}
// custom_opdef
OP_TRAIT_REG
(
CustomOpDef
,
CustomOpDef
)
OP_TRAIT_REG
(
CustomOpDef
,
CustomOpDef
)
.
apply_on_physical_tensor
(
imperative
::
custom_opdef
::
apply_on_physical_tensor
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
apply_on_var_node
(
imperative
::
custom_opdef
::
apply_on_var_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_device_tensornd
(
imperative
::
custom_opdef
::
apply_on_device_tensornd
)
.
apply_on_device_tensornd
(
apply_on_device_tensornd
)
.
infer_output_attrs_fallible
(
i
mperative
::
custom_opdef
::
i
nfer_output_attrs_fallible
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
infer_output_mem_desc
(
i
mperative
::
custom_opdef
::
i
nfer_output_mem_desc
)
.
infer_output_mem_desc
(
infer_output_mem_desc
)
.
hash
(
imperative
::
custom_opdef
::
hash
)
.
hash
(
hash
)
.
is_same_st
(
i
mperative
::
custom_opdef
::
i
s_same_st
)
.
is_same_st
(
is_same_st
)
.
props
(
imperative
::
custom_opdef
::
props
)
.
props
(
props
)
.
make_name
(
imperative
::
custom_opdef
::
make_name
)
.
make_name
(
make_name
)
.
fallback
();
.
fallback
();
}
// custom_opdef
}
// imperative
}
// imperative
}
// mgb
}
// mgb
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
浏览文件 @
90dd0716
...
@@ -60,18 +60,5 @@ public:
...
@@ -60,18 +60,5 @@ public:
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
custom
::
RunTimeId
&
,
const
custom
::
Param
&
)
const
;
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
custom
::
RunTimeId
&
,
const
custom
::
Param
&
)
const
;
};
};
namespace
custom_opdef
{
// avoid name conflict
void
apply_on_device_tensornd
(
const
OpDef
&
,
const
SmallVector
<
DeviceTensorND
>&
,
SmallVector
<
DeviceTensorND
>*
);
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
,
const
SmallVector
<
TensorPtr
>&
);
VarNodeArray
apply_on_var_node
(
const
OpDef
&
,
const
cg
::
VarNodeArray
&
);
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
,
const
SmallVector
<
LogicalTensorDesc
>&
);
size_t
hash
(
const
OpDef
&
);
bool
is_same_st
(
const
OpDef
&
,
const
OpDef
&
);
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
);
std
::
string
make_name
(
const
OpDef
&
);
}
// custom_opdef
}
// imperative
}
// imperative
}
// mgb
}
// mgb
src/opr/impl/custom_opnode.cpp
浏览文件 @
90dd0716
...
@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() {
...
@@ -214,11 +214,6 @@ void CustomOpNode::on_output_comp_node_stream_changed() {
}
}
cg
::
OperatorNodeBase
::
NodeProp
*
CustomOpNode
::
do_make_node_prop
()
const
{
cg
::
OperatorNodeBase
::
NodeProp
*
CustomOpNode
::
do_make_node_prop
()
const
{
// auto ret = &const_cast<OperatorNodeBase::NodeProp&>(node_prop());
// for (auto &&inp_var: input())
// ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE);
// ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE);
// return ret;
return
OperatorNodeBase
::
do_make_node_prop
();
return
OperatorNodeBase
::
do_make_node_prop
();
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录