Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cbf024bf
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cbf024bf
编写于
9月 01, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative): add adaptor between custom op and imperative runtime
GitOrigin-RevId: d7877f2e321ea2835006bbba6b80f6eb9a7e3111
上级
39ba3021
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
523 addition
and
0 deletion
+523
-0
imperative/python/megengine/core/ops/custom/__init__.py
imperative/python/megengine/core/ops/custom/__init__.py
+30
-0
imperative/python/megengine/tools/load_network_and_run.py
imperative/python/megengine/tools/load_network_and_run.py
+7
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+103
-0
imperative/python/src/ops.h
imperative/python/src/ops.h
+2
-0
imperative/src/impl/ops/custom_opdef.cpp
imperative/src/impl/ops/custom_opdef.cpp
+304
-0
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
...rative/src/include/megbrain/imperative/ops/custom_opdef.h
+77
-0
未找到文件。
imperative/python/megengine/core/ops/custom/__init__.py
0 → 100644
浏览文件 @
cbf024bf
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..._imperative_rt.ops
import
_custom
__all__
=
[]
for
k
,
v
in
_custom
.
__dict__
.
items
():
globals
()[
k
]
=
v
__all__
.
append
(
k
)
def
gen_custom_op_maker
(
custom_op_name
):
def
op_maker
(
**
kwargs
):
return
make_custom_op
(
custom_op_name
,
kwargs
)
return
op_maker
def
load
(
lib_path
):
op_in_this_lib
=
install
(
lib_path
[
0
:
-
3
],
lib_path
)
for
op
in
op_in_this_lib
:
op_maker
=
gen_custom_op_maker
(
op
)
globals
()[
op
]
=
op_maker
__all__
.
append
(
op
)
imperative/python/megengine/tools/load_network_and_run.py
浏览文件 @
cbf024bf
...
...
@@ -13,6 +13,7 @@ from collections import OrderedDict
import
numpy
as
np
import
megengine
as
mge
from
megengine.core.ops
import
custom
from
megengine.core.tensor
import
megbrain_graph
as
G
from
megengine.device
import
get_device_count
,
set_default_device
from
megengine.functional.debug_param
import
set_execution_strategy
...
...
@@ -397,6 +398,10 @@ def main():
type
=
str
,
help
=
"Record the static graph's static memory info."
,
)
parser
.
add_argument
(
"--custom-op-lib"
,
type
=
str
,
help
=
"path of the custom op"
,
)
args
=
parser
.
parse_args
()
if
args
.
verbose
:
...
...
@@ -409,6 +414,8 @@ def main():
if
args
.
dump_cpp_model
:
args
.
embed_input
=
True
if
args
.
custom_op_lib
is
not
None
:
custom
.
load
(
args
.
custom_op_lib
)
logger
.
info
(
"loading model ..."
)
ret
=
G
.
load_graph
(
args
.
net
)
...
...
imperative/python/src/ops.cpp
浏览文件 @
cbf024bf
...
...
@@ -607,4 +607,107 @@ void init_ops(py::module m) {
.
def
(
"compile"
,
[](
PySubgraphBuilder
&
self
,
int
gopt_level
){
return
(
std
::
shared_ptr
<
OpDef
>
)
CompiledOp
::
make
(
self
.
build
(),
gopt_level
);
});
auto
custom
=
submodule
(
m
,
"_custom"
);
init_custom
(
custom
);
}
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type) \
case mgb::custom::ParamDynType::dyn_type: { \
param_val = py::handle(kv.second).cast<static_type>(); \
break; \
}
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type) \
case mgb::custom::ParamDynType::dyn_type: { \
auto pyvals = py::handle(kv.second).cast<py::list>(); \
static_type vals; \
using basic_type = \
mgb::custom::get_vector_template_arg_type<static_type>::type; \
for (auto &pyval: pyvals) { \
vals.push_back(py::handle(pyval).cast<basic_type>()); \
} \
param_val = vals; \
break; \
}
PyObject
*
make_custom_op
(
PyObject
*
self
,
PyObject
**
args
,
Py_ssize_t
nargs
,
PyObject
*
kwnames
)
{
auto
op_name
=
py
::
handle
(
args
[
0
]).
cast
<
std
::
string
>
();
auto
kwargs
=
py
::
handle
(
args
[
1
]).
cast
<
py
::
dict
>
();
std
::
shared_ptr
<
OpDef
>
opdef
=
CustomOpDefFactory
::
inst
()
->
create_opdef
(
op_name
);
auto
&
custom_opdef
=
static_cast
<
mgb
::
imperative
::
CustomOpDef
&>
(
*
opdef
);
auto
&
param
=
custom_opdef
.
param
();
for
(
auto
&&
kv
:
kwargs
)
{
std
::
string
param_name
=
py
::
handle
(
kv
.
first
).
cast
<
std
::
string
>
();
std
::
string
type_name
=
py
::
handle
(
kv
.
second
).
ptr
()
->
ob_type
->
tp_name
;
if
(
!
param
.
exist
(
param_name
))
{
mgb_log_warn
(
"op %s have no param named %s, ignore this param parsed from python"
,
op_name
.
c_str
(),
param_name
.
c_str
()
);
continue
;
}
auto
&
param_val
=
param
[
param_name
];
switch
(
param_val
.
type
())
{
CUSTOM_FOR_EACH_BASIC_PARAMTYPE
(
CUSTOM_CASE_TO_PARSE_NON_LIST
)
CUSTOM_FOR_STRING_PARAMTYPE
(
CUSTOM_CASE_TO_PARSE_NON_LIST
)
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE
(
CUSTOM_CASE_TO_PARSE_LIST
)
CUSTOM_FOR_BOOL_LIST_PARAMTYPE
(
CUSTOM_CASE_TO_PARSE_LIST
)
CUSTOM_FOR_STRING_LIST_PARAMTYPE
(
CUSTOM_CASE_TO_PARSE_LIST
)
default:
{
mgb_assert
(
false
,
"param dtype of %s:%s is invalid"
,
op_name
.
c_str
(),
param_name
.
c_str
()
);
}
}
}
PyTypeObject
*
pytype
;
pytype
=
&
PyOpType
(
OpDef
);
PyObject
*
obj
=
pytype
->
tp_alloc
(
pytype
,
0
);
reinterpret_cast
<
PyOp
(
OpDef
)
*>
(
obj
)
->
op
=
opdef
;
return
obj
;
}
#undef CUSTOM_CASE_TO_PARSE_LIST
#undef CUSTOM_CASE_TO_PARSE_NON_LIST
py
::
list
install_custom
(
const
std
::
string
&
name
,
const
std
::
string
&
path
)
{
py
::
list
ret
;
const
auto
&
ops_in_lib
=
mgb
::
custom
::
LibManager
::
inst
()
->
install
(
name
,
path
);
for
(
const
auto
&
op
:
ops_in_lib
)
{
ret
.
append
(
op
);
}
return
std
::
move
(
ret
);
}
bool
uninstall_custom
(
const
std
::
string
&
name
)
{
return
mgb
::
custom
::
LibManager
::
inst
()
->
uninstall
(
name
);
}
py
::
list
get_custom_op_list
(
void
)
{
std
::
vector
<
std
::
string
>
all_ops
=
CustomOpDefFactory
::
inst
()
->
op_list
();
py
::
list
ret
;
for
(
auto
&
op
:
all_ops
)
{
ret
.
append
(
op
);
}
return
std
::
move
(
ret
);
}
void
init_custom
(
pybind11
::
module
m
)
{
m
.
def
(
"install"
,
&
install_custom
);
m
.
def
(
"uninstall"
,
&
uninstall_custom
);
m
.
def
(
"get_custom_op_list"
,
&
get_custom_op_list
);
static
PyMethodDef
method_def
=
{
"make_custom_op"
,
(
PyCFunction
)
make_custom_op
,
METH_FASTCALL
,
""
};
auto
*
func
=
PyCFunction_NewEx
(
&
method_def
,
nullptr
,
nullptr
);
pybind11
::
setattr
(
m
,
method_def
.
ml_name
,
func
);
}
imperative/python/src/ops.h
浏览文件 @
cbf024bf
...
...
@@ -16,6 +16,7 @@
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/param_defs.h"
#include "megbrain/imperative/ops/custom_opdef.h"
namespace
PYBIND11_NAMESPACE
{
namespace
detail
{
...
...
@@ -35,3 +36,4 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(ENUM_CASTER_DEF)
}
// PYBIND11_NAMESPACE
void
init_ops
(
pybind11
::
module
m
);
void
init_custom
(
pybind11
::
module
m
);
imperative/src/impl/ops/custom_opdef.cpp
0 → 100644
浏览文件 @
cbf024bf
/**
* \file imperative/src/impl/ops/custom_opdef.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/ops/custom_opdef.h"
#include "megbrain/opr/custom_opnode.h"
#include "megbrain/custom/data_adaptor.h"
#include "../op_trait.h"
namespace
mgb
{
namespace
imperative
{
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
CustomOpDef
);
CustomOpDef
::
CustomOpDef
(
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
op
)
:
m_op
(
op
),
m_param
(
op
->
param_info
())
{}
CustomOpDef
::
CustomOpDef
(
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
op
,
const
custom
::
Param
&
param
)
:
m_op
(
op
),
m_param
(
param
)
{}
void
CustomOpDef
::
param
(
const
custom
::
Param
&
rhs
)
{
m_param
=
rhs
;
}
custom
::
Param
&
CustomOpDef
::
param
(
void
)
{
return
m_param
;
}
custom
::
Param
CustomOpDef
::
param
(
void
)
const
{
return
m_param
;
}
size_t
CustomOpDef
::
input_num
(
void
)
const
{
return
m_op
->
input_num
();
}
size_t
CustomOpDef
::
output_num
(
void
)
const
{
return
m_op
->
output_num
();
}
std
::
string
CustomOpDef
::
name
(
void
)
const
{
return
m_op
->
op_type
();
}
custom
::
RunTimeId
CustomOpDef
::
runtime_id
(
void
)
const
{
return
m_op
->
runtime_id
();
}
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
CustomOpDef
::
impl
(
void
)
const
{
return
m_op
;
}
void
CustomOpDef
::
compute
(
const
SmallVector
<
DeviceTensorND
>
&
inputs
,
SmallVector
<
DeviceTensorND
>
*
outputs
)
const
{
std
::
vector
<
custom
::
Tensor
>
custom_inputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
inputs
);
std
::
vector
<
custom
::
Tensor
>
custom_outputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
*
outputs
);
m_op
->
compute
(
custom_inputs
,
this
->
m_param
,
custom_outputs
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
CustomOpDef
::
infer_output_attrs
(
const
SmallVector
<
TensorPtr
>
&
inputs
)
const
{
SmallVector
<
LogicalTensorDesc
>
input_descs
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
input_descs
[
i
].
comp_node
=
inputs
[
i
]
->
comp_node
();
input_descs
[
i
].
layout
=
inputs
[
i
]
->
layout
();
}
return
std
::
move
(
this
->
infer_output_attrs
(
input_descs
));
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
CustomOpDef
::
infer_output_attrs
(
const
SmallVector
<
LogicalTensorDesc
>
&
inputs
)
const
{
SmallVector
<
CompNode
>
i_devices
(
inputs
.
size
());
SmallVector
<
TensorShape
>
i_shapes
(
inputs
.
size
());
SmallVector
<
megdnn
::
DType
>
i_dtypes
(
inputs
.
size
());
SmallVector
<
TensorFormat
>
i_formats
(
inputs
.
size
());
for
(
int
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
i_devices
[
i
]
=
inputs
[
i
].
comp_node
;
i_shapes
[
i
]
=
inputs
[
i
].
layout
;
// TensorLayout is derived from TensorShape
i_dtypes
[
i
]
=
inputs
[
i
].
layout
.
dtype
;
i_formats
[
i
]
=
inputs
[
i
].
layout
.
format
;
}
bool
success
=
true
;
for
(
auto
i_shape
:
i_shapes
)
{
if
(
i_shape
.
ndim
==
0
)
{
success
=
false
;
}
}
SmallVector
<
CompNode
>
o_devices
;
SmallVector
<
megdnn
::
DType
>
o_dtypes
;
SmallVector
<
TensorFormat
>
o_formats
;
SmallVector
<
TensorShape
>
o_shapes
;
o_devices
=
custom
::
to_builtin
<
CompNode
,
custom
::
Device
>
(
m_op
->
infer_output_device
(
custom
::
to_custom
<
CompNode
,
custom
::
Device
>
(
i_devices
),
this
->
m_param
)
);
o_dtypes
=
custom
::
to_builtin
<
megdnn
::
DType
,
custom
::
DType
>
(
m_op
->
infer_output_dtype
(
custom
::
to_custom
<
megdnn
::
DType
,
custom
::
DType
>
(
i_dtypes
),
this
->
m_param
)
);
o_formats
=
custom
::
to_builtin
<
TensorFormat
,
custom
::
Format
>
(
m_op
->
infer_output_format
(
custom
::
to_custom
<
TensorFormat
,
custom
::
Format
>
(
i_formats
),
this
->
m_param
)
);
if
(
success
)
{
o_shapes
=
custom
::
to_builtin
<
TensorShape
,
custom
::
Shape
>
(
m_op
->
infer_output_shape
(
custom
::
to_custom
<
TensorShape
,
custom
::
Shape
>
(
i_shapes
),
this
->
m_param
)
);
}
else
{
o_shapes
=
SmallVector
<
TensorShape
>
(
this
->
output_num
());
}
SmallVector
<
LogicalTensorDesc
>
outputs
(
this
->
output_num
());
for
(
int
i
=
0
;
i
<
this
->
output_num
();
i
++
)
{
outputs
[
i
].
comp_node
=
std
::
move
(
o_devices
[
i
]);
outputs
[
i
].
layout
=
std
::
move
(
TensorLayout
(
o_shapes
[
i
],
o_dtypes
[
i
],
o_formats
[
i
])
);
}
return
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
(
outputs
,
success
);
}
CustomOpDefFactory
*
CustomOpDefFactory
::
inst
(
void
)
{
static
CustomOpDefFactory
factory
;
return
&
factory
;
}
bool
CustomOpDefFactory
::
is_custom_op
(
const
OpDef
&
op
)
{
return
op
.
dyn_typeinfo
()
==
CustomOpDef
::
typeinfo
();
}
CustomOpDefFactory
::
CustomOpDefFactory
()
{
ops
=
custom
::
CustomOpManager
::
inst
();
}
std
::
vector
<
std
::
string
>
CustomOpDefFactory
::
op_list
(
void
)
const
{
return
ops
->
op_name_list
();
}
std
::
shared_ptr
<
OpDef
>
CustomOpDefFactory
::
create_opdef
(
const
std
::
string
&
op_type
)
const
{
auto
op
=
ops
->
find
(
op_type
);
return
std
::
make_shared
<
CustomOpDef
>
(
op
);
}
std
::
shared_ptr
<
OpDef
>
CustomOpDefFactory
::
create_opdef
(
const
custom
::
RunTimeId
&
op_id
)
const
{
auto
op
=
ops
->
find
(
op_id
);
return
std
::
make_shared
<
CustomOpDef
>
(
op
);
}
std
::
shared_ptr
<
OpDef
>
CustomOpDefFactory
::
create_opdef
(
const
std
::
string
&
op_type
,
const
custom
::
Param
&
param
)
const
{
auto
op
=
ops
->
find
(
op_type
);
return
std
::
make_shared
<
CustomOpDef
>
(
op
,
param
);
}
std
::
shared_ptr
<
OpDef
>
CustomOpDefFactory
::
create_opdef
(
const
custom
::
RunTimeId
&
op_id
,
const
custom
::
Param
&
param
)
const
{
auto
op
=
ops
->
find
(
op_id
);
return
std
::
make_shared
<
CustomOpDef
>
(
op
,
param
);
}
namespace
custom_opdef
{
// avoid name conflict
void
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
for
(
auto
&&
output
:
(
*
outputs
))
{
auto
cn
=
output
.
comp_node
();
cn
.
activate
();
}
CompNode
::
sync_all
();
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
op
.
compute
(
inputs
,
outputs
);
// for (auto &&output: (*outputs)) {
// auto cn = output.comp_node();
// cn.sync(); // cannot sync ??????????
// }
CompNode
::
sync_all
();
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
auto
[
output_descs
,
success
]
=
op
.
infer_output_attrs
(
inputs
);
mgb_assert
(
success
==
true
,
"infer output attributes fall
\n
"
);
SmallVector
<
TensorPtr
>
outputs
(
output_descs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
auto
&
output
=
outputs
[
i
];
auto
&
output_desc
=
output_descs
[
i
];
output
=
Tensor
::
make
(
output_desc
.
layout
,
output_desc
.
comp_node
);
}
SmallVector
<
DeviceTensorND
>
inp_tensornds
(
inputs
.
size
());
SmallVector
<
DeviceTensorND
>
oup_tensornds
(
outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
inp_tensornds
[
i
]
=
inputs
[
i
]
->
dev_tensor
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
oup_tensornds
[
i
]
=
outputs
[
i
]
->
dev_tensor
();
apply_on_device_tensornd
(
def
,
inp_tensornds
,
&
oup_tensornds
);
return
outputs
;
}
VarNodeArray
apply_on_var_node
(
const
OpDef
&
def
,
const
cg
::
VarNodeArray
&
inputs
)
{
SymbolVarArray
input_syms
;
for
(
auto
&
input_var
:
inputs
)
input_syms
.
emplace_back
(
input_var
);
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
OperatorNodeConfig
config
;
SymbolVarArray
output_syms
=
opr
::
CustomOpNode
::
make
(
op
.
impl
(),
input_syms
,
op
.
param
(),
config
);
VarNodeArray
outputs
;
for
(
auto
&
output_sym
:
output_syms
)
outputs
.
push_back
(
output_sym
.
node
());
return
outputs
;
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
return
op
.
infer_output_attrs
(
inputs
);
}
std
::
tuple
<
SmallVector
<
MemoryDesc
>
,
SmallVector
<
MemoryDesc
>>
infer_output_mem_desc
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
const
SmallVector
<
MemoryDesc
>&
inputs_mems
)
{
return
{{},
{}};
}
size_t
hash
(
const
OpDef
&
def
)
{
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
const
custom
::
Param
&
param
=
op
.
param
();
size_t
val
=
mgb
::
hash
(
op
.
runtime_id
());
std
::
string
hash_str
=
""
;
for
(
auto
&&
val
:
param
.
raw
())
{
hash_str
+=
val
.
first
;
hash_str
+=
val
.
second
.
str
();
}
val
=
mgb
::
hash_pair_combine
(
val
,
mgb
::
hash
(
hash_str
));
return
val
;
}
bool
is_same_st
(
const
OpDef
&
lhs
,
const
OpDef
&
rhs
)
{
auto
&&
a
=
static_cast
<
const
CustomOpDef
&>
(
lhs
),
&&
b
=
static_cast
<
const
CustomOpDef
&>
(
rhs
);
return
a
.
param
()
==
b
.
param
()
&&
a
.
runtime_id
()
==
b
.
runtime_id
();
}
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
def
)
{
mgb_assert
(
false
,
"Custom OpDef Props Function is not IMPLEMENTED now"
);
// can be implement with param schema
// auto&& custom_opdef = def.cast_final_safe<CustomOpDef>();
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props_
;
return
props_
;
}
std
::
string
make_name
(
const
OpDef
&
def
)
{
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
return
op
.
name
();
}
}
// custom_opdef
OP_TRAIT_REG
(
CustomOpDef
,
CustomOpDef
)
.
apply_on_physical_tensor
(
imperative
::
custom_opdef
::
apply_on_physical_tensor
)
.
apply_on_var_node
(
imperative
::
custom_opdef
::
apply_on_var_node
)
.
apply_on_device_tensornd
(
imperative
::
custom_opdef
::
apply_on_device_tensornd
)
.
infer_output_attrs_fallible
(
imperative
::
custom_opdef
::
infer_output_attrs_fallible
)
.
infer_output_mem_desc
(
imperative
::
custom_opdef
::
infer_output_mem_desc
)
.
hash
(
imperative
::
custom_opdef
::
hash
)
.
is_same_st
(
imperative
::
custom_opdef
::
is_same_st
)
.
props
(
imperative
::
custom_opdef
::
props
)
.
make_name
(
imperative
::
custom_opdef
::
make_name
)
.
fallback
();
}
// imperative
}
// mgb
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
0 → 100644
浏览文件 @
cbf024bf
/**
* \file imperative/src/include/megbrain/imperative/ops/custom_opdef.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/custom/custom.h"
#include "megbrain/custom/manager.h"
#include "megbrain/imperative/op_def.h"
namespace
mgb
{
namespace
imperative
{
class
CustomOpDef
:
public
OpDefImplBase
<
CustomOpDef
>
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
m_op
;
custom
::
Param
m_param
;
public:
CustomOpDef
(
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
op
);
CustomOpDef
(
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
op
,
const
custom
::
Param
&
);
void
param
(
const
custom
::
Param
&
);
custom
::
Param
&
param
(
void
);
custom
::
Param
param
(
void
)
const
;
size_t
input_num
(
void
)
const
;
size_t
output_num
(
void
)
const
;
std
::
string
name
(
void
)
const
;
custom
::
RunTimeId
runtime_id
(
void
)
const
;
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>
&
impl
(
void
)
const
;
void
compute
(
const
SmallVector
<
DeviceTensorND
>&
,
SmallVector
<
DeviceTensorND
>*
)
const
;
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
const
SmallVector
<
TensorPtr
>
&
inputs
)
const
;
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
)
const
;
};
class
CustomOpDefFactory
{
custom
::
CustomOpManager
*
ops
;
CustomOpDefFactory
();
public:
PREVENT_COPY_AND_ASSIGN
(
CustomOpDefFactory
);
static
CustomOpDefFactory
*
inst
(
void
);
static
bool
is_custom_op
(
const
OpDef
&
op
);
std
::
vector
<
std
::
string
>
op_list
(
void
)
const
;
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
std
::
string
&
)
const
;
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
custom
::
RunTimeId
&
)
const
;
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
std
::
string
&
,
const
custom
::
Param
&
)
const
;
std
::
shared_ptr
<
OpDef
>
create_opdef
(
const
custom
::
RunTimeId
&
,
const
custom
::
Param
&
)
const
;
};
namespace
custom_opdef
{
// avoid name conflict
void
apply_on_device_tensornd
(
const
OpDef
&
,
const
SmallVector
<
DeviceTensorND
>&
,
SmallVector
<
DeviceTensorND
>*
);
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
,
const
SmallVector
<
TensorPtr
>&
);
VarNodeArray
apply_on_var_node
(
const
OpDef
&
,
const
cg
::
VarNodeArray
&
);
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
,
const
SmallVector
<
LogicalTensorDesc
>&
);
size_t
hash
(
const
OpDef
&
);
bool
is_same_st
(
const
OpDef
&
,
const
OpDef
&
);
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
);
std
::
string
make_name
(
const
OpDef
&
);
}
// custom_opdef
}
// imperative
}
// mgb
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录