Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cdb692d2
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cdb692d2
编写于
9月 07, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(imperative): add TODO tag for some functions
GitOrigin-RevId: e295a1fa5537f13bc65f9e82b44a3f9cd56992a6
上级
90dd0716
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
18 deletion
+18
-18
imperative/python/megengine/core/ops/custom.py
imperative/python/megengine/core/ops/custom.py
+8
-1
imperative/src/impl/ops/custom_opdef.cpp
imperative/src/impl/ops/custom_opdef.cpp
+6
-15
src/opr/impl/custom_opnode.cpp
src/opr/impl/custom_opnode.cpp
+4
-2
未找到文件。
imperative/python/megengine/core/ops/custom.py
浏览文件 @
cdb692d2
...
...
@@ -7,13 +7,20 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.._imperative_rt.ops._custom
import
_install
,
_uninstall
,
_get_custom_op_list
,
_make_custom_op
from
.._imperative_rt.ops._custom
import
(
_get_custom_op_list
,
_install
,
_make_custom_op
,
_uninstall
,
)
__all__
=
[
"load"
]
def
_gen_custom_op_maker
(
custom_op_name
):
def
op_maker
(
**
kwargs
):
return
_make_custom_op
(
custom_op_name
,
kwargs
)
return
op_maker
...
...
imperative/src/impl/ops/custom_opdef.cpp
浏览文件 @
cdb692d2
...
...
@@ -95,6 +95,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs
for
(
auto
i_shape
:
i_shapes
)
{
if
(
i_shape
.
ndim
==
0
)
{
success
=
false
;
break
;
}
}
...
...
@@ -187,14 +188,11 @@ void apply_on_device_tensornd(const OpDef& def,
auto
cn
=
output
.
comp_node
();
cn
.
activate
();
}
// [TODO] sync should be modified
CompNode
::
sync_all
();
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
op
.
compute
(
inputs
,
outputs
);
// for (auto &&output: (*outputs)) {
// auto cn = output.comp_node();
// cn.sync(); // cannot sync ??????????
// }
CompNode
::
sync_all
();
}
...
...
@@ -224,19 +222,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
cg
::
VarNodeArray
&
inputs
)
{
SymbolVarArray
input_syms
;
for
(
auto
&
input_var
:
inputs
)
input_syms
.
emplace_back
(
input_var
);
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
OperatorNodeConfig
config
;
SymbolVarArray
output_sym
s
=
opr
::
CustomOpNode
::
make
(
op
.
impl
(),
input
_sym
s
,
op
.
param
(),
config
VarNodeArray
output
s
=
opr
::
CustomOpNode
::
make
(
op
.
impl
(),
inputs
,
op
.
param
(),
config
);
VarNodeArray
outputs
;
for
(
auto
&
output_sym
:
output_syms
)
outputs
.
push_back
(
output_sym
.
node
());
return
outputs
;
}
...
...
@@ -273,6 +263,7 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return
a
.
param
()
==
b
.
param
()
&&
a
.
runtime_id
()
==
b
.
runtime_id
();
}
// [TODO] to be implemented
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
mgb_assert
(
false
,
"Custom OpDef Props Function is not IMPLEMENTED now"
);
// can be implement with param schema
...
...
src/opr/impl/custom_opnode.cpp
浏览文件 @
cdb692d2
...
...
@@ -140,7 +140,8 @@ void CustomOpNode::do_execute(ExecEnv &env) {
std
::
vector
<
custom
::
Tensor
>
custom_inputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
inputs
);
std
::
vector
<
custom
::
Tensor
>
custom_outputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
outputs
);
m_op
->
compute
(
custom_inputs
,
m_param
,
custom_outputs
);
CompNode
::
sync_all
();
// whether reasonable
// [TODO] sync should be modified
CompNode
::
sync_all
();
this
->
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
AfterKernel
>
(
this
,
m_comp_node
...
...
@@ -157,7 +158,8 @@ void CustomOpNode::init_output_static_infer_desc() {
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
DepVal
dep
;
if
(
true
)
{
// need design a function to allow user to decide it
// [TODO] need design a interface to allow user to decide it
if
(
true
)
{
for
(
auto
input_var
:
input
())
dep
.
push_back
({
input_var
,
DepType
::
SHAPE
});
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录