Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
8a692573
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
8a692573
编写于
8月 24, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(customop): support write builtin op with custom op
GitOrigin-RevId: cd90002fe851a025b002e918f4b6f638936e660f
上级
8db64303
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
667 addition
and
437 deletion
+667
-437
imperative/python/megengine/core/ops/custom.py
imperative/python/megengine/core/ops/custom.py
+18
-4
imperative/python/megengine/utils/custom_op_tools.py
imperative/python/megengine/utils/custom_op_tools.py
+8
-7
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+30
-28
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu
...rative/python/test/unit/core/custom_opsrc/matmul_scale.cu
+12
-8
imperative/python/test/unit/core/test_custom_op.py
imperative/python/test/unit/core/test_custom_op.py
+232
-65
imperative/src/impl/ops/custom_opdef.cpp
imperative/src/impl/ops/custom_opdef.cpp
+16
-24
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
...rative/src/include/megbrain/imperative/ops/custom_opdef.h
+2
-1
src/custom/impl/manager.cpp
src/custom/impl/manager.cpp
+64
-79
src/custom/impl/op.cpp
src/custom/impl/op.cpp
+42
-0
src/custom/impl/param_val.cpp
src/custom/impl/param_val.cpp
+15
-3
src/custom/impl/platform/custom_cuda.cpp
src/custom/impl/platform/custom_cuda.cpp
+1
-1
src/custom/impl/tensor.cpp
src/custom/impl/tensor.cpp
+74
-58
src/custom/include/megbrain/custom/adaptor.h
src/custom/include/megbrain/custom/adaptor.h
+13
-5
src/custom/include/megbrain/custom/manager.h
src/custom/include/megbrain/custom/manager.h
+23
-36
src/custom/include/megbrain/custom/param_val.h
src/custom/include/megbrain/custom/param_val.h
+2
-8
src/custom/include/megbrain/custom/utils.h
src/custom/include/megbrain/custom/utils.h
+4
-0
src/custom/test/manager.cpp
src/custom/test/manager.cpp
+8
-10
src/custom/test/op.cpp
src/custom/test/op.cpp
+89
-83
src/custom/test/tensor.cpp
src/custom/test/tensor.cpp
+1
-1
src/opr/impl/custom_opnode.cpp
src/opr/impl/custom_opnode.cpp
+12
-15
src/opr/include/megbrain/opr/custom_opnode.h
src/opr/include/megbrain/opr/custom_opnode.h
+1
-1
未找到文件。
imperative/python/megengine/core/ops/custom.py
浏览文件 @
8a692573
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
os
import
os
from
.._imperative_rt.ops._custom
import
(
from
.._imperative_rt.ops._custom
import
(
_get_custom_op_lib_info
,
_get_custom_op_list
,
_get_custom_op_list
,
_install
,
_install
,
_make_custom_op
,
_make_custom_op
,
...
@@ -22,8 +23,7 @@ def _gen_custom_op_maker(custom_op_name):
...
@@ -22,8 +23,7 @@ def _gen_custom_op_maker(custom_op_name):
def
load
(
lib_path
):
def
load
(
lib_path
):
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_name
=
os
.
path
.
splitext
(
lib_path
)[
0
]
op_in_this_lib
=
_install
(
lib_path
,
lib_path
)
op_in_this_lib
=
_install
(
lib_name
,
lib_path
)
for
op
in
op_in_this_lib
:
for
op
in
op_in_this_lib
:
op_maker
=
_gen_custom_op_maker
(
op
)
op_maker
=
_gen_custom_op_maker
(
op
)
globals
()[
op
]
=
op_maker
globals
()[
op
]
=
op_maker
...
@@ -32,5 +32,19 @@ def load(lib_path):
...
@@ -32,5 +32,19 @@ def load(lib_path):
def
unload
(
lib_path
):
def
unload
(
lib_path
):
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_path
=
os
.
path
.
abspath
(
lib_path
)
lib_name
=
os
.
path
.
splitext
(
lib_path
)[
0
]
op_in_lib
=
_uninstall
(
lib_path
)
_uninstall
(
lib_name
)
for
op
in
op_in_lib
:
del
globals
()[
op
]
__all__
.
remove
(
op
)
def
_make_official_custom_op
():
official_opr_list
=
_get_custom_op_list
()
for
op
in
official_opr_list
:
op_maker
=
_gen_custom_op_maker
(
op
)
if
op
not
in
globals
():
globals
()[
op
]
=
op_maker
__all__
.
append
(
op
)
_make_official_custom_op
()
imperative/python/megengine/utils/custom_op_tools.py
浏览文件 @
8a692573
...
@@ -782,6 +782,10 @@ def build(
...
@@ -782,6 +782,10 @@ def build(
with_cudnn
,
with_cudnn
,
abi_tag
,
abi_tag
,
)
)
target_libpath
=
"{}_v{}"
.
format
(
name
,
version
)
+
str
(
".dll"
if
IS_WINDOWS
else
".so"
)
if
verbose
:
if
verbose
:
if
version
!=
old_version
and
old_version
!=
None
:
if
version
!=
old_version
and
old_version
!=
None
:
print
(
print
(
...
@@ -795,8 +799,7 @@ def build(
...
@@ -795,8 +799,7 @@ def build(
print
(
print
(
"No modifications detected for {}, skipping build step..."
.
format
(
name
)
"No modifications detected for {}, skipping build step..."
.
format
(
name
)
)
)
return
return
os
.
path
.
join
(
build_dir
,
"{}"
.
format
(
target_libpath
))
name
=
"{}_v{}"
.
format
(
name
,
version
)
# phase 3: compiler and ninja check
# phase 3: compiler and ninja check
_check_ninja_availability
()
_check_ninja_availability
()
...
@@ -830,8 +833,6 @@ def build(
...
@@ -830,8 +833,6 @@ def build(
try
:
try
:
# phase 5: generate ninja build file
# phase 5: generate ninja build file
objs
=
[
_obj_file_path
(
src
)
for
src
in
sources
]
objs
=
[
_obj_file_path
(
src
)
for
src
in
sources
]
name
+=
".dll"
if
IS_WINDOWS
else
".so"
build_file_path
=
os
.
path
.
join
(
build_dir
,
"build.ninja"
)
build_file_path
=
os
.
path
.
join
(
build_dir
,
"build.ninja"
)
if
verbose
:
if
verbose
:
print
(
"Emitting ninja build file {}"
.
format
(
build_file_path
))
print
(
"Emitting ninja build file {}"
.
format
(
build_file_path
))
...
@@ -844,7 +845,7 @@ def build(
...
@@ -844,7 +845,7 @@ def build(
sources
=
sources
,
sources
=
sources
,
objects
=
objs
,
objects
=
objs
,
ldflags
=
ldflags
,
ldflags
=
ldflags
,
library_target
=
name
,
library_target
=
target_libpath
,
with_cuda
=
with_cuda
,
with_cuda
=
with_cuda
,
)
)
...
@@ -852,7 +853,7 @@ def build(
...
@@ -852,7 +853,7 @@ def build(
if
verbose
:
if
verbose
:
print
(
print
(
"Compiling and linking your custom op {}"
.
format
(
"Compiling and linking your custom op {}"
.
format
(
os
.
path
.
join
(
build_dir
,
name
)
os
.
path
.
join
(
build_dir
,
target_libpath
)
)
)
)
)
_build_with_ninja
(
build_dir
,
verbose
,
"compiling error"
)
_build_with_ninja
(
build_dir
,
verbose
,
"compiling error"
)
...
@@ -861,7 +862,7 @@ def build(
...
@@ -861,7 +862,7 @@ def build(
else
:
else
:
baton
.
wait
()
baton
.
wait
()
return
os
.
path
.
join
(
build_dir
,
name
)
return
os
.
path
.
join
(
build_dir
,
target_libpath
)
def
build_and_load
(
def
build_and_load
(
...
...
imperative/python/src/ops.cpp
浏览文件 @
8a692573
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "./tensor.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/common.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/autogen.h"
...
@@ -725,9 +725,7 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
...
@@ -725,9 +725,7 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
return
obj
;
return
obj
;
#else
#else
mgb_assert
(
mgb_assert
(
false
,
"CustomOp disabled, please build megengine with CustomOp open"
);
false
,
"Custom Op is disabled now, please build megengine with Custom Op open"
);
return
nullptr
;
return
nullptr
;
#endif
#endif
}
}
...
@@ -737,46 +735,49 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
...
@@ -737,46 +735,49 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
py
::
list
install_custom
(
const
std
::
string
&
name
,
const
std
::
string
&
path
)
{
py
::
list
install_custom
(
const
std
::
string
&
name
,
const
std
::
string
&
path
)
{
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
py
::
list
ret
;
const
auto
&
ops_in_lib
=
custom
::
CustomOpManager
::
inst
()
->
install
(
name
,
path
);
const
auto
&
ops_in_lib
=
custom
::
LibManager
::
inst
()
->
install
(
name
,
path
);
py
::
list
ret
=
py
::
cast
(
ops_in_lib
);
for
(
const
auto
&
op
:
ops_in_lib
)
{
ret
.
append
(
op
);
}
return
ret
;
return
ret
;
#else
#else
mgb_assert
(
mgb_assert
(
false
,
"CustomOp disabled, please build megengine with CustomOp open"
);
false
,
return
py
::
list
{};
"Custom Op is disabled now, please build megengine with Custom Op open"
);
py
::
list
ret
;
return
ret
;
#endif
#endif
}
}
bool
uninstall_custom
(
const
std
::
string
&
name
)
{
py
::
list
uninstall_custom
(
const
std
::
string
&
name
)
{
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
return
custom
::
LibManager
::
inst
()
->
uninstall
(
name
);
const
auto
&
ops_in_lib
=
custom
::
CustomOpManager
::
inst
()
->
uninstall
(
name
);
py
::
list
ret
=
py
::
cast
(
ops_in_lib
);
return
ret
;
#else
#else
mgb_assert
(
mgb_assert
(
false
,
"CustomOp disabled, please build megengine with CustomOp open"
);
false
,
"Custom Op is disabled now, please build megengine with Custom Op open"
);
return
false
;
return
false
;
#endif
#endif
}
}
py
::
list
get_custom_op_list
(
void
)
{
py
::
list
get_custom_op_list
(
void
)
{
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
std
::
vector
<
std
::
string
>
all_ops
=
CustomOpDefFactory
::
inst
()
->
op_list
();
std
::
vector
<
std
::
string
>
all_ops
=
custom
::
CustomOpManager
::
inst
()
->
op_name_list
();
py
::
list
ret
;
py
::
list
ret
=
py
::
cast
(
all_ops
);
for
(
auto
&
op
:
all_ops
)
{
ret
.
append
(
op
);
}
return
ret
;
return
ret
;
#else
#else
mgb_assert
(
mgb_assert
(
false
,
"CustomOp disabled, please build megengine with CustomOp open"
);
false
,
return
py
::
list
{};
"Custom Op is disabled now, please build megengine with Custom Op open"
);
#endif
py
::
list
ret
;
}
py
::
dict
get_custom_op_lib_info
(
void
)
{
#if MGB_CUSTOM_OP
auto
&&
libs
=
custom
::
CustomOpManager
::
inst
()
->
lib_info
();
py
::
dict
ret
;
for
(
auto
&&
[
lib_name
,
lib_handle
]
:
libs
)
{
py
::
list
ops
=
py
::
cast
(
lib_handle
->
ops_in_lib
());
ret
[
py
::
str
(
lib_name
)]
=
ops
;
}
return
ret
;
return
ret
;
#else
mgb_assert
(
false
,
"CustomOp disabled, please build megengine with CustomOp open"
);
return
py
::
list
{};
#endif
#endif
}
}
...
@@ -792,6 +793,7 @@ void init_custom(pybind11::module m) {
...
@@ -792,6 +793,7 @@ void init_custom(pybind11::module m) {
m
.
def
(
"_install"
,
&
install_custom
);
m
.
def
(
"_install"
,
&
install_custom
);
m
.
def
(
"_uninstall"
,
&
uninstall_custom
);
m
.
def
(
"_uninstall"
,
&
uninstall_custom
);
m
.
def
(
"_get_custom_op_list"
,
&
get_custom_op_list
);
m
.
def
(
"_get_custom_op_list"
,
&
get_custom_op_list
);
m
.
def
(
"_get_custom_op_lib_info"
,
&
get_custom_op_lib_info
);
m
.
def
(
"get_custom_op_abi_tag"
,
[](
void
)
->
int
{
m
.
def
(
"get_custom_op_abi_tag"
,
[](
void
)
->
int
{
int
ret
=
0
;
int
ret
=
0
;
#ifdef _GLIBCXX_USE_CXX11_ABI
#ifdef _GLIBCXX_USE_CXX11_ABI
...
...
imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu
浏览文件 @
8a692573
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <stdio.h>
#include "./matmul_scale.h"
#include "./matmul_scale.h"
#include "megbrain/custom/platform/custom_cuda.h"
using
namespace
custom
;
using
namespace
custom
;
...
@@ -51,12 +52,13 @@ void matmul_forward_helper(
...
@@ -51,12 +52,13 @@ void matmul_forward_helper(
float
scale
)
{
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
block
(
1
,
1
);
dim3
grid
(
N
/
block
.
x
,
M
/
block
.
y
);
dim3
grid
(
N
/
block
.
x
,
M
/
block
.
y
);
auto
stream
=
get_cuda_stream
(
lhs
.
device
());
DISPATCH_INT_AND_FLOAT_TYPES
(
res
.
dtype
(),
"matmul_forward"
,
([
&
]()
{
DISPATCH_INT_AND_FLOAT_TYPES
(
matmul_forward_naive
<
scalar_t
><<<
grid
,
block
>>>
(
res
.
dtype
(),
"matmul_forward"
,
([
&
]()
{
lhs
.
data
<
scalar_t
>
(),
rhs
.
data
<
scalar_t
>
(),
matmul_forward_naive
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
res
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
lhs
.
data
<
scalar_t
>
(),
rhs
.
data
<
scalar_t
>
(),
}));
res
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}
}
void
matmul_backward_lhs_helper
(
void
matmul_backward_lhs_helper
(
...
@@ -64,9 +66,10 @@ void matmul_backward_lhs_helper(
...
@@ -64,9 +66,10 @@ void matmul_backward_lhs_helper(
size_t
N
,
float
scale
)
{
size_t
N
,
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
block
(
1
,
1
);
dim3
grid
(
K
/
block
.
x
,
M
/
block
.
y
);
dim3
grid
(
K
/
block
.
x
,
M
/
block
.
y
);
auto
stream
=
get_cuda_stream
(
rhs
.
device
());
DISPATCH_INT_AND_FLOAT_TYPES
(
DISPATCH_INT_AND_FLOAT_TYPES
(
lhs_grad
.
dtype
(),
"matmul_backward_lhs"
,
([
&
]()
{
lhs_grad
.
dtype
(),
"matmul_backward_lhs"
,
([
&
]()
{
matmul_backward_lhs_naive
<
scalar_t
><<<
grid
,
block
>>>
(
matmul_backward_lhs_naive
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
rhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
rhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
lhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
lhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}));
...
@@ -77,9 +80,10 @@ void matmul_backward_rhs_helper(
...
@@ -77,9 +80,10 @@ void matmul_backward_rhs_helper(
size_t
N
,
float
scale
)
{
size_t
N
,
float
scale
)
{
dim3
block
(
1
,
1
);
dim3
block
(
1
,
1
);
dim3
grid
(
N
/
block
.
x
,
K
/
block
.
y
);
dim3
grid
(
N
/
block
.
x
,
K
/
block
.
y
);
auto
stream
=
get_cuda_stream
(
lhs
.
device
());
DISPATCH_INT_AND_FLOAT_TYPES
(
DISPATCH_INT_AND_FLOAT_TYPES
(
rhs_grad
.
dtype
(),
"matmul_backward_rhs"
,
([
&
]()
{
rhs_grad
.
dtype
(),
"matmul_backward_rhs"
,
([
&
]()
{
matmul_backward_rhs_naive
<
scalar_t
><<<
grid
,
block
>>>
(
matmul_backward_rhs_naive
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
lhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
lhs
.
data
<
scalar_t
>
(),
ograd
.
data
<
scalar_t
>
(),
rhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
rhs_grad
.
data
<
scalar_t
>
(),
M
,
K
,
N
,
scale
);
}));
}));
...
...
imperative/python/test/unit/core/test_custom_op.py
浏览文件 @
8a692573
...
@@ -6,92 +6,132 @@ import sys
...
@@ -6,92 +6,132 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
megengine
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.optimizer
as
optim
from
megengine
import
jit
from
megengine
import
jit
from
megengine.autodiff
import
Function
,
GradManager
from
megengine.autodiff
import
Function
,
GradManager
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
custom
from
megengine.core.ops
import
custom
from
megengine.device
import
get_device_count
from
megengine.device
import
get_device_count
from
megengine.module
import
Conv2d
,
Linear
,
Module
from
megengine.tensor
import
Tensor
from
megengine.random
import
normal
from
megengine.tensor
import
Parameter
,
Tensor
from
megengine.utils
import
custom_op_tools
from
megengine.utils
import
custom_op_tools
build_path
=
os
.
path
.
join
(
custom_op_tools
.
_get_default_build_root
(),
"custom_opsrc"
,
"build"
)
cur_dir_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
mgb_root_path
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
cur_dir_path
))))
)
extra_include_paths
=
[
os
.
path
.
join
(
mgb_root_path
,
"src"
,
"custom"
,
"include"
)]
def
compare
(
ref
,
real
):
extra_ld_flags
=
[]
if
ref
.
shape
!=
real
.
shape
:
if
sys
.
platform
!=
"win32"
:
real
=
real
.
T
ld_path
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
np
.
testing
.
assert_allclose
(
ref
,
real
,
rtol
=
1e-3
,
atol
=
1e-5
)
if
ld_path
!=
None
:
ld_dirs
=
ld_path
.
split
(
":"
)
for
ld_dir
in
ld_dirs
:
if
os
.
path
.
exists
(
ld_dir
)
and
os
.
path
.
isdir
(
ld_dir
):
for
lib
in
os
.
listdir
(
ld_dir
):
if
"megengine_shared"
in
lib
:
extra_ld_flags
+=
[
"-L{} -Wl,-rpath,{}"
.
format
(
ld_dir
,
ld_dir
)]
break
def
build_and_clean
(
test_func
):
def
wrapper
():
cur_dir_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
build_root_dir
=
custom_op_tools
.
_get_default_build_root
()
build_path
=
os
.
path
.
join
(
build_root_dir
,
"custom_opsrc"
,
"build"
)
if
os
.
path
.
exists
(
build_path
):
def
build_and_clean
(
*
srcs
):
shutil
.
rmtree
(
build_path
)
def
deco
(
test_func
):
custom_op_srcs
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
s
)
for
s
in
srcs
]
mgb_root_path
=
os
.
path
.
dirname
(
def
wrapper
(
*
args
,
**
kwargs
):
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
cur_dir_path
)))
)
)
extra_include_paths
=
[
os
.
path
.
join
(
mgb_root_path
,
"src"
,
"custom"
,
"include"
)]
extra_ld_flags
=
[]
if
sys
.
platform
!=
"win32"
:
ld_path
=
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
)
if
ld_path
!=
None
:
ld_dirs
=
ld_path
.
split
(
":"
)
for
ld_dir
in
ld_dirs
:
if
os
.
path
.
exists
(
ld_dir
)
and
os
.
path
.
isdir
(
ld_dir
):
for
lib
in
os
.
listdir
(
ld_dir
):
if
"megengine_shared"
in
lib
:
extra_ld_flags
+=
[
"-L{} -Wl,-rpath,{}"
.
format
(
ld_dir
,
ld_dir
)
]
break
if
get_device_count
(
"gpu"
)
>
0
:
custom_opsrc
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"matmul_scale.cpp"
),
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"matmul_scale.cu"
),
]
else
:
custom_opsrc
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"elem_add.cpp"
)]
try
:
lib_path
=
custom_op_tools
.
build_and_load
(
lib_path
=
custom_op_tools
.
build_and_load
(
"test_op"
,
"test_op"
,
custom_op
src
,
custom_op
_srcs
,
extra_include_paths
=
extra_include_paths
,
extra_include_paths
=
extra_include_paths
,
extra_ldflags
=
extra_ld_flags
,
build_dir
=
build_path
,
build_dir
=
build_path
,
verbose
=
False
,
extra_ldflags
=
extra_ld_flags
,
verbose
=
True
,
)
)
test_func
()
test_func
(
*
args
,
**
kwargs
)
custom
.
unload
(
lib_path
)
custom
.
unload
(
lib_path
)
finally
:
return
wrapper
if
os
.
path
.
exists
(
build_path
):
shutil
.
rmtree
(
build_path
)
return
wrapper
return
deco
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
>
0
,
reason
=
"elem_add operator is only supported on CPU"
get_device_count
(
"gpu"
)
>
0
,
reason
=
"elem_add operator is only supported on CPU"
)
)
@
build_and_clean
@
build_and_clean
(
"elem_add.cpp"
)
def
test_custom_op_cpu_build
():
def
test_cpu_func
():
assert
"ElemAddSmoothForward"
in
custom
.
_get_custom_op_list
()
class
ElemAddSmooth
(
Function
):
assert
"ElemAddSmoothBackward"
in
custom
.
_get_custom_op_list
()
def
__init__
(
self
,
smooth
):
assert
hasattr
(
custom
,
"ElemAddSmoothForward"
)
super
().
__init__
()
assert
hasattr
(
custom
,
"ElemAddSmoothBackward"
)
self
.
smooth
=
smooth
def
forward
(
self
,
lhs
,
rhs
):
op
=
custom
.
ElemAddSmoothForward
(
smooth
=
self
.
smooth
)
return
apply
(
op
,
lhs
,
rhs
)[
0
]
def
backward
(
self
,
ograd
):
op
=
custom
.
ElemAddSmoothBackward
()
return
apply
(
op
,
ograd
)
def
gen_elemadd_data
(
seed
,
shape
,
low
=-
1
,
high
=
1
):
rng
=
np
.
random
.
RandomState
(
seed
=
seed
)
lhs_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
shape
).
astype
(
np
.
float32
)
rhs_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
shape
).
astype
(
np
.
float32
)
ograd_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
shape
).
astype
(
np
.
float32
)
return
lhs_np
,
rhs_np
,
ograd_np
def
builtin_func
(
lhs
,
rhs
,
smooth
):
out
=
lhs
+
rhs
return
F
.
where
(
out
<
0
,
out
+
smooth
,
out
-
smooth
)
def
test_elemadd_smooth_train
(
smooth
=
0.5
,
m
=
4
,
n
=
2
,
seed
=
2021
):
lhs_np
,
rhs_np
,
ograd_np
=
gen_elemadd_data
(
seed
,
(
m
,
n
))
custom_lhs
,
custom_rhs
=
Tensor
(
lhs_np
),
Tensor
(
rhs_np
)
builtin_lhs
,
builtin_rhs
=
Tensor
(
lhs_np
),
Tensor
(
rhs_np
)
ograd_tensor
=
Tensor
(
ograd_np
)
custom_func
=
ElemAddSmooth
(
smooth
=
smooth
)
gm
=
GradManager
().
attach
([
custom_lhs
,
custom_rhs
])
with
gm
:
custom_out
=
custom_func
(
custom_lhs
,
custom_rhs
)
gm
.
backward
(
custom_out
,
ograd_tensor
)
gm
=
GradManager
().
attach
([
builtin_lhs
,
builtin_rhs
])
with
gm
:
builtin_out
=
builtin_func
(
builtin_lhs
,
builtin_rhs
,
smooth
)
gm
.
backward
(
builtin_out
,
ograd_tensor
)
np
.
testing
.
assert_allclose
(
custom_out
,
builtin_out
,
rtol
=
1e-3
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
custom_lhs
.
grad
.
numpy
(),
builtin_lhs
.
grad
.
numpy
(),
rtol
=
1e-3
,
atol
=
1e-5
)
np
.
testing
.
assert_allclose
(
custom_rhs
.
grad
.
numpy
(),
builtin_rhs
.
grad
.
numpy
(),
rtol
=
1e-3
,
atol
=
1e-5
)
def
test_elemadd_smooth_trace
(
smooth
=
0.5
,
m
=
4
,
n
=
2
,
seed
=
2021
):
@
jit
.
trace
(
capture_as_const
=
True
)
def
func_dumper
(
lhs
,
rhs
,
*
,
net
):
return
net
(
lhs
,
rhs
)
lhs_np
,
rhs_np
,
_
=
gen_elemadd_data
(
seed
,
(
m
,
n
))
lhs_tensor
=
Tensor
(
lhs_np
)
rhs_tensor
=
Tensor
(
rhs_np
)
func
=
ElemAddSmooth
(
smooth
=
smooth
)
real
=
func_dumper
(
lhs_tensor
,
rhs_tensor
,
net
=
func
)
real
=
func_dumper
(
lhs_tensor
,
rhs_tensor
,
net
=
func
)
ref
=
builtin_func
(
Tensor
(
lhs_np
),
Tensor
(
rhs_np
),
smooth
)
np
.
testing
.
assert_allclose
(
real
.
numpy
(),
ref
.
numpy
(),
rtol
=
1e-3
,
atol
=
1e-5
)
test_elemadd_smooth_train
(
0.2
,
128
,
256
,
2027
)
test_elemadd_smooth_train
(
0.3
,
256
,
128
,
2028
)
test_elemadd_smooth_train
(
0.4
,
128
,
512
,
2029
)
test_elemadd_smooth_trace
(
0.2
,
256
,
64
,
2030
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
...
@@ -101,9 +141,136 @@ def test_custom_op_cpu_build():
...
@@ -101,9 +141,136 @@ def test_custom_op_cpu_build():
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
<
1
,
reason
=
"matmul scale operator is only supported on GPU"
get_device_count
(
"gpu"
)
<
1
,
reason
=
"matmul scale operator is only supported on GPU"
)
)
@
build_and_clean
@
build_and_clean
(
"matmul_scale.cpp"
,
"matmul_scale.cu"
)
def
test_custom_op_gpu_build
():
def
test_gpu_func
():
class
MatMulScale
(
Function
):
def
__init__
(
self
,
scale
):
super
().
__init__
()
self
.
scale
=
scale
def
forward
(
self
,
lhs
,
rhs
):
op
=
custom
.
MatMulScaleForward
(
scale
=
self
.
scale
)
self
.
lhs
=
lhs
self
.
rhs
=
rhs
return
apply
(
op
,
lhs
,
rhs
)[
0
]
def
backward
(
self
,
ograd
):
op
=
custom
.
MatMulScaleBackward
(
scale
=
self
.
scale
)
return
apply
(
op
,
ograd
,
self
.
lhs
,
self
.
rhs
)
def
gen_matmul_data
(
seed
,
m
,
k
,
n
,
low
=-
0.5
,
high
=
0.5
,
dtype
=
np
.
float32
):
rng
=
np
.
random
.
RandomState
(
seed
=
seed
)
lhs_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
(
m
,
k
)).
astype
(
dtype
)
rhs_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
(
k
,
n
)).
astype
(
dtype
)
ograd_np
=
rng
.
uniform
(
low
=
low
,
high
=
high
,
size
=
(
m
,
n
)).
astype
(
dtype
)
scale
=
rng
.
uniform
(
low
=
0.1
,
high
=
0.9
,
size
=
(
1
)).
astype
(
np
.
float32
)[
0
]
return
lhs_np
,
rhs_np
,
ograd_np
,
scale
def
builtin_func
(
lhs
,
rhs
,
scale
):
out
=
F
.
matmul
(
lhs
,
rhs
)
*
scale
return
out
def
test_matmul_scale
(
m
=
1
,
k
=
1
,
n
=
1
,
seed
=
2021
):
lhs_np
,
rhs_np
,
_
,
scale
=
gen_matmul_data
(
seed
,
m
,
k
,
n
)
custom_lhs
,
custom_rhs
=
Tensor
(
lhs_np
),
Tensor
(
rhs_np
)
builtin_lhs
,
builtin_rhs
=
Tensor
(
lhs_np
),
Tensor
(
rhs_np
)
custom_func
=
MatMulScale
(
scale
=
scale
)
custom_out
=
custom_func
(
custom_lhs
,
custom_rhs
)
builtin_out
=
builtin_func
(
builtin_lhs
,
builtin_rhs
,
scale
)
np
.
testing
.
assert_allclose
(
custom_out
,
builtin_out
,
rtol
=
1e-3
,
atol
=
1e-5
)
def
test_matmul_scale_trace
(
m
=
1
,
k
=
1
,
n
=
1
,
seed
=
2021
):
@
jit
.
trace
(
capture_as_const
=
True
)
def
func_dumper
(
lhs
,
rhs
,
*
,
net
):
return
net
(
lhs
,
rhs
)
lhs_np
,
rhs_np
,
_
,
scale
=
gen_matmul_data
(
seed
,
m
,
k
,
n
)
lhs_tensor
,
rhs_tensor
=
Tensor
(
lhs_np
),
Tensor
(
rhs_np
)
func
=
MatMulScale
(
scale
=
scale
)
real
=
func_dumper
(
lhs_tensor
,
rhs_tensor
,
net
=
func
)
real
=
func_dumper
(
lhs_tensor
,
rhs_tensor
,
net
=
func
)
ref
=
builtin_func
(
Tensor
(
lhs_np
),
Tensor
(
rhs_np
),
scale
)
np
.
testing
.
assert_allclose
(
real
.
numpy
(),
ref
.
numpy
(),
rtol
=
1e-3
,
atol
=
1e-5
)
test_matmul_scale
(
128
,
256
,
64
,
2028
)
test_matmul_scale
(
64
,
32
,
16
,
2029
)
test_matmul_scale_trace
(
64
,
32
,
16
,
2030
)
@
pytest
.
mark
.
skipif
(
get_device_count
(
"gpu"
)
<
1
,
reason
=
"matmul scale operator is only supported on GPU"
)
def
test_custom_op
():
org_op_list
=
custom
.
_get_custom_op_list
()
assert
len
(
custom
.
_get_custom_op_lib_info
())
==
0
assert
"ElemAddSmoothForward"
not
in
custom
.
_get_custom_op_list
()
assert
not
hasattr
(
custom
,
"ElemAddSmoothForward"
)
assert
"MatMulScaleForward"
not
in
custom
.
_get_custom_op_list
()
assert
not
hasattr
(
custom
,
"MatMulScaleForward"
)
srcs1
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
"elem_add.cpp"
)]
lib_path1
=
custom_op_tools
.
build_and_load
(
"elem"
,
srcs1
,
extra_include_paths
=
extra_include_paths
,
build_dir
=
build_path
,
extra_ldflags
=
extra_ld_flags
,
verbose
=
True
,
)
assert
"ElemAddSmoothForward"
in
custom
.
_get_custom_op_list
()
assert
hasattr
(
custom
,
"ElemAddSmoothForward"
)
assert
lib_path1
in
custom
.
_get_custom_op_lib_info
()
assert
"ElemAddSmoothForward"
in
custom
.
_get_custom_op_lib_info
()[
lib_path1
]
srcs2
=
[
os
.
path
.
join
(
cur_dir_path
,
"custom_opsrc"
,
src
)
for
src
in
[
"matmul_scale.cpp"
,
"matmul_scale.cu"
]
]
lib_path2
=
custom_op_tools
.
build_and_load
(
"matmul"
,
srcs2
,
extra_include_paths
=
extra_include_paths
,
build_dir
=
build_path
,
extra_ldflags
=
extra_ld_flags
,
verbose
=
True
,
)
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_list
()
assert
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
lib_path2
in
custom
.
_get_custom_op_lib_info
()
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_lib_info
()[
lib_path2
]
assert
len
(
custom
.
_get_custom_op_list
())
==
len
(
org_op_list
)
+
4
custom
.
unload
(
lib_path1
)
assert
"ElemAddSmoothForward"
not
in
custom
.
_get_custom_op_list
()
assert
not
hasattr
(
custom
,
"ElemAddSmoothForward"
)
assert
lib_path1
not
in
custom
.
_get_custom_op_lib_info
()
custom
.
unload
(
lib_path2
)
assert
"MatMulScaleForward"
not
in
custom
.
_get_custom_op_list
()
assert
not
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
lib_path1
not
in
custom
.
_get_custom_op_lib_info
()
assert
len
(
custom
.
_get_custom_op_lib_info
())
==
0
assert
custom
.
_get_custom_op_list
()
==
org_op_list
custom
.
load
(
lib_path2
)
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_list
()
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_list
()
assert
"MatMulScaleBackward"
in
custom
.
_get_custom_op_list
()
assert
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
hasattr
(
custom
,
"MatMulScaleBackward"
)
assert
lib_path2
in
custom
.
_get_custom_op_lib_info
()
assert
"MatMulScaleForward"
in
custom
.
_get_custom_op_lib_info
()[
lib_path2
]
custom
.
unload
(
lib_path2
)
assert
"MatMulScaleForward"
not
in
custom
.
_get_custom_op_list
()
assert
not
hasattr
(
custom
,
"MatMulScaleForward"
)
assert
lib_path1
not
in
custom
.
_get_custom_op_lib_info
()
assert
len
(
custom
.
_get_custom_op_lib_info
())
==
0
assert
custom
.
_get_custom_op_list
()
==
org_op_list
imperative/src/impl/ops/custom_opdef.cpp
浏览文件 @
8a692573
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
#include "../op_trait.h"
#include "../op_trait.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/opr/custom_opnode.h"
#include "megbrain/opr/custom_opnode.h"
namespace
mgb
{
namespace
mgb
{
...
@@ -51,13 +51,9 @@ const std::shared_ptr<const custom::CustomOp>& CustomOpDef::impl(void) const {
...
@@ -51,13 +51,9 @@ const std::shared_ptr<const custom::CustomOp>& CustomOpDef::impl(void) const {
}
}
void
CustomOpDef
::
compute
(
void
CustomOpDef
::
compute
(
const
SmallVector
<
DeviceTensorND
>&
inputs
,
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
const
{
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
outputs
)
const
{
std
::
vector
<
custom
::
Tensor
>
custom_inputs
=
custom
::
dispatch_custom_op
(
m_op
,
m_param
,
inputs
,
outputs
);
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
inputs
);
std
::
vector
<
custom
::
Tensor
>
custom_outputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
*
outputs
);
m_op
->
compute
(
custom_inputs
,
this
->
m_param
,
custom_outputs
);
}
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
CustomOpDef
::
infer_output_attrs
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
CustomOpDef
::
infer_output_attrs
(
...
@@ -169,13 +165,6 @@ std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(
...
@@ -169,13 +165,6 @@ std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(
namespace
custom_opdef
{
// avoid name conflict
namespace
custom_opdef
{
// avoid name conflict
void
apply_on_device_tensornd
(
const
OpDef
&
def
,
const
SmallVector
<
DeviceTensorND
>&
inputs
,
SmallVector
<
DeviceTensorND
>*
outputs
)
{
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
op
.
compute
(
inputs
,
outputs
);
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
...
@@ -194,15 +183,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
...
@@ -194,15 +183,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
output
=
Tensor
::
make
(
output_descs
[
i
].
layout
,
output_descs
[
i
].
comp_node
);
output
=
Tensor
::
make
(
output_descs
[
i
].
layout
,
output_descs
[
i
].
comp_node
);
}
}
SmallVector
<
DeviceTensorND
>
inp_tensornds
(
inputs
.
size
());
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
inp_tensornds
=
SmallVector
<
DeviceTensorND
>
oup_tensornds
(
outputs
.
size
());
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
();
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
oup_tensornds
=
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
();
inp_tensornds
[
i
]
=
inputs
[
i
]
->
dev_tensor
();
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
inp_tensornds
->
emplace_back
(
inputs
[
i
]
->
dev_tensor
(
true
));
oup_tensornds
[
i
]
=
outputs
[
i
]
->
dev_tensor
();
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
oup_tensornds
->
emplace_back
(
outputs
[
i
]
->
dev_tensor
(
true
));
}
apply_on_device_tensornd
(
def
,
inp_tensornds
,
&
oup_tensornds
);
auto
&&
op
=
static_cast
<
const
CustomOpDef
&>
(
def
);
op
.
compute
(
inp_tensornds
,
oup_tensornds
);
return
outputs
;
return
outputs
;
}
}
...
@@ -258,7 +251,6 @@ std::string make_name(const OpDef& def) {
...
@@ -258,7 +251,6 @@ std::string make_name(const OpDef& def) {
OP_TRAIT_REG
(
CustomOpDef
,
CustomOpDef
)
OP_TRAIT_REG
(
CustomOpDef
,
CustomOpDef
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_device_tensornd
(
apply_on_device_tensornd
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
hash
(
hash
)
.
hash
(
hash
)
.
is_same_st
(
is_same_st
)
.
is_same_st
(
is_same_st
)
...
...
imperative/src/include/megbrain/imperative/ops/custom_opdef.h
浏览文件 @
8a692573
...
@@ -31,7 +31,8 @@ public:
...
@@ -31,7 +31,8 @@ public:
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>&
impl
(
void
)
const
;
const
std
::
shared_ptr
<
const
custom
::
CustomOp
>&
impl
(
void
)
const
;
void
compute
(
void
compute
(
const
SmallVector
<
DeviceTensorND
>&
,
SmallVector
<
DeviceTensorND
>*
)
const
;
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
,
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
)
const
;
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
;
const
SmallVector
<
TensorPtr
>&
inputs
)
const
;
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs
(
...
...
src/custom/impl/manager.cpp
浏览文件 @
8a692573
...
@@ -32,6 +32,39 @@ const char* dlerror(void) {
...
@@ -32,6 +32,39 @@ const char* dlerror(void) {
}
}
#endif
#endif
CustomLib
::
CustomLib
(
const
std
::
string
&
path
,
int
mode
=
RTLD_LAZY
)
:
m_handle
(
nullptr
,
[](
void
*
handle
)
{
dlclose
(
handle
);
})
{
auto
op_list_before_load
=
CustomOpManager
::
inst
()
->
op_name_list
();
std
::
unordered_set
<
std
::
string
>
op_set_before_load
(
op_list_before_load
.
begin
(),
op_list_before_load
.
end
());
m_handle
.
reset
(
dlopen
(
path
.
c_str
(),
mode
));
mgb_assert
(
m_handle
!=
nullptr
,
"open custom op lib failed, error type: %s"
,
dlerror
());
auto
op_list_after_load
=
CustomOpManager
::
inst
()
->
op_name_list
();
for
(
auto
&
op
:
op_list_after_load
)
{
if
(
op_set_before_load
.
find
(
op
)
==
op_set_before_load
.
end
())
{
m_ops
.
emplace_back
(
op
);
}
}
}
CustomLib
::~
CustomLib
()
{
for
(
auto
&
op
:
m_ops
)
{
CustomOpManager
::
inst
()
->
erase
(
op
);
}
}
const
std
::
vector
<
std
::
string
>&
CustomLib
::
ops_in_lib
(
void
)
const
{
return
m_ops
;
}
bool
CustomLib
::
valid
()
const
{
return
m_handle
!=
nullptr
;
}
CustomOpManager
*
CustomOpManager
::
inst
(
void
)
{
CustomOpManager
*
CustomOpManager
::
inst
(
void
)
{
static
CustomOpManager
op_manager
;
static
CustomOpManager
op_manager
;
return
&
op_manager
;
return
&
op_manager
;
...
@@ -39,12 +72,40 @@ CustomOpManager* CustomOpManager::inst(void) {
...
@@ -39,12 +72,40 @@ CustomOpManager* CustomOpManager::inst(void) {
CustomOpManager
::~
CustomOpManager
()
{
CustomOpManager
::~
CustomOpManager
()
{
mgb_assert
(
m_name2op
.
size
()
==
m_id2op
.
size
(),
"Custom Op maintenance error!"
);
mgb_assert
(
m_name2op
.
size
()
==
m_id2op
.
size
(),
"Custom Op maintenance error!"
);
LibManager
::
inst
()
->
m_custom_libs
.
clear
();
{
MGB_LOCK_GUARD
(
m_lib_mtx
);
m_custom_libs
.
clear
();
}
mgb_assert
(
m_name2op
.
size
()
==
m_id2op
.
size
(),
"Custom Op maintenance error!"
);
MGB_LOCK_GUARD
(
m_op_mtx
);
m_name2op
.
clear
();
m_id2op
.
clear
();
}
const
std
::
vector
<
std
::
string
>&
CustomOpManager
::
install
(
const
std
::
string
&
name
,
const
std
::
string
&
path
)
{
MGB_LOCK_GUARD
(
m_lib_mtx
);
LibHandle
handle
=
std
::
make_shared
<
CustomLib
>
(
path
);
m_custom_libs
.
insert
({
name
,
handle
});
return
m_custom_libs
[
name
]
->
ops_in_lib
();
}
std
::
vector
<
std
::
string
>
CustomOpManager
::
uninstall
(
const
std
::
string
&
name
)
{
MGB_LOCK_GUARD
(
m_lib_mtx
);
std
::
vector
<
std
::
string
>
op_names
=
m_custom_libs
[
name
]
->
ops_in_lib
();
mgb_assert
(
m_custom_libs
.
erase
(
name
)
==
1
,
"uninstall error"
);
return
op_names
;
}
const
std
::
unordered_map
<
std
::
string
,
LibHandle
>&
CustomOpManager
::
lib_info
(
void
)
const
{
return
m_custom_libs
;
}
}
std
::
shared_ptr
<
CustomOp
>
CustomOpManager
::
insert
(
std
::
shared_ptr
<
CustomOp
>
CustomOpManager
::
insert
(
const
std
::
string
&
name
,
uint32_t
version
)
{
const
std
::
string
&
name
,
uint32_t
version
)
{
MGB_LOCK_GUARD
(
m_mtx
);
MGB_LOCK_GUARD
(
m_
op_
mtx
);
auto
iter
=
m_name2op
.
find
(
name
);
auto
iter
=
m_name2op
.
find
(
name
);
if
(
iter
!=
m_name2op
.
end
())
{
if
(
iter
!=
m_name2op
.
end
())
{
mgb_log_warn
(
mgb_log_warn
(
...
@@ -59,7 +120,7 @@ std::shared_ptr<CustomOp> CustomOpManager::insert(
...
@@ -59,7 +120,7 @@ std::shared_ptr<CustomOp> CustomOpManager::insert(
}
}
bool
CustomOpManager
::
erase
(
const
std
::
string
&
name
)
{
bool
CustomOpManager
::
erase
(
const
std
::
string
&
name
)
{
MGB_LOCK_GUARD
(
m_mtx
);
MGB_LOCK_GUARD
(
m_
op_
mtx
);
auto
iter
=
m_name2op
.
find
(
name
);
auto
iter
=
m_name2op
.
find
(
name
);
if
(
iter
==
m_name2op
.
end
())
{
if
(
iter
==
m_name2op
.
end
())
{
mgb_log_warn
(
mgb_log_warn
(
...
@@ -72,28 +133,6 @@ bool CustomOpManager::erase(const std::string& name) {
...
@@ -72,28 +133,6 @@ bool CustomOpManager::erase(const std::string& name) {
return
true
;
return
true
;
}
}
bool
CustomOpManager
::
erase
(
const
RunTimeId
&
id
)
{
MGB_LOCK_GUARD
(
m_mtx
);
auto
iter
=
m_id2op
.
find
(
id
);
if
(
iter
==
m_id2op
.
end
())
{
mgb_log_warn
(
"Erase Custom Op Failed! The Op has not been registered"
);
return
false
;
}
std
::
shared_ptr
<
const
CustomOp
>
op
=
iter
->
second
;
m_id2op
.
erase
(
op
->
runtime_id
());
m_name2op
.
erase
(
op
->
op_type
());
return
true
;
}
std
::
shared_ptr
<
CustomOp
>
CustomOpManager
::
find_or_reg
(
const
std
::
string
&
name
,
uint32_t
version
)
{
auto
iter
=
m_name2op
.
find
(
name
);
if
(
iter
==
m_name2op
.
end
())
{
return
insert
(
name
,
version
);
}
return
std
::
const_pointer_cast
<
CustomOp
,
const
CustomOp
>
(
iter
->
second
);
}
RunTimeId
CustomOpManager
::
to_id
(
const
std
::
string
&
name
)
const
{
RunTimeId
CustomOpManager
::
to_id
(
const
std
::
string
&
name
)
const
{
std
::
shared_ptr
<
const
CustomOp
>
op
=
find
(
name
);
std
::
shared_ptr
<
const
CustomOp
>
op
=
find
(
name
);
return
op
->
runtime_id
();
return
op
->
runtime_id
();
...
@@ -135,60 +174,6 @@ std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
...
@@ -135,60 +174,6 @@ std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
return
ret
;
return
ret
;
}
}
CustomLib
::
CustomLib
(
const
std
::
string
&
path
,
int
mode
=
RTLD_LAZY
)
:
m_handle
(
nullptr
,
[](
void
*
handle
)
{
dlclose
(
handle
);
})
{
auto
op_list_before_load
=
CustomOpManager
::
inst
()
->
op_name_list
();
std
::
unordered_set
<
std
::
string
>
op_set_before_load
(
op_list_before_load
.
begin
(),
op_list_before_load
.
end
());
m_handle
.
reset
(
dlopen
(
path
.
c_str
(),
mode
));
mgb_assert
(
m_handle
!=
nullptr
,
"open custom op lib failed, error type: %s"
,
dlerror
());
auto
op_list_after_load
=
CustomOpManager
::
inst
()
->
op_name_list
();
for
(
auto
&
op
:
op_list_after_load
)
{
if
(
op_set_before_load
.
find
(
op
)
==
op_set_before_load
.
end
())
{
m_ops
.
emplace_back
(
op
);
}
}
}
const
std
::
vector
<
std
::
string
>&
CustomLib
::
ops_in_lib
(
void
)
const
{
return
m_ops
;
}
CustomLib
::~
CustomLib
()
{
for
(
auto
&
op
:
m_ops
)
{
CustomOpManager
::
inst
()
->
erase
(
op
);
}
}
bool
CustomLib
::
valid
()
const
{
return
m_handle
!=
nullptr
;
}
LibManager
*
LibManager
::
inst
(
void
)
{
static
LibManager
custom_libs
;
return
&
custom_libs
;
}
const
std
::
vector
<
std
::
string
>&
LibManager
::
install
(
const
std
::
string
&
name
,
const
std
::
string
&
path
)
{
MGB_LOCK_GUARD
(
m_mtx
);
;
LibHandle
handle
=
std
::
make_shared
<
CustomLib
>
(
path
);
m_custom_libs
.
insert
({
name
,
handle
});
return
m_custom_libs
[
name
]
->
ops_in_lib
();
}
bool
LibManager
::
uninstall
(
const
std
::
string
&
name
)
{
MGB_LOCK_GUARD
(
m_mtx
);
;
mgb_assert
(
m_custom_libs
.
erase
(
name
)
==
1
,
"uninstall error"
);
return
true
;
}
std
::
shared_ptr
<
CustomOp
>
op_insert
(
std
::
string
opname
,
uint32_t
version
)
{
std
::
shared_ptr
<
CustomOp
>
op_insert
(
std
::
string
opname
,
uint32_t
version
)
{
return
CustomOpManager
::
inst
()
->
insert
(
opname
,
version
);
return
CustomOpManager
::
inst
()
->
insert
(
opname
,
version
);
}
}
...
...
src/custom/impl/op.cpp
浏览文件 @
8a692573
...
@@ -4,8 +4,11 @@
...
@@ -4,8 +4,11 @@
#include <sstream>
#include <sstream>
#include <unordered_set>
#include <unordered_set>
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/op.h"
#include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h"
#include "megbrain/custom/utils.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/thin/function.h"
#include "megbrain/utils/thin/function.h"
using
namespace
mgb
;
using
namespace
mgb
;
...
@@ -550,6 +553,45 @@ void CustomOp::compute(
...
@@ -550,6 +553,45 @@ void CustomOp::compute(
assert_outputs_size_right
(
outputs
);
assert_outputs_size_right
(
outputs
);
}
}
void
compute_impl
(
std
::
shared_ptr
<
const
CustomOp
>
op
,
const
Param
&
param
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
inputs
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
outputs
)
{
std
::
vector
<
custom
::
Tensor
>
custom_inputs
;
for
(
size_t
i
=
0
;
i
<
inputs
->
size
();
++
i
)
{
custom_inputs
.
emplace_back
(
to_custom_tensor
(
inputs
->
operator
[](
i
)));
}
std
::
vector
<
custom
::
Tensor
>
custom_outputs
;
for
(
size_t
i
=
0
;
i
<
outputs
->
size
();
++
i
)
{
custom_outputs
.
emplace_back
(
to_custom_tensor
(
outputs
->
operator
[](
i
)));
}
op
->
compute
(
custom_inputs
,
param
,
custom_outputs
);
}
void
dispatch_custom_op
(
std
::
shared_ptr
<
const
CustomOp
>
op
,
const
Param
&
param
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
inputs
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
outputs
)
{
if
(
outputs
->
size
()
==
0
)
{
return
;
}
auto
compnode
=
outputs
->
at
(
0
).
comp_node
();
if
(
compnode
.
device_type
()
==
CompNode
::
DeviceType
::
CPU
)
{
auto
&&
cpu_env
=
CompNodeEnv
::
from_comp_node
(
compnode
).
cpu_env
();
cpu_env
.
dispatch
([
op
,
param
,
inputs
,
outputs
]()
{
compute_impl
(
op
,
param
,
inputs
,
outputs
);
});
}
else
{
mgb_assert
(
compnode
.
device_type
()
==
CompNode
::
DeviceType
::
CUDA
,
"custom op only support cuda/cpu now, but get %s"
,
compnode
.
to_string
().
c_str
());
compute_impl
(
op
,
param
,
inputs
,
outputs
);
}
}
}
// namespace custom
}
// namespace custom
#endif
#endif
src/custom/impl/param_val.cpp
浏览文件 @
8a692573
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
#include "megbrain/comp_node.h"
#include "megbrain/comp_node.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/param_val.h"
#include "megbrain/custom/param_val.h"
#include "megbrain/custom/tensor.h"
#include "megbrain/custom/tensor.h"
...
@@ -40,7 +40,7 @@ namespace custom {
...
@@ -40,7 +40,7 @@ namespace custom {
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \
mgb_assert( \
mgb_assert( \
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \
type2name[lhs.m_type].c_str(), #op, type2name[rhs.m_type]
.c_str())
ptype2name(lhs.m_type).c_str(), #op, ptype2name(rhs.m_type)
.c_str())
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \
case (ParamDynType::dyn_type): { \
case (ParamDynType::dyn_type): { \
...
@@ -177,6 +177,18 @@ namespace custom {
...
@@ -177,6 +177,18 @@ namespace custom {
break; \
break; \
}
}
std
::
string
ptype2name
(
ParamDynType
ptype
)
{
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \
{ParamDynType::dyn_type, #dyn_type},
static
std
::
unordered_map
<
ParamDynType
,
std
::
string
,
EnumHash
<
ParamDynType
>
,
EnumCmp
<
ParamDynType
>>
type2name
=
{
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_DYN_PARAMTYPE_NAME
){
ParamDynType
::
Invalid
,
"Invalid"
}};
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME
return
type2name
[
ptype
];
}
ParamVal
::
ParamVal
()
:
m_ptr
(
nullptr
,
[](
void
*
)
->
void
{})
{
ParamVal
::
ParamVal
()
:
m_ptr
(
nullptr
,
[](
void
*
)
->
void
{})
{
m_type
=
ParamDynType
::
Invalid
;
m_type
=
ParamDynType
::
Invalid
;
}
}
...
@@ -265,7 +277,7 @@ ParamDynType ParamVal::type(void) const {
...
@@ -265,7 +277,7 @@ ParamDynType ParamVal::type(void) const {
std
::
string
ParamVal
::
str
()
const
{
std
::
string
ParamVal
::
str
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"type: "
<<
type2name
[
m_type
]
<<
"
\n
"
ss
<<
"type: "
<<
ptype2name
(
m_type
)
<<
"
\n
"
<<
"value: "
;
<<
"value: "
;
switch
(
m_type
)
{
switch
(
m_type
)
{
CUSTOM_FOR_EACH_BASIC_PARAMTYPE
(
CUSTOM_CASE_TO_PRINT_NONLIST
)
CUSTOM_FOR_EACH_BASIC_PARAMTYPE
(
CUSTOM_CASE_TO_PRINT_NONLIST
)
...
...
src/custom/impl/platform/custom_cuda.cpp
浏览文件 @
8a692573
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
#include "megbrain/comp_node_env.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"
#include "megbrain/custom/platform/custom_cuda.h"
using
namespace
mgb
;
using
namespace
mgb
;
...
...
src/custom/impl/tensor.cpp
浏览文件 @
8a692573
...
@@ -42,31 +42,33 @@ using TensorImpl = DeviceTensorND;
...
@@ -42,31 +42,33 @@ using TensorImpl = DeviceTensorND;
#define TensorImplConstRef(rawptr) \
#define TensorImplConstRef(rawptr) \
static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr))
static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr))
static
std
::
unordered_map
<
struct
DeviceMapper
{
DeviceImpl
::
DeviceType
,
std
::
string
,
EnumHash
<
DeviceImpl
::
DeviceType
>
,
using
DeviceTy
=
DeviceImpl
::
DeviceType
;
EnumCmp
<
DeviceImpl
::
DeviceType
>>
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dev_cstr2bstr
;
dev_benum2cstr
;
EnumMap
<
DeviceTy
,
std
::
string
>
dev_benum2cstr
;
static
std
::
unordered_map
<
EnumMap
<
DeviceTy
,
DeviceEnum
>
dev_benum2cenum
;
DeviceImpl
::
DeviceType
,
DeviceEnum
,
EnumHash
<
DeviceImpl
::
DeviceType
>
,
EnumMap
<
DeviceEnum
,
std
::
string
>
dev_cenum2bstr
;
EnumCmp
<
DeviceImpl
::
DeviceType
>>
static
DeviceMapper
&
inst
();
dev_benum2cenum
;
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
dev_cstr2bstr
;
private:
static
std
::
unordered_map
<
DeviceMapper
();
DeviceEnum
,
std
::
string
,
EnumHash
<
DeviceEnum
>
,
EnumCmp
<
DeviceEnum
>>
};
dev_cenum2bstr
;
DeviceMapper
::
DeviceMapper
()
{
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \
auto be2cs##custom_impl = dev_benum2cstr.emplace( \
dev_benum2cstr.emplace(DeviceTy::builtin_device, std::string(#custom_impl)); \
DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \
dev_benum2cenum.emplace(DeviceTy::builtin_device, DeviceEnum::custom_impl); \
auto be2ce##custom_impl = dev_benum2cenum.emplace( \
dev_cstr2bstr.emplace(std::string(#custom_impl), std::string(builtin_str)); \
DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \
dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str));
auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \
std::string(#custom_impl), std::string(builtin_str)); \
CUSTOM_FOR_EACH_DEVICE_TYPE
(
CUSTOM_BIND_DEVICE
)
auto ce2bs##custom_impl = \
dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str));
CUSTOM_FOR_EACH_DEVICE_TYPE
(
CUSTOM_BIND_DEVICE
)
#undef CUSTOM_BIND_DEVICE
#undef CUSTOM_BIND_DEVICE
}
DeviceMapper
&
DeviceMapper
::
inst
()
{
static
DeviceMapper
dm
;
return
dm
;
}
CUSTOM_PIMPL_CLS_DEFINE
(
Device
)
CUSTOM_PIMPL_CLS_DEFINE
(
Device
)
...
@@ -81,6 +83,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
...
@@ -81,6 +83,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
return
;
return
;
}
}
auto
&&
dev_benum2cenum
=
DeviceMapper
::
inst
().
dev_benum2cenum
;
auto
builtin_device_enum
=
DeviceImplConstRef
(
impl
).
device_type
();
auto
builtin_device_enum
=
DeviceImplConstRef
(
impl
).
device_type
();
mgb_assert
(
mgb_assert
(
dev_benum2cenum
.
find
(
builtin_device_enum
)
!=
dev_benum2cenum
.
end
(),
dev_benum2cenum
.
find
(
builtin_device_enum
)
!=
dev_benum2cenum
.
end
(),
...
@@ -91,7 +94,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
...
@@ -91,7 +94,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
Device
::
Device
(
const
std
::
string
&
device
)
:
m_impl
(
nullptr
,
impl_deleter
<
DeviceImpl
>
)
{
Device
::
Device
(
const
std
::
string
&
device
)
:
m_impl
(
nullptr
,
impl_deleter
<
DeviceImpl
>
)
{
mgb_assert
(
is_legal
(
device
),
"invalid device type: %s"
,
device
.
c_str
());
mgb_assert
(
is_legal
(
device
),
"invalid device type: %s"
,
device
.
c_str
());
std
::
string
builtin_device
=
dev_cstr2bstr
[
device
];
std
::
string
builtin_device
=
DeviceMapper
::
inst
().
dev_cstr2bstr
[
device
];
m_impl
.
reset
(
new
DeviceImpl
(
DeviceImpl
::
load
(
builtin_device
)));
m_impl
.
reset
(
new
DeviceImpl
(
DeviceImpl
::
load
(
builtin_device
)));
}
}
...
@@ -100,7 +103,7 @@ Device::Device(const char* device) : Device(std::string(device)) {}
...
@@ -100,7 +103,7 @@ Device::Device(const char* device) : Device(std::string(device)) {}
Device
::
Device
(
DeviceEnum
device
)
:
m_impl
(
nullptr
,
impl_deleter
<
DeviceImpl
>
)
{
Device
::
Device
(
DeviceEnum
device
)
:
m_impl
(
nullptr
,
impl_deleter
<
DeviceImpl
>
)
{
mgb_assert
(
is_legal
(
device
),
"invalid device type"
);
mgb_assert
(
is_legal
(
device
),
"invalid device type"
);
std
::
string
builtin_device
=
dev_cenum2bstr
[
device
];
std
::
string
builtin_device
=
DeviceMapper
::
inst
().
dev_cenum2bstr
[
device
];
m_impl
.
reset
(
new
DeviceImpl
(
DeviceImpl
::
load
(
builtin_device
)));
m_impl
.
reset
(
new
DeviceImpl
(
DeviceImpl
::
load
(
builtin_device
)));
}
}
...
@@ -110,6 +113,7 @@ std::string Device::str(void) const {
...
@@ -110,6 +113,7 @@ std::string Device::str(void) const {
}
}
auto
builtin_device_type
=
DeviceImplRef
(
m_impl
.
get
()).
device_type
();
auto
builtin_device_type
=
DeviceImplRef
(
m_impl
.
get
()).
device_type
();
auto
&&
dev_benum2cstr
=
DeviceMapper
::
inst
().
dev_benum2cstr
;
auto
iter
=
dev_benum2cstr
.
find
(
builtin_device_type
);
auto
iter
=
dev_benum2cstr
.
find
(
builtin_device_type
);
mgb_assert
(
mgb_assert
(
iter
!=
dev_benum2cstr
.
end
(),
"invalid device type %s
\n
"
,
iter
!=
dev_benum2cstr
.
end
(),
"invalid device type %s
\n
"
,
...
@@ -123,6 +127,7 @@ DeviceEnum Device::enumv(void) const {
...
@@ -123,6 +127,7 @@ DeviceEnum Device::enumv(void) const {
"cannot get the enum value of invalid device"
);
"cannot get the enum value of invalid device"
);
auto
builtin_device_type
=
DeviceImplRef
(
m_impl
.
get
()).
device_type
();
auto
builtin_device_type
=
DeviceImplRef
(
m_impl
.
get
()).
device_type
();
auto
&&
dev_benum2cenum
=
DeviceMapper
::
inst
().
dev_benum2cenum
;
auto
iter
=
dev_benum2cenum
.
find
(
builtin_device_type
);
auto
iter
=
dev_benum2cenum
.
find
(
builtin_device_type
);
mgb_assert
(
mgb_assert
(
iter
!=
dev_benum2cenum
.
end
(),
"invalid device type %s
\n
"
,
iter
!=
dev_benum2cenum
.
end
(),
"invalid device type %s
\n
"
,
...
@@ -131,16 +136,18 @@ DeviceEnum Device::enumv(void) const {
...
@@ -131,16 +136,18 @@ DeviceEnum Device::enumv(void) const {
}
}
bool
Device
::
is_legal
(
const
std
::
string
&
device_type
)
{
bool
Device
::
is_legal
(
const
std
::
string
&
device_type
)
{
auto
&&
dev_cstr2bstr
=
DeviceMapper
::
inst
().
dev_cstr2bstr
;
return
dev_cstr2bstr
.
find
(
device_type
)
!=
dev_cstr2bstr
.
end
();
return
dev_cstr2bstr
.
find
(
device_type
)
!=
dev_cstr2bstr
.
end
();
}
}
bool
Device
::
is_legal
(
DeviceEnum
device_type
)
{
bool
Device
::
is_legal
(
DeviceEnum
device_type
)
{
auto
&&
dev_cenum2bstr
=
DeviceMapper
::
inst
().
dev_cenum2bstr
;
return
dev_cenum2bstr
.
find
(
device_type
)
!=
dev_cenum2bstr
.
end
();
return
dev_cenum2bstr
.
find
(
device_type
)
!=
dev_cenum2bstr
.
end
();
}
}
std
::
vector
<
std
::
string
>
Device
::
legal_devices
(
void
)
{
std
::
vector
<
std
::
string
>
Device
::
legal_devices
(
void
)
{
std
::
vector
<
std
::
string
>
ret
;
std
::
vector
<
std
::
string
>
ret
;
for
(
const
auto
&
kv
:
dev_cstr2bstr
)
{
for
(
const
auto
&
kv
:
DeviceMapper
::
inst
().
dev_cstr2bstr
)
{
ret
.
emplace_back
(
kv
.
first
);
ret
.
emplace_back
(
kv
.
first
);
}
}
return
ret
;
return
ret
;
...
@@ -197,36 +204,37 @@ bool operator==(const Shape& lhs, const Shape& rhs) {
...
@@ -197,36 +204,37 @@ bool operator==(const Shape& lhs, const Shape& rhs) {
return
ShapeImplRef
(
lhs
.
m_impl
.
get
()).
eq_shape
(
ShapeImplRef
(
rhs
.
m_impl
.
get
()));
return
ShapeImplRef
(
lhs
.
m_impl
.
get
()).
eq_shape
(
ShapeImplRef
(
rhs
.
m_impl
.
get
()));
}
}
static
std
::
unordered_map
<
std
::
string
,
megdnn
::
DTypeEnum
>
dtype_cstr2benum
;
struct
DTypeMapper
{
static
std
::
unordered_map
<
using
CustomEnum
=
DTypeEnum
;
DTypeEnum
,
megdnn
::
DTypeEnum
,
EnumHash
<
DTypeEnum
>
,
EnumCmp
<
DTypeEnum
>>
using
BuiltinEnum
=
megdnn
::
DTypeEnum
;
dtype_cenum2benum
;
static
std
::
unordered_map
<
std
::
unordered_map
<
std
::
string
,
BuiltinEnum
>
dtype_cstr2benum
;
megdnn
::
DTypeEnum
,
std
::
string
,
EnumHash
<
megdnn
::
DTypeEnum
>
,
EnumMap
<
DTypeEnum
,
BuiltinEnum
>
dtype_cenum2benum
;
EnumCmp
<
megdnn
::
DTypeEnum
>>
EnumMap
<
BuiltinEnum
,
std
::
string
>
dtype_benum2cstr
;
dtype_benum2cstr
;
EnumMap
<
BuiltinEnum
,
DTypeEnum
>
dtype_benum2cenum
;
static
std
::
unordered_map
<
EnumMap
<
DTypeEnum
,
std
::
string
>
dtype_cenum2cstr
;
megdnn
::
DTypeEnum
,
DTypeEnum
,
EnumHash
<
megdnn
::
DTypeEnum
>
,
static
DTypeMapper
&
inst
();
EnumCmp
<
megdnn
::
DTypeEnum
>>
dtype_benum2cenum
;
private:
static
std
::
unordered_map
<
DTypeMapper
();
DTypeEnum
,
std
::
string
,
EnumHash
<
DTypeEnum
>
,
EnumCmp
<
DTypeEnum
>>
};
dtype_cenum2cstr
;
DTypeMapper
::
DTypeMapper
()
{
#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \
#define CUSTOM_BIND_DTYPE(custom_dty, builtin_dty, ctype) \
auto cs2be##custom_impl = dtype_cstr2benum.emplace( \
dtype_cstr2benum.emplace(std::string(#custom_dty), BuiltinEnum::builtin_dty); \
std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \
dtype_cenum2benum.emplace(DTypeEnum::custom_dty, BuiltinEnum::builtin_dty); \
auto ce2be##custom_impl = dtype_cenum2benum.emplace( \
dtype_benum2cstr.emplace(BuiltinEnum::builtin_dty, std::string(#custom_dty)); \
DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \
dtype_benum2cenum.emplace(BuiltinEnum::builtin_dty, DTypeEnum::custom_dty); \
auto be2cs##custom_impl = dtype_benum2cstr.emplace( \
dtype_cenum2cstr.emplace(DTypeEnum::custom_dty, std::string(#custom_dty));
megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \
auto be2ce##custom_impl = dtype_benum2cenum.emplace( \
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE
(
CUSTOM_BIND_DTYPE
)
megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \
auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \
DTypeEnum::custom_impl, std::string(#custom_impl));
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE
(
CUSTOM_BIND_DTYPE
)
#undef CUSTOM_BIND_DTYPE
#undef CUSTOM_BIND_DTYPE
}
DTypeMapper
&
DTypeMapper
::
inst
()
{
static
DTypeMapper
dm
;
return
dm
;
}
CUSTOM_PIMPL_CLS_DEFINE
(
DType
)
CUSTOM_PIMPL_CLS_DEFINE
(
DType
)
...
@@ -240,6 +248,7 @@ DType::DType(const void* impl) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
...
@@ -240,6 +248,7 @@ DType::DType(const void* impl) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
}
}
DType
::
DType
(
const
std
::
string
&
dtype
)
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
DType
::
DType
(
const
std
::
string
&
dtype
)
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
auto
&&
dtype_cstr2benum
=
DTypeMapper
::
inst
().
dtype_cstr2benum
;
auto
iter
=
dtype_cstr2benum
.
find
(
dtype
);
auto
iter
=
dtype_cstr2benum
.
find
(
dtype
);
mgb_assert
(
iter
!=
dtype_cstr2benum
.
end
(),
"invalid dtype %s"
,
dtype
.
c_str
());
mgb_assert
(
iter
!=
dtype_cstr2benum
.
end
(),
"invalid dtype %s"
,
dtype
.
c_str
());
mgb_assert
(
mgb_assert
(
...
@@ -254,6 +263,7 @@ DType::DType(const char* dtype) : DType(std::string(dtype)) {}
...
@@ -254,6 +263,7 @@ DType::DType(const char* dtype) : DType(std::string(dtype)) {}
DType
::
DType
(
const
std
::
string
&
dtype
,
float
scale
,
uint8_t
zero_point
)
DType
::
DType
(
const
std
::
string
&
dtype
,
float
scale
,
uint8_t
zero_point
)
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
auto
&&
dtype_cstr2benum
=
DTypeMapper
::
inst
().
dtype_cstr2benum
;
auto
iter
=
dtype_cstr2benum
.
find
(
dtype
);
auto
iter
=
dtype_cstr2benum
.
find
(
dtype
);
mgb_assert
(
iter
!=
dtype_cstr2benum
.
end
(),
"invalid dtype %s"
,
dtype
.
c_str
());
mgb_assert
(
iter
!=
dtype_cstr2benum
.
end
(),
"invalid dtype %s"
,
dtype
.
c_str
());
mgb_assert
(
mgb_assert
(
...
@@ -289,6 +299,7 @@ DType::DType(const char* dtype, float scale, uint8_t zero_point)
...
@@ -289,6 +299,7 @@ DType::DType(const char* dtype, float scale, uint8_t zero_point)
:
DType
(
std
::
string
(
dtype
),
scale
,
zero_point
)
{}
:
DType
(
std
::
string
(
dtype
),
scale
,
zero_point
)
{}
DType
::
DType
(
DTypeEnum
dtype
)
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
DType
::
DType
(
DTypeEnum
dtype
)
:
m_impl
(
nullptr
,
impl_deleter
<
DTypeImpl
>
)
{
auto
&&
dtype_cenum2benum
=
DTypeMapper
::
inst
().
dtype_cenum2benum
;
auto
iter
=
dtype_cenum2benum
.
find
(
dtype
);
auto
iter
=
dtype_cenum2benum
.
find
(
dtype
);
mgb_assert
(
iter
!=
dtype_cenum2benum
.
end
(),
"invalid dtype"
);
mgb_assert
(
iter
!=
dtype_cenum2benum
.
end
(),
"invalid dtype"
);
mgb_assert
(
mgb_assert
(
...
@@ -298,11 +309,13 @@ DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
...
@@ -298,11 +309,13 @@ DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
}
}
DType
::
DType
(
DTypeEnum
dtype
,
float
scale
,
uint8_t
zero_point
)
DType
::
DType
(
DTypeEnum
dtype
,
float
scale
,
uint8_t
zero_point
)
:
DType
(
dtype_cenum2cstr
.
find
(
dtype
)
->
second
,
scale
,
zero_point
)
{}
:
DType
(
DTypeMapper
::
inst
().
dtype_cenum2cstr
.
find
(
dtype
)
->
second
,
scale
,
zero_point
)
{}
std
::
string
DType
::
str
(
void
)
const
{
std
::
string
DType
::
str
(
void
)
const
{
if
(
!
DTypeImplRef
(
m_impl
.
get
()).
valid
())
if
(
!
DTypeImplRef
(
m_impl
.
get
()).
valid
())
return
"invalid"
;
return
"invalid"
;
auto
&&
dtype_benum2cstr
=
DTypeMapper
::
inst
().
dtype_benum2cstr
;
auto
iter
=
dtype_benum2cstr
.
find
(
DTypeImplRef
(
m_impl
.
get
()).
enumv
());
auto
iter
=
dtype_benum2cstr
.
find
(
DTypeImplRef
(
m_impl
.
get
()).
enumv
());
if
(
iter
==
dtype_benum2cstr
.
end
())
if
(
iter
==
dtype_benum2cstr
.
end
())
return
"invalid"
;
return
"invalid"
;
...
@@ -310,6 +323,7 @@ std::string DType::str(void) const {
...
@@ -310,6 +323,7 @@ std::string DType::str(void) const {
}
}
DTypeEnum
DType
::
enumv
(
void
)
const
{
DTypeEnum
DType
::
enumv
(
void
)
const
{
auto
&&
dtype_benum2cenum
=
DTypeMapper
::
inst
().
dtype_benum2cenum
;
auto
iter
=
dtype_benum2cenum
.
find
(
DTypeImplRef
(
m_impl
.
get
()).
enumv
());
auto
iter
=
dtype_benum2cenum
.
find
(
DTypeImplRef
(
m_impl
.
get
()).
enumv
());
mgb_assert
(
iter
!=
dtype_benum2cenum
.
end
(),
"invalid dtype"
);
mgb_assert
(
iter
!=
dtype_benum2cenum
.
end
(),
"invalid dtype"
);
return
iter
->
second
;
return
iter
->
second
;
...
@@ -337,16 +351,18 @@ uint8_t DType::zero_point() const {
...
@@ -337,16 +351,18 @@ uint8_t DType::zero_point() const {
}
}
bool
DType
::
is_legal
(
const
std
::
string
&
dtype
)
{
bool
DType
::
is_legal
(
const
std
::
string
&
dtype
)
{
auto
&&
dtype_cstr2benum
=
DTypeMapper
::
inst
().
dtype_cstr2benum
;
return
dtype_cstr2benum
.
find
(
dtype
)
!=
dtype_cstr2benum
.
end
();
return
dtype_cstr2benum
.
find
(
dtype
)
!=
dtype_cstr2benum
.
end
();
}
}
bool
DType
::
is_legal
(
const
DTypeEnum
&
dtype
)
{
bool
DType
::
is_legal
(
const
DTypeEnum
&
dtype
)
{
auto
&&
dtype_cenum2benum
=
DTypeMapper
::
inst
().
dtype_cenum2benum
;
return
dtype_cenum2benum
.
find
(
dtype
)
!=
dtype_cenum2benum
.
end
();
return
dtype_cenum2benum
.
find
(
dtype
)
!=
dtype_cenum2benum
.
end
();
}
}
std
::
vector
<
std
::
string
>
DType
::
legal_dtypes
(
void
)
{
std
::
vector
<
std
::
string
>
DType
::
legal_dtypes
(
void
)
{
std
::
vector
<
std
::
string
>
ret
;
std
::
vector
<
std
::
string
>
ret
;
for
(
const
auto
&
kv
:
dtype_cstr2benum
)
for
(
const
auto
&
kv
:
DTypeMapper
::
inst
().
dtype_cstr2benum
)
ret
.
emplace_back
(
kv
.
first
);
ret
.
emplace_back
(
kv
.
first
);
return
ret
;
return
ret
;
}
}
...
...
src/custom/include/megbrain/custom/
data_
adaptor.h
→
src/custom/include/megbrain/custom/adaptor.h
浏览文件 @
8a692573
#pragma once
#pragma once
#include "megbrain/custom/op.h"
#include "megbrain/custom/tensor.h"
#include "megbrain/tensor.h"
#include "megdnn/thin/small_vector.h"
#include "megdnn/thin/small_vector.h"
namespace
custom
{
namespace
custom
{
...
@@ -11,27 +14,32 @@ BuiltinT to_builtin(const CustomT& custom) {
...
@@ -11,27 +14,32 @@ BuiltinT to_builtin(const CustomT& custom) {
template
<
typename
BuiltinT
,
typename
CustomT
>
template
<
typename
BuiltinT
,
typename
CustomT
>
CustomT
to_custom
(
const
BuiltinT
&
builtin
)
{
CustomT
to_custom
(
const
BuiltinT
&
builtin
)
{
return
std
::
move
(
CustomT
(
&
builtin
)
);
return
CustomT
(
&
builtin
);
}
}
template
<
typename
BuiltinT
,
typename
CustomT
>
template
<
typename
BuiltinT
,
typename
CustomT
>
megdnn
::
SmallVector
<
BuiltinT
>
to_builtin
(
const
std
::
vector
<
CustomT
>&
customs
)
{
megdnn
::
SmallVector
<
BuiltinT
>
to_builtin
(
const
std
::
vector
<
CustomT
>&
customs
)
{
megdnn
::
SmallVector
<
BuiltinT
>
builtins
;
megdnn
::
SmallVector
<
BuiltinT
>
builtins
;
for
(
size_t
i
=
0
;
i
<
customs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
customs
.
size
();
++
i
)
{
builtins
.
push_back
(
std
::
move
(
to_builtin
<
BuiltinT
,
CustomT
>
(
customs
[
i
])
));
builtins
.
emplace_back
(
to_builtin
<
BuiltinT
,
CustomT
>
(
customs
[
i
]
));
}
}
return
std
::
move
(
builtins
)
;
return
builtins
;
}
}
template
<
typename
BuiltinT
,
typename
CustomT
>
template
<
typename
BuiltinT
,
typename
CustomT
>
std
::
vector
<
CustomT
>
to_custom
(
const
megdnn
::
SmallVector
<
BuiltinT
>&
builtins
)
{
std
::
vector
<
CustomT
>
to_custom
(
const
megdnn
::
SmallVector
<
BuiltinT
>&
builtins
)
{
std
::
vector
<
CustomT
>
customs
;
std
::
vector
<
CustomT
>
customs
;
for
(
size_t
i
=
0
;
i
<
builtins
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
builtins
.
size
();
++
i
)
{
customs
.
push_back
(
std
::
move
(
to_custom
<
BuiltinT
,
CustomT
>
(
builtins
[
i
])
));
customs
.
emplace_back
(
to_custom
<
BuiltinT
,
CustomT
>
(
builtins
[
i
]
));
}
}
return
std
::
move
(
customs
)
;
return
customs
;
}
}
MGE_WIN_DECLSPEC_FUC
void
dispatch_custom_op
(
std
::
shared_ptr
<
const
CustomOp
>
op
,
const
Param
&
param
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
inputs
,
std
::
shared_ptr
<::
megdnn
::
SmallVector
<::
mgb
::
DeviceTensorND
>>
outputs
);
}
// namespace custom
}
// namespace custom
#define to_custom_device(expr) \
#define to_custom_device(expr) \
...
...
src/custom/include/megbrain/custom/manager.h
浏览文件 @
8a692573
...
@@ -5,10 +5,26 @@
...
@@ -5,10 +5,26 @@
namespace
custom
{
namespace
custom
{
class
CustomLib
{
std
::
unique_ptr
<
void
,
void_deleter
>
m_handle
;
std
::
vector
<
std
::
string
>
m_ops
;
public:
PREVENT_COPY_AND_ASSIGN
(
CustomLib
);
CustomLib
(
const
std
::
string
&
path
,
int
mode
);
~
CustomLib
();
MGE_WIN_DECLSPEC_FUC
const
std
::
vector
<
std
::
string
>&
ops_in_lib
(
void
)
const
;
bool
valid
(
void
)
const
;
};
using
LibHandle
=
std
::
shared_ptr
<
CustomLib
>
;
class
CustomOpManager
{
class
CustomOpManager
{
std
::
unordered_map
<
std
::
string
,
LibHandle
>
m_custom_libs
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
CustomOp
>>
m_name2op
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
CustomOp
>>
m_name2op
;
std
::
unordered_map
<
RunTimeId
,
std
::
shared_ptr
<
const
CustomOp
>>
m_id2op
;
std
::
unordered_map
<
RunTimeId
,
std
::
shared_ptr
<
const
CustomOp
>>
m_id2op
;
MGB_MUTEX
m_mtx
;
MGB_MUTEX
m_lib_mtx
;
MGB_MUTEX
m_op_mtx
;
CustomOpManager
()
=
default
;
CustomOpManager
()
=
default
;
public:
public:
...
@@ -16,13 +32,15 @@ public:
...
@@ -16,13 +32,15 @@ public:
MGE_WIN_DECLSPEC_FUC
static
CustomOpManager
*
inst
(
void
);
MGE_WIN_DECLSPEC_FUC
static
CustomOpManager
*
inst
(
void
);
MGE_WIN_DECLSPEC_FUC
~
CustomOpManager
();
MGE_WIN_DECLSPEC_FUC
~
CustomOpManager
();
MGE_WIN_DECLSPEC_FUC
const
std
::
vector
<
std
::
string
>&
install
(
const
std
::
string
&
name
,
const
std
::
string
&
path
);
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
std
::
string
>
uninstall
(
const
std
::
string
&
name
);
MGE_WIN_DECLSPEC_FUC
const
std
::
unordered_map
<
std
::
string
,
LibHandle
>&
lib_info
(
void
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
CustomOp
>
insert
(
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
CustomOp
>
insert
(
const
std
::
string
&
name
,
uint32_t
version
);
const
std
::
string
&
name
,
uint32_t
version
);
MGE_WIN_DECLSPEC_FUC
bool
erase
(
const
std
::
string
&
name
);
MGE_WIN_DECLSPEC_FUC
bool
erase
(
const
std
::
string
&
name
);
MGE_WIN_DECLSPEC_FUC
bool
erase
(
const
RunTimeId
&
id
);
MGE_WIN_DECLSPEC_FUC
std
::
shared_ptr
<
CustomOp
>
find_or_reg
(
const
std
::
string
&
name
,
uint32_t
version
);
MGE_WIN_DECLSPEC_FUC
RunTimeId
to_id
(
const
std
::
string
&
name
)
const
;
MGE_WIN_DECLSPEC_FUC
RunTimeId
to_id
(
const
std
::
string
&
name
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
to_name
(
const
RunTimeId
&
id
)
const
;
MGE_WIN_DECLSPEC_FUC
std
::
string
to_name
(
const
RunTimeId
&
id
)
const
;
...
@@ -36,35 +54,4 @@ public:
...
@@ -36,35 +54,4 @@ public:
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
RunTimeId
>
op_id_list
(
void
);
MGE_WIN_DECLSPEC_FUC
std
::
vector
<
RunTimeId
>
op_id_list
(
void
);
};
};
class
CustomLib
{
std
::
unique_ptr
<
void
,
void_deleter
>
m_handle
;
std
::
vector
<
std
::
string
>
m_ops
;
public:
PREVENT_COPY_AND_ASSIGN
(
CustomLib
);
CustomLib
(
const
std
::
string
&
path
,
int
mode
);
const
std
::
vector
<
std
::
string
>&
ops_in_lib
(
void
)
const
;
~
CustomLib
();
bool
valid
(
void
)
const
;
};
using
LibHandle
=
std
::
shared_ptr
<
CustomLib
>
;
class
LibManager
{
std
::
unordered_map
<
std
::
string
,
LibHandle
>
m_custom_libs
;
MGB_MUTEX
m_mtx
;
LibManager
()
=
default
;
public:
PREVENT_COPY_AND_ASSIGN
(
LibManager
);
MGE_WIN_DECLSPEC_FUC
static
LibManager
*
inst
(
void
);
MGE_WIN_DECLSPEC_FUC
const
std
::
vector
<
std
::
string
>&
install
(
const
std
::
string
&
name
,
const
std
::
string
&
path
);
MGE_WIN_DECLSPEC_FUC
bool
uninstall
(
const
std
::
string
&
name
);
friend
class
CustomOpManager
;
};
}
// namespace custom
}
// namespace custom
src/custom/include/megbrain/custom/param_val.h
浏览文件 @
8a692573
...
@@ -76,8 +76,6 @@ class Device;
...
@@ -76,8 +76,6 @@ class Device;
* Macro Callback for Register
* Macro Callback for Register
*/
*/
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type,
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type,
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \
{ParamDynType::dyn_type, #dyn_type},
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \
template <> \
template <> \
...
@@ -95,10 +93,7 @@ enum class ParamDynType : uint32_t {
...
@@ -95,10 +93,7 @@ enum class ParamDynType : uint32_t {
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_DYN_PARAMTYPE
)
Invalid
=
255
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_DYN_PARAMTYPE
)
Invalid
=
255
};
};
static
std
::
unordered_map
<
MGE_WIN_DECLSPEC_FUC
std
::
string
ptype2name
(
ParamDynType
);
ParamDynType
,
std
::
string
,
EnumHash
<
ParamDynType
>
,
EnumCmp
<
ParamDynType
>>
type2name
=
{
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_DYN_PARAMTYPE_NAME
){
ParamDynType
::
Invalid
,
"Invalid"
}};
/**
/**
* get the dynamic data type according to the builtin static data type
* get the dynamic data type according to the builtin static data type
...
@@ -124,7 +119,6 @@ CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER)
...
@@ -124,7 +119,6 @@ CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER)
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_STATIC_PARAMTYPE_GETTER
)
CUSTOM_FOR_EACH_VALID_PARAMTYPE
(
CUSTOM_REG_STATIC_PARAMTYPE_GETTER
)
#undef CUSTOM_REG_DYN_PARAMTYPE
#undef CUSTOM_REG_DYN_PARAMTYPE
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER
...
@@ -290,7 +284,7 @@ T& ParamVal::as(void) {
...
@@ -290,7 +284,7 @@ T& ParamVal::as(void) {
ParamDynType
t_dyn_type
=
get_dyn_type
<
DecayType
>::
type
;
ParamDynType
t_dyn_type
=
get_dyn_type
<
DecayType
>::
type
;
custom_assert
(
custom_assert
(
t_dyn_type
==
m_type
,
"type mismatch, type %s cannot be cast to type %s
\n
"
,
t_dyn_type
==
m_type
,
"type mismatch, type %s cannot be cast to type %s
\n
"
,
type2name
[
m_type
].
c_str
(),
type2name
[
t_dyn_type
]
.
c_str
());
ptype2name
(
m_type
).
c_str
(),
ptype2name
(
t_dyn_type
)
.
c_str
());
return
TypedRef
(
T
,
m_ptr
.
get
());
return
TypedRef
(
T
,
m_ptr
.
get
());
}
}
...
...
src/custom/include/megbrain/custom/utils.h
浏览文件 @
8a692573
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <cassert>
#include <cassert>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
namespace
custom
{
namespace
custom
{
...
@@ -108,4 +109,7 @@ struct EnumCmp {
...
@@ -108,4 +109,7 @@ struct EnumCmp {
}
}
};
};
template
<
typename
Key
,
typename
Value
>
using
EnumMap
=
std
::
unordered_map
<
Key
,
Value
,
EnumHash
<
Key
>
,
EnumCmp
<
Key
>>
;
}
// namespace custom
}
// namespace custom
src/custom/test/manager.cpp
浏览文件 @
8a692573
...
@@ -12,16 +12,17 @@ namespace custom {
...
@@ -12,16 +12,17 @@ namespace custom {
TEST
(
TestOpManager
,
TestOpManager
)
{
TEST
(
TestOpManager
,
TestOpManager
)
{
CustomOpManager
*
com
=
CustomOpManager
::
inst
();
CustomOpManager
*
com
=
CustomOpManager
::
inst
();
std
::
vector
<
std
::
string
>
builtin_op_names
=
com
->
op_name_list
();
size_t
builtin_op_num
=
builtin_op_names
.
size
();
com
->
insert
(
"Op1"
,
CUSTOM_OP_VERSION
);
com
->
insert
(
"Op1"
,
CUSTOM_OP_VERSION
);
com
->
insert
(
"Op2"
,
CUSTOM_OP_VERSION
);
com
->
insert
(
"Op2"
,
CUSTOM_OP_VERSION
);
std
::
shared_ptr
<
CustomOp
>
ptr
=
com
->
find_or_reg
(
"Op3"
,
CUSTOM_OP_VERSION
);
ASSERT_TRUE
(
ptr
!=
nullptr
);
std
::
vector
<
std
::
string
>
op_names
=
com
->
op_name_list
();
std
::
vector
<
std
::
string
>
op_names
=
com
->
op_name_list
();
std
::
vector
<
RunTimeId
>
op_ids
=
com
->
op_id_list
();
std
::
vector
<
RunTimeId
>
op_ids
=
com
->
op_id_list
();
ASSERT_TRUE
(
op_names
.
size
()
==
3
);
ASSERT_TRUE
(
op_names
.
size
()
==
builtin_op_num
+
2
);
ASSERT_TRUE
(
op_ids
.
size
()
==
3
);
ASSERT_TRUE
(
op_ids
.
size
()
==
builtin_op_num
+
2
);
#if MANAGER_TEST_LOG
#if MANAGER_TEST_LOG
for
(
std
::
string
&
name
:
op_names
)
{
for
(
std
::
string
&
name
:
op_names
)
{
...
@@ -52,12 +53,9 @@ TEST(TestOpManager, TestOpManager) {
...
@@ -52,12 +53,9 @@ TEST(TestOpManager, TestOpManager) {
}
}
#endif
#endif
ASSERT_TRUE
(
com
->
erase
(
"Op1"
));
ASSERT_TRUE
(
com
->
erase
(
"Op1"
));
ASSERT_TRUE
(
com
->
erase
(
com
->
to_id
(
"Op2"
)));
ASSERT_TRUE
(
com
->
op_id_list
().
size
()
==
builtin_op_num
+
1
);
ASSERT_TRUE
(
com
->
op_id_list
().
size
()
==
1
);
ASSERT_TRUE
(
com
->
op_name_list
().
size
()
==
builtin_op_num
+
1
);
ASSERT_TRUE
(
com
->
op_name_list
().
size
()
==
1
);
ASSERT_TRUE
(
com
->
erase
(
"Op2"
));
ASSERT_TRUE
(
com
->
op_name_list
()[
0
]
==
"Op3"
);
ptr
.
reset
();
ASSERT_TRUE
(
com
->
erase
(
"Op3"
));
}
}
TEST
(
TestOpManager
,
TestOpReg
)
{
TEST
(
TestOpManager
,
TestOpReg
)
{
...
...
src/custom/test/op.cpp
浏览文件 @
8a692573
...
@@ -4,9 +4,10 @@
...
@@ -4,9 +4,10 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "megbrain/comp_node.h"
#include "megbrain/comp_node.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/op.h"
#include "megbrain/custom/op.h"
#include "megbrain/tensor.h"
#include "megbrain/tensor.h"
#include "megbrain/test/helper.h"
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
#define OP_TEST_LOG 0
#define OP_TEST_LOG 0
...
@@ -93,60 +94,6 @@ void format_infer(
...
@@ -93,60 +94,6 @@ void format_infer(
outputs
[
1
]
=
inputs
[
0
];
outputs
[
1
]
=
inputs
[
0
];
}
}
void
cpu_kernel
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking CPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"x86"
);
}
void
gpu_kernel
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking GPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"cuda"
);
}
void
cpu_kernel_with_runtime_args
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
,
const
RuntimeArgs
&
args
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
(
void
)
args
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking CPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"x86"
);
}
void
gpu_kernel_with_runtime_args
(
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
std
::
vector
<
Tensor
>&
outputs
,
const
RuntimeArgs
&
args
)
{
(
void
)
inputs
;
(
void
)
params
;
(
void
)
outputs
;
(
void
)
args
;
#if OP_TEST_LOG
std
::
cout
<<
"Checking GPU Forward - "
<<
params
[
"device"
].
as
<
std
::
string
>
()
<<
std
::
endl
;
#endif
ASSERT_TRUE
(
params
[
"device"
]
==
"cuda"
);
}
TEST
(
TestCustomOp
,
TestCustomOpFuncSetter
)
{
TEST
(
TestCustomOp
,
TestCustomOpFuncSetter
)
{
#if MGB_CUDA
#if MGB_CUDA
CustomOp
test
(
"TestOp"
,
CUSTOM_OP_VERSION
);
CustomOp
test
(
"TestOp"
,
CUSTOM_OP_VERSION
);
...
@@ -155,7 +102,8 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
...
@@ -155,7 +102,8 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
.
add_input
(
"rhs"
,
"rhs of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_input
(
"rhs"
,
"rhs of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outl"
,
"outl of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outl"
,
"outl of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outr"
,
"outr of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outr"
,
"outr of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_param
(
"smooth"
,
"smooth"
,
0.
f
)
.
add_param
(
"scale_f"
,
"scale_f"
,
1.
f
)
.
add_param
(
"offset_i"
,
"offset_i"
,
0
)
.
add_param
(
"device"
,
"using for judge device"
,
"x86"
);
.
add_param
(
"device"
,
"using for judge device"
,
"x86"
);
std
::
vector
<
Device
>
idevices
=
{
"x86"
,
"cuda"
};
std
::
vector
<
Device
>
idevices
=
{
"x86"
,
"cuda"
};
...
@@ -206,35 +154,93 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
...
@@ -206,35 +154,93 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
ASSERT_TRUE
(
odtypes
[
1
]
==
"int32"
);
ASSERT_TRUE
(
odtypes
[
1
]
==
"int32"
);
ASSERT_TRUE
(
iformats
[
0
].
is_default
());
ASSERT_TRUE
(
iformats
[
0
].
is_default
());
ASSERT_TRUE
(
iformats
[
1
].
is_default
());
ASSERT_TRUE
(
iformats
[
1
].
is_default
());
#endif
}
test
.
set_compute
(
cpu_kernel_with_runtime_args
);
void
cpu_kernel
(
test
.
set_compute
(
cpu_kernel
);
const
std
::
vector
<
Tensor
>&
inputs
,
const
Param
&
params
,
DeviceTensorND
cdev_itensor0
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Int32
{});
std
::
vector
<
Tensor
>&
outputs
)
{
DeviceTensorND
cdev_itensor1
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Float32
{});
ASSERT_TRUE
(
inputs
.
size
()
==
2
);
DeviceTensorND
cdev_otensor0
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Float32
{});
ASSERT_TRUE
(
outputs
.
size
()
==
2
);
DeviceTensorND
cdev_otensor1
(
CompNode
::
load
(
"cpux"
),
{
3
,
2
},
dtype
::
Int32
{});
ASSERT_TRUE
(
params
[
"device"
]
==
"x86"
);
ASSERT_TRUE
(
params
[
"scale_f"
]
==
2.12
f
);
std
::
vector
<
Tensor
>
cinputs
=
{
ASSERT_TRUE
(
params
[
"offset_i"
]
==
6
);
to_custom_tensor
(
cdev_itensor0
),
to_custom_tensor
(
cdev_itensor1
)};
ASSERT_TRUE
(
inputs
[
0
].
shape
()
==
Shape
({
3
,
4
}));
std
::
vector
<
Tensor
>
coutputs
=
{
ASSERT_TRUE
(
inputs
[
1
].
shape
()
==
Shape
({
5
,
6
}));
to_custom_tensor
(
cdev_otensor0
),
to_custom_tensor
(
cdev_otensor1
)};
ASSERT_TRUE
(
outputs
[
0
].
shape
()
==
Shape
({
5
,
6
}));
ASSERT_TRUE
(
outputs
[
1
].
shape
()
==
Shape
({
3
,
4
}));
ASSERT_TRUE
(
inputs
[
0
].
device
()
==
"x86"
);
ASSERT_TRUE
(
inputs
[
1
].
device
()
==
"x86"
);
ASSERT_TRUE
(
outputs
[
0
].
device
()
==
"x86"
);
ASSERT_TRUE
(
outputs
[
1
].
device
()
==
"x86"
);
float
scale_f
=
params
[
"scale_f"
].
as
<
float
>
();
int
offset_i
=
params
[
"offset_i"
].
as
<
int
>
();
for
(
size_t
i
=
0
;
i
<
5
*
6
;
++
i
)
{
ASSERT_TRUE
(
inputs
[
1
].
data
<
float
>
()[
i
]
==
static_cast
<
float
>
(
i
));
outputs
[
0
].
data
<
float
>
()[
i
]
=
inputs
[
1
].
data
<
float
>
()[
i
]
*
scale_f
;
}
for
(
size_t
i
=
0
;
i
<
3
*
4
;
++
i
)
{
ASSERT_TRUE
(
inputs
[
0
].
data
<
int
>
()[
i
]
==
static_cast
<
int
>
(
i
));
outputs
[
1
].
data
<
int
>
()[
i
]
=
inputs
[
0
].
data
<
int
>
()[
i
]
+
offset_i
;
}
}
TEST
(
TestCustomOp
,
TestCustomOpCompute
)
{
std
::
shared_ptr
<
CustomOp
>
op
=
std
::
make_shared
<
CustomOp
>
(
"TestOp"
,
CUSTOM_OP_VERSION
);
op
->
set_description
(
"Test Op Forward Backward Union"
)
.
add_input
(
"lhs"
,
"lhs of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_input
(
"rhs"
,
"rhs of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outl"
,
"outl of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_output
(
"outr"
,
"outr of Test op"
,
{
"float32"
,
"int32"
},
2
)
.
add_param
(
"scale_f"
,
"scale_f"
,
1.
f
)
.
add_param
(
"offset_i"
,
"offset_i"
,
0
)
.
add_param
(
"device"
,
"using for judge device"
,
"x86"
)
.
set_shape_infer
(
shape_infer
)
.
set_dtype_infer
(
dtype_infer
)
.
set_compute
(
"x86"
,
cpu_kernel
);
Param
param
(
op
->
param_info
());
param
[
"device"
]
=
"x86"
;
param
[
"device"
]
=
"x86"
;
test
.
compute
(
cinputs
,
param
,
coutputs
);
param
[
"scale_f"
]
=
2.12
f
;
param
[
"offset_i"
]
=
6
;
test
.
set_compute
(
"cuda"
,
gpu_kernel_with_runtime_args
);
test
.
set_compute
(
"cuda"
,
gpu_kernel
);
HostTensorGenerator
<
dtype
::
Float32
>
gen_f
;
DeviceTensorND
gdev_itensor0
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Int32
{});
HostTensorGenerator
<
dtype
::
Int32
>
gen_i
;
DeviceTensorND
gdev_itensor1
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Float32
{});
auto
host_i0
=
gen_i
({
3
,
4
}),
host_i1
=
gen_f
({
5
,
6
});
DeviceTensorND
gdev_otensor0
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Float32
{});
auto
expect_o0
=
gen_f
({
5
,
6
}),
expect_o1
=
gen_i
({
3
,
4
});
DeviceTensorND
gdev_otensor1
(
CompNode
::
load
(
"gpux"
),
{
3
,
2
},
dtype
::
Int32
{});
for
(
size_t
i
=
0
;
i
<
5
*
6
;
++
i
)
{
host_i1
->
ptr
<
float
>
()[
i
]
=
static_cast
<
float
>
(
i
);
std
::
vector
<
Tensor
>
ginputs
=
{
expect_o0
->
ptr
<
float
>
()[
i
]
=
host_i1
->
ptr
<
float
>
()[
i
]
*
2.12
f
;
to_custom_tensor
(
gdev_itensor0
),
to_custom_tensor
(
gdev_itensor1
)};
}
std
::
vector
<
Tensor
>
goutputs
=
{
for
(
size_t
i
=
0
;
i
<
3
*
4
;
++
i
)
{
to_custom_tensor
(
gdev_otensor0
),
to_custom_tensor
(
gdev_otensor1
)};
host_i0
->
ptr
<
int
>
()[
i
]
=
static_cast
<
int
>
(
i
);
param
[
"device"
]
=
"cuda"
;
expect_o1
->
ptr
<
int
>
()[
i
]
=
host_i0
->
ptr
<
int
>
()[
i
]
+
6
;
test
.
compute
(
ginputs
,
param
,
goutputs
);
}
#endif
auto
cn
=
CompNode
::
load
(
"cpux"
);
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
x86_inps
=
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
(
2
);
x86_inps
->
at
(
0
)
=
DeviceTensorND
{
cn
};
x86_inps
->
at
(
1
)
=
DeviceTensorND
{
cn
};
x86_inps
->
at
(
0
).
copy_from
(
*
host_i0
).
sync
();
x86_inps
->
at
(
1
).
copy_from
(
*
host_i1
).
sync
();
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
x86_oups
=
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
(
2
);
x86_oups
->
at
(
0
)
=
DeviceTensorND
{
cn
,
{
5
,
6
},
dtype
::
Float32
{}};
x86_oups
->
at
(
1
)
=
DeviceTensorND
{
cn
,
{
3
,
4
},
dtype
::
Int32
{}};
dispatch_custom_op
(
op
,
param
,
x86_inps
,
x86_oups
);
cn
.
sync
();
HostTensorND
host_o0
,
host_o1
;
host_o0
.
copy_from
(
x86_oups
->
at
(
0
)).
sync
();
host_o1
.
copy_from
(
x86_oups
->
at
(
1
)).
sync
();
MGB_ASSERT_TENSOR_NEAR
(
*
expect_o0
,
host_o0
,
1e-6
);
MGB_ASSERT_TENSOR_NEAR
(
*
expect_o1
,
host_o1
,
1e-6
);
}
}
}
// namespace custom
}
// namespace custom
...
...
src/custom/test/tensor.cpp
浏览文件 @
8a692573
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "megbrain/comp_node.h"
#include "megbrain/comp_node.h"
#include "megbrain/custom/
data_
adaptor.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/tensor.h"
#include "megbrain/custom/tensor.h"
#include "megbrain/tensor.h"
#include "megbrain/tensor.h"
#include "megbrain_build_config.h"
#include "megbrain_build_config.h"
...
...
src/opr/impl/custom_opnode.cpp
浏览文件 @
8a692573
...
@@ -114,24 +114,21 @@ void CustomOpNode::init_output_comp_node() {
...
@@ -114,24 +114,21 @@ void CustomOpNode::init_output_comp_node() {
void
CustomOpNode
::
do_execute
(
ExecEnv
&
env
)
{
void
CustomOpNode
::
do_execute
(
ExecEnv
&
env
)
{
auto
runner
=
[
this
]()
{
auto
runner
=
[
this
]()
{
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
inputs
=
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
();
std
::
shared_ptr
<
SmallVector
<
DeviceTensorND
>>
outputs
=
std
::
make_shared
<
SmallVector
<
DeviceTensorND
>>
();
for
(
size_t
i
=
0
;
i
<
input_num
();
i
++
)
{
inputs
->
emplace_back
(
input
(
i
)
->
dev_tensor
());
}
for
(
size_t
i
=
0
;
i
<
output_num
();
i
++
)
{
outputs
->
emplace_back
(
output
(
i
)
->
dev_tensor
());
}
this
->
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
BeforeKernel
>
(
this
->
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
BeforeKernel
>
(
this
,
m_comp_node
);
this
,
m_comp_node
);
m_comp_node
.
activate
();
m_comp_node
.
activate
();
custom
::
dispatch_custom_op
(
m_op
,
m_param
,
inputs
,
outputs
);
SmallVector
<
DeviceTensorND
>
inputs
,
outputs
;
for
(
size_t
i
=
0
;
i
<
input_num
();
i
++
)
inputs
.
push_back
(
input
(
i
)
->
dev_tensor
());
for
(
size_t
i
=
0
;
i
<
output_num
();
i
++
)
outputs
.
push_back
(
output
(
i
)
->
dev_tensor
());
std
::
vector
<
custom
::
Tensor
>
custom_inputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
inputs
);
std
::
vector
<
custom
::
Tensor
>
custom_outputs
=
custom
::
to_custom
<
DeviceTensorND
,
custom
::
Tensor
>
(
outputs
);
m_op
->
compute
(
custom_inputs
,
m_param
,
custom_outputs
);
// [TODO] sync should be modified
CompNode
::
sync_all
();
this
->
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
AfterKernel
>
(
this
->
owner_graph
()
->
event
().
signal_inplace
<
cg
::
event
::
AfterKernel
>
(
this
,
m_comp_node
);
this
,
m_comp_node
);
};
};
...
...
src/opr/include/megbrain/opr/custom_opnode.h
浏览文件 @
8a692573
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#if MGB_CUSTOM_OP
#if MGB_CUSTOM_OP
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/custom.h"
#include "megbrain/custom/custom.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/manager.h"
#include "megbrain/custom/manager.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/helper.h"
#include "megbrain/graph/helper.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录