提交 8a692573 编写于 作者: M Megvii Engine Team

refactor(customop): support write builtin op with custom op

GitOrigin-RevId: cd90002fe851a025b002e918f4b6f638936e660f
上级 8db64303
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
from .._imperative_rt.ops._custom import ( from .._imperative_rt.ops._custom import (
_get_custom_op_lib_info,
_get_custom_op_list, _get_custom_op_list,
_install, _install,
_make_custom_op, _make_custom_op,
...@@ -22,8 +23,7 @@ def _gen_custom_op_maker(custom_op_name): ...@@ -22,8 +23,7 @@ def _gen_custom_op_maker(custom_op_name):
def load(lib_path): def load(lib_path):
lib_path = os.path.abspath(lib_path) lib_path = os.path.abspath(lib_path)
lib_name = os.path.splitext(lib_path)[0] op_in_this_lib = _install(lib_path, lib_path)
op_in_this_lib = _install(lib_name, lib_path)
for op in op_in_this_lib: for op in op_in_this_lib:
op_maker = _gen_custom_op_maker(op) op_maker = _gen_custom_op_maker(op)
globals()[op] = op_maker globals()[op] = op_maker
...@@ -32,5 +32,19 @@ def load(lib_path): ...@@ -32,5 +32,19 @@ def load(lib_path):
def unload(lib_path): def unload(lib_path):
lib_path = os.path.abspath(lib_path) lib_path = os.path.abspath(lib_path)
lib_name = os.path.splitext(lib_path)[0] op_in_lib = _uninstall(lib_path)
_uninstall(lib_name) for op in op_in_lib:
del globals()[op]
__all__.remove(op)
def _make_official_custom_op():
official_opr_list = _get_custom_op_list()
for op in official_opr_list:
op_maker = _gen_custom_op_maker(op)
if op not in globals():
globals()[op] = op_maker
__all__.append(op)
_make_official_custom_op()
...@@ -782,6 +782,10 @@ def build( ...@@ -782,6 +782,10 @@ def build(
with_cudnn, with_cudnn,
abi_tag, abi_tag,
) )
target_libpath = "{}_v{}".format(name, version) + str(
".dll" if IS_WINDOWS else ".so"
)
if verbose: if verbose:
if version != old_version and old_version != None: if version != old_version and old_version != None:
print( print(
...@@ -795,8 +799,7 @@ def build( ...@@ -795,8 +799,7 @@ def build(
print( print(
"No modifications detected for {}, skipping build step...".format(name) "No modifications detected for {}, skipping build step...".format(name)
) )
return return os.path.join(build_dir, "{}".format(target_libpath))
name = "{}_v{}".format(name, version)
# phase 3: compiler and ninja check # phase 3: compiler and ninja check
_check_ninja_availability() _check_ninja_availability()
...@@ -830,8 +833,6 @@ def build( ...@@ -830,8 +833,6 @@ def build(
try: try:
# phase 5: generate ninja build file # phase 5: generate ninja build file
objs = [_obj_file_path(src) for src in sources] objs = [_obj_file_path(src) for src in sources]
name += ".dll" if IS_WINDOWS else ".so"
build_file_path = os.path.join(build_dir, "build.ninja") build_file_path = os.path.join(build_dir, "build.ninja")
if verbose: if verbose:
print("Emitting ninja build file {}".format(build_file_path)) print("Emitting ninja build file {}".format(build_file_path))
...@@ -844,7 +845,7 @@ def build( ...@@ -844,7 +845,7 @@ def build(
sources=sources, sources=sources,
objects=objs, objects=objs,
ldflags=ldflags, ldflags=ldflags,
library_target=name, library_target=target_libpath,
with_cuda=with_cuda, with_cuda=with_cuda,
) )
...@@ -852,7 +853,7 @@ def build( ...@@ -852,7 +853,7 @@ def build(
if verbose: if verbose:
print( print(
"Compiling and linking your custom op {}".format( "Compiling and linking your custom op {}".format(
os.path.join(build_dir, name) os.path.join(build_dir, target_libpath)
) )
) )
_build_with_ninja(build_dir, verbose, "compiling error") _build_with_ninja(build_dir, verbose, "compiling error")
...@@ -861,7 +862,7 @@ def build( ...@@ -861,7 +862,7 @@ def build(
else: else:
baton.wait() baton.wait()
return os.path.join(build_dir, name) return os.path.join(build_dir, target_libpath)
def build_and_load( def build_and_load(
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "./tensor.h" #include "./tensor.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
...@@ -725,9 +725,7 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { ...@@ -725,9 +725,7 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
return obj; return obj;
#else #else
mgb_assert( mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
return nullptr; return nullptr;
#endif #endif
} }
...@@ -737,46 +735,49 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { ...@@ -737,46 +735,49 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
py::list install_custom(const std::string& name, const std::string& path) { py::list install_custom(const std::string& name, const std::string& path) {
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
py::list ret; const auto& ops_in_lib = custom::CustomOpManager::inst()->install(name, path);
const auto& ops_in_lib = custom::LibManager::inst()->install(name, path); py::list ret = py::cast(ops_in_lib);
for (const auto& op : ops_in_lib) {
ret.append(op);
}
return ret; return ret;
#else #else
mgb_assert( mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
false, return py::list{};
"Custom Op is disabled now, please build megengine with Custom Op open");
py::list ret;
return ret;
#endif #endif
} }
bool uninstall_custom(const std::string& name) { py::list uninstall_custom(const std::string& name) {
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
return custom::LibManager::inst()->uninstall(name); const auto& ops_in_lib = custom::CustomOpManager::inst()->uninstall(name);
py::list ret = py::cast(ops_in_lib);
return ret;
#else #else
mgb_assert( mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
false,
"Custom Op is disabled now, please build megengine with Custom Op open");
return false; return false;
#endif #endif
} }
py::list get_custom_op_list(void) { py::list get_custom_op_list(void) {
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list(); std::vector<std::string> all_ops = custom::CustomOpManager::inst()->op_name_list();
py::list ret; py::list ret = py::cast(all_ops);
for (auto& op : all_ops) {
ret.append(op);
}
return ret; return ret;
#else #else
mgb_assert( mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
false, return py::list{};
"Custom Op is disabled now, please build megengine with Custom Op open"); #endif
py::list ret; }
py::dict get_custom_op_lib_info(void) {
#if MGB_CUSTOM_OP
auto&& libs = custom::CustomOpManager::inst()->lib_info();
py::dict ret;
for (auto&& [lib_name, lib_handle] : libs) {
py::list ops = py::cast(lib_handle->ops_in_lib());
ret[py::str(lib_name)] = ops;
}
return ret; return ret;
#else
mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
return py::list{};
#endif #endif
} }
...@@ -792,6 +793,7 @@ void init_custom(pybind11::module m) { ...@@ -792,6 +793,7 @@ void init_custom(pybind11::module m) {
m.def("_install", &install_custom); m.def("_install", &install_custom);
m.def("_uninstall", &uninstall_custom); m.def("_uninstall", &uninstall_custom);
m.def("_get_custom_op_list", &get_custom_op_list); m.def("_get_custom_op_list", &get_custom_op_list);
m.def("_get_custom_op_lib_info", &get_custom_op_lib_info);
m.def("get_custom_op_abi_tag", [](void) -> int { m.def("get_custom_op_abi_tag", [](void) -> int {
int ret = 0; int ret = 0;
#ifdef _GLIBCXX_USE_CXX11_ABI #ifdef _GLIBCXX_USE_CXX11_ABI
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
#include "./matmul_scale.h" #include "./matmul_scale.h"
#include "megbrain/custom/platform/custom_cuda.h"
using namespace custom; using namespace custom;
...@@ -51,12 +52,13 @@ void matmul_forward_helper( ...@@ -51,12 +52,13 @@ void matmul_forward_helper(
float scale) { float scale) {
dim3 block(1, 1); dim3 block(1, 1);
dim3 grid(N / block.x, M / block.y); dim3 grid(N / block.x, M / block.y);
auto stream = get_cuda_stream(lhs.device());
DISPATCH_INT_AND_FLOAT_TYPES(res.dtype(), "matmul_forward", ([&]() { DISPATCH_INT_AND_FLOAT_TYPES(
matmul_forward_naive<scalar_t><<<grid, block>>>( res.dtype(), "matmul_forward", ([&]() {
lhs.data<scalar_t>(), rhs.data<scalar_t>(), matmul_forward_naive<scalar_t><<<grid, block, 0, stream>>>(
res.data<scalar_t>(), M, K, N, scale); lhs.data<scalar_t>(), rhs.data<scalar_t>(),
})); res.data<scalar_t>(), M, K, N, scale);
}));
} }
void matmul_backward_lhs_helper( void matmul_backward_lhs_helper(
...@@ -64,9 +66,10 @@ void matmul_backward_lhs_helper( ...@@ -64,9 +66,10 @@ void matmul_backward_lhs_helper(
size_t N, float scale) { size_t N, float scale) {
dim3 block(1, 1); dim3 block(1, 1);
dim3 grid(K / block.x, M / block.y); dim3 grid(K / block.x, M / block.y);
auto stream = get_cuda_stream(rhs.device());
DISPATCH_INT_AND_FLOAT_TYPES( DISPATCH_INT_AND_FLOAT_TYPES(
lhs_grad.dtype(), "matmul_backward_lhs", ([&]() { lhs_grad.dtype(), "matmul_backward_lhs", ([&]() {
matmul_backward_lhs_naive<scalar_t><<<grid, block>>>( matmul_backward_lhs_naive<scalar_t><<<grid, block, 0, stream>>>(
rhs.data<scalar_t>(), ograd.data<scalar_t>(), rhs.data<scalar_t>(), ograd.data<scalar_t>(),
lhs_grad.data<scalar_t>(), M, K, N, scale); lhs_grad.data<scalar_t>(), M, K, N, scale);
})); }));
...@@ -77,9 +80,10 @@ void matmul_backward_rhs_helper( ...@@ -77,9 +80,10 @@ void matmul_backward_rhs_helper(
size_t N, float scale) { size_t N, float scale) {
dim3 block(1, 1); dim3 block(1, 1);
dim3 grid(N / block.x, K / block.y); dim3 grid(N / block.x, K / block.y);
auto stream = get_cuda_stream(lhs.device());
DISPATCH_INT_AND_FLOAT_TYPES( DISPATCH_INT_AND_FLOAT_TYPES(
rhs_grad.dtype(), "matmul_backward_rhs", ([&]() { rhs_grad.dtype(), "matmul_backward_rhs", ([&]() {
matmul_backward_rhs_naive<scalar_t><<<grid, block>>>( matmul_backward_rhs_naive<scalar_t><<<grid, block, 0, stream>>>(
lhs.data<scalar_t>(), ograd.data<scalar_t>(), lhs.data<scalar_t>(), ograd.data<scalar_t>(),
rhs_grad.data<scalar_t>(), M, K, N, scale); rhs_grad.data<scalar_t>(), M, K, N, scale);
})); }));
......
...@@ -6,92 +6,132 @@ import sys ...@@ -6,92 +6,132 @@ import sys
import numpy as np import numpy as np
import pytest import pytest
import megengine
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optim
from megengine import jit from megengine import jit
from megengine.autodiff import Function, GradManager from megengine.autodiff import Function, GradManager
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core.ops import custom from megengine.core.ops import custom
from megengine.device import get_device_count from megengine.device import get_device_count
from megengine.module import Conv2d, Linear, Module from megengine.tensor import Tensor
from megengine.random import normal
from megengine.tensor import Parameter, Tensor
from megengine.utils import custom_op_tools from megengine.utils import custom_op_tools
build_path = os.path.join(
custom_op_tools._get_default_build_root(), "custom_opsrc", "build"
)
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
mgb_root_path = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(cur_dir_path))))
)
extra_include_paths = [os.path.join(mgb_root_path, "src", "custom", "include")]
def compare(ref, real): extra_ld_flags = []
if ref.shape != real.shape: if sys.platform != "win32":
real = real.T ld_path = os.environ.get("LD_LIBRARY_PATH")
np.testing.assert_allclose(ref, real, rtol=1e-3, atol=1e-5) if ld_path != None:
ld_dirs = ld_path.split(":")
for ld_dir in ld_dirs:
if os.path.exists(ld_dir) and os.path.isdir(ld_dir):
for lib in os.listdir(ld_dir):
if "megengine_shared" in lib:
extra_ld_flags += ["-L{} -Wl,-rpath,{}".format(ld_dir, ld_dir)]
break
def build_and_clean(test_func):
def wrapper():
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
build_root_dir = custom_op_tools._get_default_build_root()
build_path = os.path.join(build_root_dir, "custom_opsrc", "build")
if os.path.exists(build_path): def build_and_clean(*srcs):
shutil.rmtree(build_path) def deco(test_func):
custom_op_srcs = [os.path.join(cur_dir_path, "custom_opsrc", s) for s in srcs]
mgb_root_path = os.path.dirname( def wrapper(*args, **kwargs):
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(cur_dir_path)))
)
)
extra_include_paths = [os.path.join(mgb_root_path, "src", "custom", "include")]
extra_ld_flags = []
if sys.platform != "win32":
ld_path = os.environ.get("LD_LIBRARY_PATH")
if ld_path != None:
ld_dirs = ld_path.split(":")
for ld_dir in ld_dirs:
if os.path.exists(ld_dir) and os.path.isdir(ld_dir):
for lib in os.listdir(ld_dir):
if "megengine_shared" in lib:
extra_ld_flags += [
"-L{} -Wl,-rpath,{}".format(ld_dir, ld_dir)
]
break
if get_device_count("gpu") > 0:
custom_opsrc = [
os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cpp"),
os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cu"),
]
else:
custom_opsrc = [os.path.join(cur_dir_path, "custom_opsrc", "elem_add.cpp")]
try:
lib_path = custom_op_tools.build_and_load( lib_path = custom_op_tools.build_and_load(
"test_op", "test_op",
custom_opsrc, custom_op_srcs,
extra_include_paths=extra_include_paths, extra_include_paths=extra_include_paths,
extra_ldflags=extra_ld_flags,
build_dir=build_path, build_dir=build_path,
verbose=False, extra_ldflags=extra_ld_flags,
verbose=True,
) )
test_func() test_func(*args, **kwargs)
custom.unload(lib_path) custom.unload(lib_path)
finally: return wrapper
if os.path.exists(build_path):
shutil.rmtree(build_path)
return wrapper return deco
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count("gpu") > 0, reason="elem_add operator is only supported on CPU" get_device_count("gpu") > 0, reason="elem_add operator is only supported on CPU"
) )
@build_and_clean @build_and_clean("elem_add.cpp")
def test_custom_op_cpu_build(): def test_cpu_func():
assert "ElemAddSmoothForward" in custom._get_custom_op_list() class ElemAddSmooth(Function):
assert "ElemAddSmoothBackward" in custom._get_custom_op_list() def __init__(self, smooth):
assert hasattr(custom, "ElemAddSmoothForward") super().__init__()
assert hasattr(custom, "ElemAddSmoothBackward") self.smooth = smooth
def forward(self, lhs, rhs):
op = custom.ElemAddSmoothForward(smooth=self.smooth)
return apply(op, lhs, rhs)[0]
def backward(self, ograd):
op = custom.ElemAddSmoothBackward()
return apply(op, ograd)
def gen_elemadd_data(seed, shape, low=-1, high=1):
rng = np.random.RandomState(seed=seed)
lhs_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32)
rhs_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32)
ograd_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32)
return lhs_np, rhs_np, ograd_np
def builtin_func(lhs, rhs, smooth):
out = lhs + rhs
return F.where(out < 0, out + smooth, out - smooth)
def test_elemadd_smooth_train(smooth=0.5, m=4, n=2, seed=2021):
lhs_np, rhs_np, ograd_np = gen_elemadd_data(seed, (m, n))
custom_lhs, custom_rhs = Tensor(lhs_np), Tensor(rhs_np)
builtin_lhs, builtin_rhs = Tensor(lhs_np), Tensor(rhs_np)
ograd_tensor = Tensor(ograd_np)
custom_func = ElemAddSmooth(smooth=smooth)
gm = GradManager().attach([custom_lhs, custom_rhs])
with gm:
custom_out = custom_func(custom_lhs, custom_rhs)
gm.backward(custom_out, ograd_tensor)
gm = GradManager().attach([builtin_lhs, builtin_rhs])
with gm:
builtin_out = builtin_func(builtin_lhs, builtin_rhs, smooth)
gm.backward(builtin_out, ograd_tensor)
np.testing.assert_allclose(custom_out, builtin_out, rtol=1e-3, atol=1e-5)
np.testing.assert_allclose(
custom_lhs.grad.numpy(), builtin_lhs.grad.numpy(), rtol=1e-3, atol=1e-5
)
np.testing.assert_allclose(
custom_rhs.grad.numpy(), builtin_rhs.grad.numpy(), rtol=1e-3, atol=1e-5
)
def test_elemadd_smooth_trace(smooth=0.5, m=4, n=2, seed=2021):
@jit.trace(capture_as_const=True)
def func_dumper(lhs, rhs, *, net):
return net(lhs, rhs)
lhs_np, rhs_np, _ = gen_elemadd_data(seed, (m, n))
lhs_tensor = Tensor(lhs_np)
rhs_tensor = Tensor(rhs_np)
func = ElemAddSmooth(smooth=smooth)
real = func_dumper(lhs_tensor, rhs_tensor, net=func)
real = func_dumper(lhs_tensor, rhs_tensor, net=func)
ref = builtin_func(Tensor(lhs_np), Tensor(rhs_np), smooth)
np.testing.assert_allclose(real.numpy(), ref.numpy(), rtol=1e-3, atol=1e-5)
test_elemadd_smooth_train(0.2, 128, 256, 2027)
test_elemadd_smooth_train(0.3, 256, 128, 2028)
test_elemadd_smooth_train(0.4, 128, 512, 2029)
test_elemadd_smooth_trace(0.2, 256, 64, 2030)
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -101,9 +141,136 @@ def test_custom_op_cpu_build(): ...@@ -101,9 +141,136 @@ def test_custom_op_cpu_build():
@pytest.mark.skipif( @pytest.mark.skipif(
get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU" get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU"
) )
@build_and_clean @build_and_clean("matmul_scale.cpp", "matmul_scale.cu")
def test_custom_op_gpu_build(): def test_gpu_func():
class MatMulScale(Function):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, lhs, rhs):
op = custom.MatMulScaleForward(scale=self.scale)
self.lhs = lhs
self.rhs = rhs
return apply(op, lhs, rhs)[0]
def backward(self, ograd):
op = custom.MatMulScaleBackward(scale=self.scale)
return apply(op, ograd, self.lhs, self.rhs)
def gen_matmul_data(seed, m, k, n, low=-0.5, high=0.5, dtype=np.float32):
rng = np.random.RandomState(seed=seed)
lhs_np = rng.uniform(low=low, high=high, size=(m, k)).astype(dtype)
rhs_np = rng.uniform(low=low, high=high, size=(k, n)).astype(dtype)
ograd_np = rng.uniform(low=low, high=high, size=(m, n)).astype(dtype)
scale = rng.uniform(low=0.1, high=0.9, size=(1)).astype(np.float32)[0]
return lhs_np, rhs_np, ograd_np, scale
def builtin_func(lhs, rhs, scale):
out = F.matmul(lhs, rhs) * scale
return out
def test_matmul_scale(m=1, k=1, n=1, seed=2021):
lhs_np, rhs_np, _, scale = gen_matmul_data(seed, m, k, n)
custom_lhs, custom_rhs = Tensor(lhs_np), Tensor(rhs_np)
builtin_lhs, builtin_rhs = Tensor(lhs_np), Tensor(rhs_np)
custom_func = MatMulScale(scale=scale)
custom_out = custom_func(custom_lhs, custom_rhs)
builtin_out = builtin_func(builtin_lhs, builtin_rhs, scale)
np.testing.assert_allclose(custom_out, builtin_out, rtol=1e-3, atol=1e-5)
def test_matmul_scale_trace(m=1, k=1, n=1, seed=2021):
@jit.trace(capture_as_const=True)
def func_dumper(lhs, rhs, *, net):
return net(lhs, rhs)
lhs_np, rhs_np, _, scale = gen_matmul_data(seed, m, k, n)
lhs_tensor, rhs_tensor = Tensor(lhs_np), Tensor(rhs_np)
func = MatMulScale(scale=scale)
real = func_dumper(lhs_tensor, rhs_tensor, net=func)
real = func_dumper(lhs_tensor, rhs_tensor, net=func)
ref = builtin_func(Tensor(lhs_np), Tensor(rhs_np), scale)
np.testing.assert_allclose(real.numpy(), ref.numpy(), rtol=1e-3, atol=1e-5)
test_matmul_scale(128, 256, 64, 2028)
test_matmul_scale(64, 32, 16, 2029)
test_matmul_scale_trace(64, 32, 16, 2030)
@pytest.mark.skipif(
get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU"
)
def test_custom_op():
org_op_list = custom._get_custom_op_list()
assert len(custom._get_custom_op_lib_info()) == 0
assert "ElemAddSmoothForward" not in custom._get_custom_op_list()
assert not hasattr(custom, "ElemAddSmoothForward")
assert "MatMulScaleForward" not in custom._get_custom_op_list()
assert not hasattr(custom, "MatMulScaleForward")
srcs1 = [os.path.join(cur_dir_path, "custom_opsrc", "elem_add.cpp")]
lib_path1 = custom_op_tools.build_and_load(
"elem",
srcs1,
extra_include_paths=extra_include_paths,
build_dir=build_path,
extra_ldflags=extra_ld_flags,
verbose=True,
)
assert "ElemAddSmoothForward" in custom._get_custom_op_list()
assert hasattr(custom, "ElemAddSmoothForward")
assert lib_path1 in custom._get_custom_op_lib_info()
assert "ElemAddSmoothForward" in custom._get_custom_op_lib_info()[lib_path1]
srcs2 = [
os.path.join(cur_dir_path, "custom_opsrc", src)
for src in ["matmul_scale.cpp", "matmul_scale.cu"]
]
lib_path2 = custom_op_tools.build_and_load(
"matmul",
srcs2,
extra_include_paths=extra_include_paths,
build_dir=build_path,
extra_ldflags=extra_ld_flags,
verbose=True,
)
assert "MatMulScaleForward" in custom._get_custom_op_list()
assert hasattr(custom, "MatMulScaleForward")
assert lib_path2 in custom._get_custom_op_lib_info()
assert "MatMulScaleForward" in custom._get_custom_op_lib_info()[lib_path2]
assert len(custom._get_custom_op_list()) == len(org_op_list) + 4
custom.unload(lib_path1)
assert "ElemAddSmoothForward" not in custom._get_custom_op_list()
assert not hasattr(custom, "ElemAddSmoothForward")
assert lib_path1 not in custom._get_custom_op_lib_info()
custom.unload(lib_path2)
assert "MatMulScaleForward" not in custom._get_custom_op_list()
assert not hasattr(custom, "MatMulScaleForward")
assert lib_path1 not in custom._get_custom_op_lib_info()
assert len(custom._get_custom_op_lib_info()) == 0
assert custom._get_custom_op_list() == org_op_list
custom.load(lib_path2)
assert "MatMulScaleForward" in custom._get_custom_op_list() assert "MatMulScaleForward" in custom._get_custom_op_list()
assert "MatMulScaleBackward" in custom._get_custom_op_list()
assert hasattr(custom, "MatMulScaleForward") assert hasattr(custom, "MatMulScaleForward")
assert hasattr(custom, "MatMulScaleBackward") assert lib_path2 in custom._get_custom_op_lib_info()
assert "MatMulScaleForward" in custom._get_custom_op_lib_info()[lib_path2]
custom.unload(lib_path2)
assert "MatMulScaleForward" not in custom._get_custom_op_list()
assert not hasattr(custom, "MatMulScaleForward")
assert lib_path1 not in custom._get_custom_op_lib_info()
assert len(custom._get_custom_op_lib_info()) == 0
assert custom._get_custom_op_list() == org_op_list
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
#include "../op_trait.h" #include "../op_trait.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/opr/custom_opnode.h" #include "megbrain/opr/custom_opnode.h"
namespace mgb { namespace mgb {
...@@ -51,13 +51,9 @@ const std::shared_ptr<const custom::CustomOp>& CustomOpDef::impl(void) const { ...@@ -51,13 +51,9 @@ const std::shared_ptr<const custom::CustomOp>& CustomOpDef::impl(void) const {
} }
void CustomOpDef::compute( void CustomOpDef::compute(
const SmallVector<DeviceTensorND>& inputs, std::shared_ptr<SmallVector<DeviceTensorND>> inputs,
SmallVector<DeviceTensorND>* outputs) const { std::shared_ptr<SmallVector<DeviceTensorND>> outputs) const {
std::vector<custom::Tensor> custom_inputs = custom::dispatch_custom_op(m_op, m_param, inputs, outputs);
custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs =
custom::to_custom<DeviceTensorND, custom::Tensor>(*outputs);
m_op->compute(custom_inputs, this->m_param, custom_outputs);
} }
std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs( std::tuple<SmallVector<LogicalTensorDesc>, bool> CustomOpDef::infer_output_attrs(
...@@ -169,13 +165,6 @@ std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef( ...@@ -169,13 +165,6 @@ std::shared_ptr<OpDef> CustomOpDefFactory::create_opdef(
namespace custom_opdef { // avoid name conflict namespace custom_opdef { // avoid name conflict
void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs);
}
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
...@@ -194,15 +183,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -194,15 +183,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
} }
SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); std::shared_ptr<SmallVector<DeviceTensorND>> inp_tensornds =
SmallVector<DeviceTensorND> oup_tensornds(outputs.size()); std::make_shared<SmallVector<DeviceTensorND>>();
std::shared_ptr<SmallVector<DeviceTensorND>> oup_tensornds =
for (size_t i = 0; i < inputs.size(); ++i) std::make_shared<SmallVector<DeviceTensorND>>();
inp_tensornds[i] = inputs[i]->dev_tensor(); for (size_t i = 0; i < inputs.size(); ++i) {
for (size_t i = 0; i < outputs.size(); ++i) inp_tensornds->emplace_back(inputs[i]->dev_tensor(true));
oup_tensornds[i] = outputs[i]->dev_tensor(); }
for (size_t i = 0; i < outputs.size(); ++i) {
oup_tensornds->emplace_back(outputs[i]->dev_tensor(true));
}
apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inp_tensornds, oup_tensornds);
return outputs; return outputs;
} }
...@@ -258,7 +251,6 @@ std::string make_name(const OpDef& def) { ...@@ -258,7 +251,6 @@ std::string make_name(const OpDef& def) {
OP_TRAIT_REG(CustomOpDef, CustomOpDef) OP_TRAIT_REG(CustomOpDef, CustomOpDef)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_device_tensornd(apply_on_device_tensornd)
.infer_output_attrs_fallible(infer_output_attrs_fallible) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.hash(hash) .hash(hash)
.is_same_st(is_same_st) .is_same_st(is_same_st)
......
...@@ -31,7 +31,8 @@ public: ...@@ -31,7 +31,8 @@ public:
const std::shared_ptr<const custom::CustomOp>& impl(void) const; const std::shared_ptr<const custom::CustomOp>& impl(void) const;
void compute( void compute(
const SmallVector<DeviceTensorND>&, SmallVector<DeviceTensorND>*) const; std::shared_ptr<SmallVector<DeviceTensorND>>,
std::shared_ptr<SmallVector<DeviceTensorND>>) const;
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs(
const SmallVector<TensorPtr>& inputs) const; const SmallVector<TensorPtr>& inputs) const;
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs(
......
...@@ -32,6 +32,39 @@ const char* dlerror(void) { ...@@ -32,6 +32,39 @@ const char* dlerror(void) {
} }
#endif #endif
CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY)
: m_handle(nullptr, [](void* handle) { dlclose(handle); }) {
auto op_list_before_load = CustomOpManager::inst()->op_name_list();
std::unordered_set<std::string> op_set_before_load(
op_list_before_load.begin(), op_list_before_load.end());
m_handle.reset(dlopen(path.c_str(), mode));
mgb_assert(
m_handle != nullptr, "open custom op lib failed, error type: %s",
dlerror());
auto op_list_after_load = CustomOpManager::inst()->op_name_list();
for (auto& op : op_list_after_load) {
if (op_set_before_load.find(op) == op_set_before_load.end()) {
m_ops.emplace_back(op);
}
}
}
CustomLib::~CustomLib() {
for (auto& op : m_ops) {
CustomOpManager::inst()->erase(op);
}
}
const std::vector<std::string>& CustomLib::ops_in_lib(void) const {
return m_ops;
}
bool CustomLib::valid() const {
return m_handle != nullptr;
}
CustomOpManager* CustomOpManager::inst(void) { CustomOpManager* CustomOpManager::inst(void) {
static CustomOpManager op_manager; static CustomOpManager op_manager;
return &op_manager; return &op_manager;
...@@ -39,12 +72,40 @@ CustomOpManager* CustomOpManager::inst(void) { ...@@ -39,12 +72,40 @@ CustomOpManager* CustomOpManager::inst(void) {
CustomOpManager::~CustomOpManager() { CustomOpManager::~CustomOpManager() {
mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!");
LibManager::inst()->m_custom_libs.clear(); {
MGB_LOCK_GUARD(m_lib_mtx);
m_custom_libs.clear();
}
mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!");
MGB_LOCK_GUARD(m_op_mtx);
m_name2op.clear();
m_id2op.clear();
}
const std::vector<std::string>& CustomOpManager::install(
const std::string& name, const std::string& path) {
MGB_LOCK_GUARD(m_lib_mtx);
LibHandle handle = std::make_shared<CustomLib>(path);
m_custom_libs.insert({name, handle});
return m_custom_libs[name]->ops_in_lib();
}
std::vector<std::string> CustomOpManager::uninstall(const std::string& name) {
MGB_LOCK_GUARD(m_lib_mtx);
std::vector<std::string> op_names = m_custom_libs[name]->ops_in_lib();
mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error");
return op_names;
}
const std::unordered_map<std::string, LibHandle>& CustomOpManager::lib_info(
void) const {
return m_custom_libs;
} }
std::shared_ptr<CustomOp> CustomOpManager::insert( std::shared_ptr<CustomOp> CustomOpManager::insert(
const std::string& name, uint32_t version) { const std::string& name, uint32_t version) {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_op_mtx);
auto iter = m_name2op.find(name); auto iter = m_name2op.find(name);
if (iter != m_name2op.end()) { if (iter != m_name2op.end()) {
mgb_log_warn( mgb_log_warn(
...@@ -59,7 +120,7 @@ std::shared_ptr<CustomOp> CustomOpManager::insert( ...@@ -59,7 +120,7 @@ std::shared_ptr<CustomOp> CustomOpManager::insert(
} }
bool CustomOpManager::erase(const std::string& name) { bool CustomOpManager::erase(const std::string& name) {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_op_mtx);
auto iter = m_name2op.find(name); auto iter = m_name2op.find(name);
if (iter == m_name2op.end()) { if (iter == m_name2op.end()) {
mgb_log_warn( mgb_log_warn(
...@@ -72,28 +133,6 @@ bool CustomOpManager::erase(const std::string& name) { ...@@ -72,28 +133,6 @@ bool CustomOpManager::erase(const std::string& name) {
return true; return true;
} }
bool CustomOpManager::erase(const RunTimeId& id) {
MGB_LOCK_GUARD(m_mtx);
auto iter = m_id2op.find(id);
if (iter == m_id2op.end()) {
mgb_log_warn("Erase Custom Op Failed! The Op has not been registered");
return false;
}
std::shared_ptr<const CustomOp> op = iter->second;
m_id2op.erase(op->runtime_id());
m_name2op.erase(op->op_type());
return true;
}
std::shared_ptr<CustomOp> CustomOpManager::find_or_reg(
const std::string& name, uint32_t version) {
auto iter = m_name2op.find(name);
if (iter == m_name2op.end()) {
return insert(name, version);
}
return std::const_pointer_cast<CustomOp, const CustomOp>(iter->second);
}
RunTimeId CustomOpManager::to_id(const std::string& name) const { RunTimeId CustomOpManager::to_id(const std::string& name) const {
std::shared_ptr<const CustomOp> op = find(name); std::shared_ptr<const CustomOp> op = find(name);
return op->runtime_id(); return op->runtime_id();
...@@ -135,60 +174,6 @@ std::vector<RunTimeId> CustomOpManager::op_id_list(void) { ...@@ -135,60 +174,6 @@ std::vector<RunTimeId> CustomOpManager::op_id_list(void) {
return ret; return ret;
} }
CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY)
: m_handle(nullptr, [](void* handle) { dlclose(handle); }) {
auto op_list_before_load = CustomOpManager::inst()->op_name_list();
std::unordered_set<std::string> op_set_before_load(
op_list_before_load.begin(), op_list_before_load.end());
m_handle.reset(dlopen(path.c_str(), mode));
mgb_assert(
m_handle != nullptr, "open custom op lib failed, error type: %s",
dlerror());
auto op_list_after_load = CustomOpManager::inst()->op_name_list();
for (auto& op : op_list_after_load) {
if (op_set_before_load.find(op) == op_set_before_load.end()) {
m_ops.emplace_back(op);
}
}
}
const std::vector<std::string>& CustomLib::ops_in_lib(void) const {
return m_ops;
}
CustomLib::~CustomLib() {
for (auto& op : m_ops) {
CustomOpManager::inst()->erase(op);
}
}
bool CustomLib::valid() const {
return m_handle != nullptr;
}
LibManager* LibManager::inst(void) {
static LibManager custom_libs;
return &custom_libs;
}
const std::vector<std::string>& LibManager::install(
const std::string& name, const std::string& path) {
MGB_LOCK_GUARD(m_mtx);
;
LibHandle handle = std::make_shared<CustomLib>(path);
m_custom_libs.insert({name, handle});
return m_custom_libs[name]->ops_in_lib();
}
bool LibManager::uninstall(const std::string& name) {
MGB_LOCK_GUARD(m_mtx);
;
mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error");
return true;
}
std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) { std::shared_ptr<CustomOp> op_insert(std::string opname, uint32_t version) {
return CustomOpManager::inst()->insert(opname, version); return CustomOpManager::inst()->insert(opname, version);
} }
......
...@@ -4,8 +4,11 @@ ...@@ -4,8 +4,11 @@
#include <sstream> #include <sstream>
#include <unordered_set> #include <unordered_set>
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/op.h" #include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h" #include "megbrain/custom/utils.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/thin/function.h" #include "megbrain/utils/thin/function.h"
using namespace mgb; using namespace mgb;
...@@ -550,6 +553,45 @@ void CustomOp::compute( ...@@ -550,6 +553,45 @@ void CustomOp::compute(
assert_outputs_size_right(outputs); assert_outputs_size_right(outputs);
} }
void compute_impl(
std::shared_ptr<const CustomOp> op, const Param& param,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs) {
std::vector<custom::Tensor> custom_inputs;
for (size_t i = 0; i < inputs->size(); ++i) {
custom_inputs.emplace_back(to_custom_tensor(inputs->operator[](i)));
}
std::vector<custom::Tensor> custom_outputs;
for (size_t i = 0; i < outputs->size(); ++i) {
custom_outputs.emplace_back(to_custom_tensor(outputs->operator[](i)));
}
op->compute(custom_inputs, param, custom_outputs);
}
void dispatch_custom_op(
std::shared_ptr<const CustomOp> op, const Param& param,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs) {
if (outputs->size() == 0) {
return;
}
auto compnode = outputs->at(0).comp_node();
if (compnode.device_type() == CompNode::DeviceType::CPU) {
auto&& cpu_env = CompNodeEnv::from_comp_node(compnode).cpu_env();
cpu_env.dispatch([op, param, inputs, outputs]() {
compute_impl(op, param, inputs, outputs);
});
} else {
mgb_assert(
compnode.device_type() == CompNode::DeviceType::CUDA,
"custom op only support cuda/cpu now, but get %s",
compnode.to_string().c_str());
compute_impl(op, param, inputs, outputs);
}
}
} // namespace custom } // namespace custom
#endif #endif
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/custom/param_val.h" #include "megbrain/custom/param_val.h"
#include "megbrain/custom/tensor.h" #include "megbrain/custom/tensor.h"
...@@ -40,7 +40,7 @@ namespace custom { ...@@ -40,7 +40,7 @@ namespace custom {
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ #define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \
mgb_assert( \ mgb_assert( \
lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \
type2name[lhs.m_type].c_str(), #op, type2name[rhs.m_type].c_str()) ptype2name(lhs.m_type).c_str(), #op, ptype2name(rhs.m_type).c_str())
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ #define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \
case (ParamDynType::dyn_type): { \ case (ParamDynType::dyn_type): { \
...@@ -177,6 +177,18 @@ namespace custom { ...@@ -177,6 +177,18 @@ namespace custom {
break; \ break; \
} }
std::string ptype2name(ParamDynType ptype) {
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \
{ParamDynType::dyn_type, #dyn_type},
static std::unordered_map<
ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>>
type2name = {CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME){
ParamDynType::Invalid, "Invalid"}};
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME
return type2name[ptype];
}
ParamVal::ParamVal() : m_ptr(nullptr, [](void*) -> void {}) { ParamVal::ParamVal() : m_ptr(nullptr, [](void*) -> void {}) {
m_type = ParamDynType::Invalid; m_type = ParamDynType::Invalid;
} }
...@@ -265,7 +277,7 @@ ParamDynType ParamVal::type(void) const { ...@@ -265,7 +277,7 @@ ParamDynType ParamVal::type(void) const {
std::string ParamVal::str() const { std::string ParamVal::str() const {
std::stringstream ss; std::stringstream ss;
ss << "type: " << type2name[m_type] << "\n" ss << "type: " << ptype2name(m_type) << "\n"
<< "value: "; << "value: ";
switch (m_type) { switch (m_type) {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h" #include "megbrain/custom/platform/custom_cuda.h"
using namespace mgb; using namespace mgb;
......
...@@ -42,31 +42,33 @@ using TensorImpl = DeviceTensorND; ...@@ -42,31 +42,33 @@ using TensorImpl = DeviceTensorND;
#define TensorImplConstRef(rawptr) \ #define TensorImplConstRef(rawptr) \
static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr)) static_cast<const TensorImpl&>(*reinterpret_cast<const TensorImpl*>(rawptr))
static std::unordered_map< struct DeviceMapper {
DeviceImpl::DeviceType, std::string, EnumHash<DeviceImpl::DeviceType>, using DeviceTy = DeviceImpl::DeviceType;
EnumCmp<DeviceImpl::DeviceType>> std::unordered_map<std::string, std::string> dev_cstr2bstr;
dev_benum2cstr; EnumMap<DeviceTy, std::string> dev_benum2cstr;
static std::unordered_map< EnumMap<DeviceTy, DeviceEnum> dev_benum2cenum;
DeviceImpl::DeviceType, DeviceEnum, EnumHash<DeviceImpl::DeviceType>, EnumMap<DeviceEnum, std::string> dev_cenum2bstr;
EnumCmp<DeviceImpl::DeviceType>> static DeviceMapper& inst();
dev_benum2cenum;
static std::unordered_map<std::string, std::string> dev_cstr2bstr; private:
static std::unordered_map< DeviceMapper();
DeviceEnum, std::string, EnumHash<DeviceEnum>, EnumCmp<DeviceEnum>> };
dev_cenum2bstr;
DeviceMapper::DeviceMapper() {
#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ #define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \
auto be2cs##custom_impl = dev_benum2cstr.emplace( \ dev_benum2cstr.emplace(DeviceTy::builtin_device, std::string(#custom_impl)); \
DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ dev_benum2cenum.emplace(DeviceTy::builtin_device, DeviceEnum::custom_impl); \
auto be2ce##custom_impl = dev_benum2cenum.emplace( \ dev_cstr2bstr.emplace(std::string(#custom_impl), std::string(builtin_str)); \
DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str));
auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \
std::string(#custom_impl), std::string(builtin_str)); \ CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE)
auto ce2bs##custom_impl = \
dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str));
CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE)
#undef CUSTOM_BIND_DEVICE #undef CUSTOM_BIND_DEVICE
}
DeviceMapper& DeviceMapper::inst() {
static DeviceMapper dm;
return dm;
}
CUSTOM_PIMPL_CLS_DEFINE(Device) CUSTOM_PIMPL_CLS_DEFINE(Device)
...@@ -81,6 +83,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) { ...@@ -81,6 +83,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
return; return;
} }
auto&& dev_benum2cenum = DeviceMapper::inst().dev_benum2cenum;
auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); auto builtin_device_enum = DeviceImplConstRef(impl).device_type();
mgb_assert( mgb_assert(
dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(),
...@@ -91,7 +94,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) { ...@@ -91,7 +94,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
Device::Device(const std::string& device) : m_impl(nullptr, impl_deleter<DeviceImpl>) { Device::Device(const std::string& device) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); mgb_assert(is_legal(device), "invalid device type: %s", device.c_str());
std::string builtin_device = dev_cstr2bstr[device]; std::string builtin_device = DeviceMapper::inst().dev_cstr2bstr[device];
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device)));
} }
...@@ -100,7 +103,7 @@ Device::Device(const char* device) : Device(std::string(device)) {} ...@@ -100,7 +103,7 @@ Device::Device(const char* device) : Device(std::string(device)) {}
Device::Device(DeviceEnum device) : m_impl(nullptr, impl_deleter<DeviceImpl>) { Device::Device(DeviceEnum device) : m_impl(nullptr, impl_deleter<DeviceImpl>) {
mgb_assert(is_legal(device), "invalid device type"); mgb_assert(is_legal(device), "invalid device type");
std::string builtin_device = dev_cenum2bstr[device]; std::string builtin_device = DeviceMapper::inst().dev_cenum2bstr[device];
m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device)));
} }
...@@ -110,6 +113,7 @@ std::string Device::str(void) const { ...@@ -110,6 +113,7 @@ std::string Device::str(void) const {
} }
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type();
auto&& dev_benum2cstr = DeviceMapper::inst().dev_benum2cstr;
auto iter = dev_benum2cstr.find(builtin_device_type); auto iter = dev_benum2cstr.find(builtin_device_type);
mgb_assert( mgb_assert(
iter != dev_benum2cstr.end(), "invalid device type %s\n", iter != dev_benum2cstr.end(), "invalid device type %s\n",
...@@ -123,6 +127,7 @@ DeviceEnum Device::enumv(void) const { ...@@ -123,6 +127,7 @@ DeviceEnum Device::enumv(void) const {
"cannot get the enum value of invalid device"); "cannot get the enum value of invalid device");
auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type();
auto&& dev_benum2cenum = DeviceMapper::inst().dev_benum2cenum;
auto iter = dev_benum2cenum.find(builtin_device_type); auto iter = dev_benum2cenum.find(builtin_device_type);
mgb_assert( mgb_assert(
iter != dev_benum2cenum.end(), "invalid device type %s\n", iter != dev_benum2cenum.end(), "invalid device type %s\n",
...@@ -131,16 +136,18 @@ DeviceEnum Device::enumv(void) const { ...@@ -131,16 +136,18 @@ DeviceEnum Device::enumv(void) const {
} }
bool Device::is_legal(const std::string& device_type) { bool Device::is_legal(const std::string& device_type) {
auto&& dev_cstr2bstr = DeviceMapper::inst().dev_cstr2bstr;
return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end();
} }
bool Device::is_legal(DeviceEnum device_type) { bool Device::is_legal(DeviceEnum device_type) {
auto&& dev_cenum2bstr = DeviceMapper::inst().dev_cenum2bstr;
return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end();
} }
std::vector<std::string> Device::legal_devices(void) { std::vector<std::string> Device::legal_devices(void) {
std::vector<std::string> ret; std::vector<std::string> ret;
for (const auto& kv : dev_cstr2bstr) { for (const auto& kv : DeviceMapper::inst().dev_cstr2bstr) {
ret.emplace_back(kv.first); ret.emplace_back(kv.first);
} }
return ret; return ret;
...@@ -197,36 +204,37 @@ bool operator==(const Shape& lhs, const Shape& rhs) { ...@@ -197,36 +204,37 @@ bool operator==(const Shape& lhs, const Shape& rhs) {
return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get()));
} }
static std::unordered_map<std::string, megdnn::DTypeEnum> dtype_cstr2benum; struct DTypeMapper {
static std::unordered_map< using CustomEnum = DTypeEnum;
DTypeEnum, megdnn::DTypeEnum, EnumHash<DTypeEnum>, EnumCmp<DTypeEnum>> using BuiltinEnum = megdnn::DTypeEnum;
dtype_cenum2benum;
static std::unordered_map< std::unordered_map<std::string, BuiltinEnum> dtype_cstr2benum;
megdnn::DTypeEnum, std::string, EnumHash<megdnn::DTypeEnum>, EnumMap<DTypeEnum, BuiltinEnum> dtype_cenum2benum;
EnumCmp<megdnn::DTypeEnum>> EnumMap<BuiltinEnum, std::string> dtype_benum2cstr;
dtype_benum2cstr; EnumMap<BuiltinEnum, DTypeEnum> dtype_benum2cenum;
static std::unordered_map< EnumMap<DTypeEnum, std::string> dtype_cenum2cstr;
megdnn::DTypeEnum, DTypeEnum, EnumHash<megdnn::DTypeEnum>, static DTypeMapper& inst();
EnumCmp<megdnn::DTypeEnum>>
dtype_benum2cenum; private:
static std::unordered_map< DTypeMapper();
DTypeEnum, std::string, EnumHash<DTypeEnum>, EnumCmp<DTypeEnum>> };
dtype_cenum2cstr;
DTypeMapper::DTypeMapper() {
#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ #define CUSTOM_BIND_DTYPE(custom_dty, builtin_dty, ctype) \
auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ dtype_cstr2benum.emplace(std::string(#custom_dty), BuiltinEnum::builtin_dty); \
std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ dtype_cenum2benum.emplace(DTypeEnum::custom_dty, BuiltinEnum::builtin_dty); \
auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ dtype_benum2cstr.emplace(BuiltinEnum::builtin_dty, std::string(#custom_dty)); \
DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ dtype_benum2cenum.emplace(BuiltinEnum::builtin_dty, DTypeEnum::custom_dty); \
auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ dtype_cenum2cstr.emplace(DTypeEnum::custom_dty, std::string(#custom_dty));
megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \
auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE)
megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \
auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \
DTypeEnum::custom_impl, std::string(#custom_impl));
CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE)
#undef CUSTOM_BIND_DTYPE #undef CUSTOM_BIND_DTYPE
}
DTypeMapper& DTypeMapper::inst() {
static DTypeMapper dm;
return dm;
}
CUSTOM_PIMPL_CLS_DEFINE(DType) CUSTOM_PIMPL_CLS_DEFINE(DType)
...@@ -240,6 +248,7 @@ DType::DType(const void* impl) : m_impl(nullptr, impl_deleter<DTypeImpl>) { ...@@ -240,6 +248,7 @@ DType::DType(const void* impl) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
} }
DType::DType(const std::string& dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) { DType::DType(const std::string& dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum;
auto iter = dtype_cstr2benum.find(dtype); auto iter = dtype_cstr2benum.find(dtype);
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str());
mgb_assert( mgb_assert(
...@@ -254,6 +263,7 @@ DType::DType(const char* dtype) : DType(std::string(dtype)) {} ...@@ -254,6 +263,7 @@ DType::DType(const char* dtype) : DType(std::string(dtype)) {}
DType::DType(const std::string& dtype, float scale, uint8_t zero_point) DType::DType(const std::string& dtype, float scale, uint8_t zero_point)
: m_impl(nullptr, impl_deleter<DTypeImpl>) { : m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum;
auto iter = dtype_cstr2benum.find(dtype); auto iter = dtype_cstr2benum.find(dtype);
mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str());
mgb_assert( mgb_assert(
...@@ -289,6 +299,7 @@ DType::DType(const char* dtype, float scale, uint8_t zero_point) ...@@ -289,6 +299,7 @@ DType::DType(const char* dtype, float scale, uint8_t zero_point)
: DType(std::string(dtype), scale, zero_point) {} : DType(std::string(dtype), scale, zero_point) {}
DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) { DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
auto&& dtype_cenum2benum = DTypeMapper::inst().dtype_cenum2benum;
auto iter = dtype_cenum2benum.find(dtype); auto iter = dtype_cenum2benum.find(dtype);
mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype");
mgb_assert( mgb_assert(
...@@ -298,11 +309,13 @@ DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) { ...@@ -298,11 +309,13 @@ DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter<DTypeImpl>) {
} }
DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point)
: DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) {} : DType(DTypeMapper::inst().dtype_cenum2cstr.find(dtype)->second, scale,
zero_point) {}
std::string DType::str(void) const { std::string DType::str(void) const {
if (!DTypeImplRef(m_impl.get()).valid()) if (!DTypeImplRef(m_impl.get()).valid())
return "invalid"; return "invalid";
auto&& dtype_benum2cstr = DTypeMapper::inst().dtype_benum2cstr;
auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv());
if (iter == dtype_benum2cstr.end()) if (iter == dtype_benum2cstr.end())
return "invalid"; return "invalid";
...@@ -310,6 +323,7 @@ std::string DType::str(void) const { ...@@ -310,6 +323,7 @@ std::string DType::str(void) const {
} }
DTypeEnum DType::enumv(void) const { DTypeEnum DType::enumv(void) const {
auto&& dtype_benum2cenum = DTypeMapper::inst().dtype_benum2cenum;
auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv());
mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype");
return iter->second; return iter->second;
...@@ -337,16 +351,18 @@ uint8_t DType::zero_point() const { ...@@ -337,16 +351,18 @@ uint8_t DType::zero_point() const {
} }
bool DType::is_legal(const std::string& dtype) { bool DType::is_legal(const std::string& dtype) {
auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum;
return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end();
} }
bool DType::is_legal(const DTypeEnum& dtype) { bool DType::is_legal(const DTypeEnum& dtype) {
auto&& dtype_cenum2benum = DTypeMapper::inst().dtype_cenum2benum;
return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end();
} }
std::vector<std::string> DType::legal_dtypes(void) { std::vector<std::string> DType::legal_dtypes(void) {
std::vector<std::string> ret; std::vector<std::string> ret;
for (const auto& kv : dtype_cstr2benum) for (const auto& kv : DTypeMapper::inst().dtype_cstr2benum)
ret.emplace_back(kv.first); ret.emplace_back(kv.first);
return ret; return ret;
} }
......
#pragma once #pragma once
#include "megbrain/custom/op.h"
#include "megbrain/custom/tensor.h"
#include "megbrain/tensor.h"
#include "megdnn/thin/small_vector.h" #include "megdnn/thin/small_vector.h"
namespace custom { namespace custom {
...@@ -11,27 +14,32 @@ BuiltinT to_builtin(const CustomT& custom) { ...@@ -11,27 +14,32 @@ BuiltinT to_builtin(const CustomT& custom) {
template <typename BuiltinT, typename CustomT> template <typename BuiltinT, typename CustomT>
CustomT to_custom(const BuiltinT& builtin) { CustomT to_custom(const BuiltinT& builtin) {
return std::move(CustomT(&builtin)); return CustomT(&builtin);
} }
template <typename BuiltinT, typename CustomT> template <typename BuiltinT, typename CustomT>
megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT>& customs) { megdnn::SmallVector<BuiltinT> to_builtin(const std::vector<CustomT>& customs) {
megdnn::SmallVector<BuiltinT> builtins; megdnn::SmallVector<BuiltinT> builtins;
for (size_t i = 0; i < customs.size(); ++i) { for (size_t i = 0; i < customs.size(); ++i) {
builtins.push_back(std::move(to_builtin<BuiltinT, CustomT>(customs[i]))); builtins.emplace_back(to_builtin<BuiltinT, CustomT>(customs[i]));
} }
return std::move(builtins); return builtins;
} }
template <typename BuiltinT, typename CustomT> template <typename BuiltinT, typename CustomT>
std::vector<CustomT> to_custom(const megdnn::SmallVector<BuiltinT>& builtins) { std::vector<CustomT> to_custom(const megdnn::SmallVector<BuiltinT>& builtins) {
std::vector<CustomT> customs; std::vector<CustomT> customs;
for (size_t i = 0; i < builtins.size(); ++i) { for (size_t i = 0; i < builtins.size(); ++i) {
customs.push_back(std::move(to_custom<BuiltinT, CustomT>(builtins[i]))); customs.emplace_back(to_custom<BuiltinT, CustomT>(builtins[i]));
} }
return std::move(customs); return customs;
} }
MGE_WIN_DECLSPEC_FUC void dispatch_custom_op(
std::shared_ptr<const CustomOp> op, const Param& param,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs,
std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs);
} // namespace custom } // namespace custom
#define to_custom_device(expr) \ #define to_custom_device(expr) \
......
...@@ -5,10 +5,26 @@ ...@@ -5,10 +5,26 @@
namespace custom { namespace custom {
class CustomLib {
std::unique_ptr<void, void_deleter> m_handle;
std::vector<std::string> m_ops;
public:
PREVENT_COPY_AND_ASSIGN(CustomLib);
CustomLib(const std::string& path, int mode);
~CustomLib();
MGE_WIN_DECLSPEC_FUC const std::vector<std::string>& ops_in_lib(void) const;
bool valid(void) const;
};
using LibHandle = std::shared_ptr<CustomLib>;
class CustomOpManager { class CustomOpManager {
std::unordered_map<std::string, LibHandle> m_custom_libs;
std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op; std::unordered_map<std::string, std::shared_ptr<const CustomOp>> m_name2op;
std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op; std::unordered_map<RunTimeId, std::shared_ptr<const CustomOp>> m_id2op;
MGB_MUTEX m_mtx; MGB_MUTEX m_lib_mtx;
MGB_MUTEX m_op_mtx;
CustomOpManager() = default; CustomOpManager() = default;
public: public:
...@@ -16,13 +32,15 @@ public: ...@@ -16,13 +32,15 @@ public:
MGE_WIN_DECLSPEC_FUC static CustomOpManager* inst(void); MGE_WIN_DECLSPEC_FUC static CustomOpManager* inst(void);
MGE_WIN_DECLSPEC_FUC ~CustomOpManager(); MGE_WIN_DECLSPEC_FUC ~CustomOpManager();
MGE_WIN_DECLSPEC_FUC const std::vector<std::string>& install(
const std::string& name, const std::string& path);
MGE_WIN_DECLSPEC_FUC std::vector<std::string> uninstall(const std::string& name);
MGE_WIN_DECLSPEC_FUC const std::unordered_map<std::string, LibHandle>& lib_info(
void) const;
MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> insert( MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> insert(
const std::string& name, uint32_t version); const std::string& name, uint32_t version);
MGE_WIN_DECLSPEC_FUC bool erase(const std::string& name); MGE_WIN_DECLSPEC_FUC bool erase(const std::string& name);
MGE_WIN_DECLSPEC_FUC bool erase(const RunTimeId& id);
MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> find_or_reg(
const std::string& name, uint32_t version);
MGE_WIN_DECLSPEC_FUC RunTimeId to_id(const std::string& name) const; MGE_WIN_DECLSPEC_FUC RunTimeId to_id(const std::string& name) const;
MGE_WIN_DECLSPEC_FUC std::string to_name(const RunTimeId& id) const; MGE_WIN_DECLSPEC_FUC std::string to_name(const RunTimeId& id) const;
...@@ -36,35 +54,4 @@ public: ...@@ -36,35 +54,4 @@ public:
MGE_WIN_DECLSPEC_FUC std::vector<RunTimeId> op_id_list(void); MGE_WIN_DECLSPEC_FUC std::vector<RunTimeId> op_id_list(void);
}; };
class CustomLib {
std::unique_ptr<void, void_deleter> m_handle;
std::vector<std::string> m_ops;
public:
PREVENT_COPY_AND_ASSIGN(CustomLib);
CustomLib(const std::string& path, int mode);
const std::vector<std::string>& ops_in_lib(void) const;
~CustomLib();
bool valid(void) const;
};
using LibHandle = std::shared_ptr<CustomLib>;
class LibManager {
std::unordered_map<std::string, LibHandle> m_custom_libs;
MGB_MUTEX m_mtx;
LibManager() = default;
public:
PREVENT_COPY_AND_ASSIGN(LibManager);
MGE_WIN_DECLSPEC_FUC static LibManager* inst(void);
MGE_WIN_DECLSPEC_FUC const std::vector<std::string>& install(
const std::string& name, const std::string& path);
MGE_WIN_DECLSPEC_FUC bool uninstall(const std::string& name);
friend class CustomOpManager;
};
} // namespace custom } // namespace custom
...@@ -76,8 +76,6 @@ class Device; ...@@ -76,8 +76,6 @@ class Device;
* Macro Callback for Register * Macro Callback for Register
*/ */
#define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, #define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type,
#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \
{ParamDynType::dyn_type, #dyn_type},
#define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ #define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \
template <> \ template <> \
...@@ -95,10 +93,7 @@ enum class ParamDynType : uint32_t { ...@@ -95,10 +93,7 @@ enum class ParamDynType : uint32_t {
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) Invalid = 255 CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) Invalid = 255
}; };
static std::unordered_map< MGE_WIN_DECLSPEC_FUC std::string ptype2name(ParamDynType);
ParamDynType, std::string, EnumHash<ParamDynType>, EnumCmp<ParamDynType>>
type2name = {CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME){
ParamDynType::Invalid, "Invalid"}};
/** /**
* get the dynamic data type according to the builtin static data type * get the dynamic data type according to the builtin static data type
...@@ -124,7 +119,6 @@ CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) ...@@ -124,7 +119,6 @@ CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER)
CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER)
#undef CUSTOM_REG_DYN_PARAMTYPE #undef CUSTOM_REG_DYN_PARAMTYPE
#undef CUSTOM_REG_DYN_PARAMTYPE_NAME
#undef CUSTOM_REG_DYN_PARAMTYPE_GETTER #undef CUSTOM_REG_DYN_PARAMTYPE_GETTER
#undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER #undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER
...@@ -290,7 +284,7 @@ T& ParamVal::as(void) { ...@@ -290,7 +284,7 @@ T& ParamVal::as(void) {
ParamDynType t_dyn_type = get_dyn_type<DecayType>::type; ParamDynType t_dyn_type = get_dyn_type<DecayType>::type;
custom_assert( custom_assert(
t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n",
type2name[m_type].c_str(), type2name[t_dyn_type].c_str()); ptype2name(m_type).c_str(), ptype2name(t_dyn_type).c_str());
return TypedRef(T, m_ptr.get()); return TypedRef(T, m_ptr.get());
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
namespace custom { namespace custom {
...@@ -108,4 +109,7 @@ struct EnumCmp { ...@@ -108,4 +109,7 @@ struct EnumCmp {
} }
}; };
template <typename Key, typename Value>
using EnumMap = std::unordered_map<Key, Value, EnumHash<Key>, EnumCmp<Key>>;
} // namespace custom } // namespace custom
...@@ -12,16 +12,17 @@ namespace custom { ...@@ -12,16 +12,17 @@ namespace custom {
TEST(TestOpManager, TestOpManager) { TEST(TestOpManager, TestOpManager) {
CustomOpManager* com = CustomOpManager::inst(); CustomOpManager* com = CustomOpManager::inst();
std::vector<std::string> builtin_op_names = com->op_name_list();
size_t builtin_op_num = builtin_op_names.size();
com->insert("Op1", CUSTOM_OP_VERSION); com->insert("Op1", CUSTOM_OP_VERSION);
com->insert("Op2", CUSTOM_OP_VERSION); com->insert("Op2", CUSTOM_OP_VERSION);
std::shared_ptr<CustomOp> ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION);
ASSERT_TRUE(ptr != nullptr);
std::vector<std::string> op_names = com->op_name_list(); std::vector<std::string> op_names = com->op_name_list();
std::vector<RunTimeId> op_ids = com->op_id_list(); std::vector<RunTimeId> op_ids = com->op_id_list();
ASSERT_TRUE(op_names.size() == 3); ASSERT_TRUE(op_names.size() == builtin_op_num + 2);
ASSERT_TRUE(op_ids.size() == 3); ASSERT_TRUE(op_ids.size() == builtin_op_num + 2);
#if MANAGER_TEST_LOG #if MANAGER_TEST_LOG
for (std::string& name : op_names) { for (std::string& name : op_names) {
...@@ -52,12 +53,9 @@ TEST(TestOpManager, TestOpManager) { ...@@ -52,12 +53,9 @@ TEST(TestOpManager, TestOpManager) {
} }
#endif #endif
ASSERT_TRUE(com->erase("Op1")); ASSERT_TRUE(com->erase("Op1"));
ASSERT_TRUE(com->erase(com->to_id("Op2"))); ASSERT_TRUE(com->op_id_list().size() == builtin_op_num + 1);
ASSERT_TRUE(com->op_id_list().size() == 1); ASSERT_TRUE(com->op_name_list().size() == builtin_op_num + 1);
ASSERT_TRUE(com->op_name_list().size() == 1); ASSERT_TRUE(com->erase("Op2"));
ASSERT_TRUE(com->op_name_list()[0] == "Op3");
ptr.reset();
ASSERT_TRUE(com->erase("Op3"));
} }
TEST(TestOpManager, TestOpReg) { TEST(TestOpManager, TestOpReg) {
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/custom/op.h" #include "megbrain/custom/op.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain/test/helper.h"
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
#define OP_TEST_LOG 0 #define OP_TEST_LOG 0
...@@ -93,60 +94,6 @@ void format_infer( ...@@ -93,60 +94,6 @@ void format_infer(
outputs[1] = inputs[0]; outputs[1] = inputs[0];
} }
void cpu_kernel(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs) {
(void)inputs;
(void)params;
(void)outputs;
#if OP_TEST_LOG
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "x86");
}
void gpu_kernel(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs) {
(void)inputs;
(void)params;
(void)outputs;
#if OP_TEST_LOG
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "cuda");
}
void cpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "x86");
}
void gpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "cuda");
}
TEST(TestCustomOp, TestCustomOpFuncSetter) { TEST(TestCustomOp, TestCustomOpFuncSetter) {
#if MGB_CUDA #if MGB_CUDA
CustomOp test("TestOp", CUSTOM_OP_VERSION); CustomOp test("TestOp", CUSTOM_OP_VERSION);
...@@ -155,7 +102,8 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ...@@ -155,7 +102,8 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
.add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2)
.add_output("outl", "outl of Test op", {"float32", "int32"}, 2) .add_output("outl", "outl of Test op", {"float32", "int32"}, 2)
.add_output("outr", "outr of Test op", {"float32", "int32"}, 2) .add_output("outr", "outr of Test op", {"float32", "int32"}, 2)
.add_param("smooth", "smooth", 0.f) .add_param("scale_f", "scale_f", 1.f)
.add_param("offset_i", "offset_i", 0)
.add_param("device", "using for judge device", "x86"); .add_param("device", "using for judge device", "x86");
std::vector<Device> idevices = {"x86", "cuda"}; std::vector<Device> idevices = {"x86", "cuda"};
...@@ -206,35 +154,93 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ...@@ -206,35 +154,93 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
ASSERT_TRUE(odtypes[1] == "int32"); ASSERT_TRUE(odtypes[1] == "int32");
ASSERT_TRUE(iformats[0].is_default()); ASSERT_TRUE(iformats[0].is_default());
ASSERT_TRUE(iformats[1].is_default()); ASSERT_TRUE(iformats[1].is_default());
#endif
}
test.set_compute(cpu_kernel_with_runtime_args); void cpu_kernel(
test.set_compute(cpu_kernel); const std::vector<Tensor>& inputs, const Param& params,
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); std::vector<Tensor>& outputs) {
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); ASSERT_TRUE(inputs.size() == 2);
DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); ASSERT_TRUE(outputs.size() == 2);
DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); ASSERT_TRUE(params["device"] == "x86");
ASSERT_TRUE(params["scale_f"] == 2.12f);
std::vector<Tensor> cinputs = { ASSERT_TRUE(params["offset_i"] == 6);
to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; ASSERT_TRUE(inputs[0].shape() == Shape({3, 4}));
std::vector<Tensor> coutputs = { ASSERT_TRUE(inputs[1].shape() == Shape({5, 6}));
to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; ASSERT_TRUE(outputs[0].shape() == Shape({5, 6}));
ASSERT_TRUE(outputs[1].shape() == Shape({3, 4}));
ASSERT_TRUE(inputs[0].device() == "x86");
ASSERT_TRUE(inputs[1].device() == "x86");
ASSERT_TRUE(outputs[0].device() == "x86");
ASSERT_TRUE(outputs[1].device() == "x86");
float scale_f = params["scale_f"].as<float>();
int offset_i = params["offset_i"].as<int>();
for (size_t i = 0; i < 5 * 6; ++i) {
ASSERT_TRUE(inputs[1].data<float>()[i] == static_cast<float>(i));
outputs[0].data<float>()[i] = inputs[1].data<float>()[i] * scale_f;
}
for (size_t i = 0; i < 3 * 4; ++i) {
ASSERT_TRUE(inputs[0].data<int>()[i] == static_cast<int>(i));
outputs[1].data<int>()[i] = inputs[0].data<int>()[i] + offset_i;
}
}
TEST(TestCustomOp, TestCustomOpCompute) {
std::shared_ptr<CustomOp> op =
std::make_shared<CustomOp>("TestOp", CUSTOM_OP_VERSION);
op->set_description("Test Op Forward Backward Union")
.add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2)
.add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2)
.add_output("outl", "outl of Test op", {"float32", "int32"}, 2)
.add_output("outr", "outr of Test op", {"float32", "int32"}, 2)
.add_param("scale_f", "scale_f", 1.f)
.add_param("offset_i", "offset_i", 0)
.add_param("device", "using for judge device", "x86")
.set_shape_infer(shape_infer)
.set_dtype_infer(dtype_infer)
.set_compute("x86", cpu_kernel);
Param param(op->param_info());
param["device"] = "x86"; param["device"] = "x86";
test.compute(cinputs, param, coutputs); param["scale_f"] = 2.12f;
param["offset_i"] = 6;
test.set_compute("cuda", gpu_kernel_with_runtime_args);
test.set_compute("cuda", gpu_kernel); HostTensorGenerator<dtype::Float32> gen_f;
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); HostTensorGenerator<dtype::Int32> gen_i;
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); auto host_i0 = gen_i({3, 4}), host_i1 = gen_f({5, 6});
DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); auto expect_o0 = gen_f({5, 6}), expect_o1 = gen_i({3, 4});
DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); for (size_t i = 0; i < 5 * 6; ++i) {
host_i1->ptr<float>()[i] = static_cast<float>(i);
std::vector<Tensor> ginputs = { expect_o0->ptr<float>()[i] = host_i1->ptr<float>()[i] * 2.12f;
to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; }
std::vector<Tensor> goutputs = { for (size_t i = 0; i < 3 * 4; ++i) {
to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; host_i0->ptr<int>()[i] = static_cast<int>(i);
param["device"] = "cuda"; expect_o1->ptr<int>()[i] = host_i0->ptr<int>()[i] + 6;
test.compute(ginputs, param, goutputs); }
#endif
auto cn = CompNode::load("cpux");
std::shared_ptr<SmallVector<DeviceTensorND>> x86_inps =
std::make_shared<SmallVector<DeviceTensorND>>(2);
x86_inps->at(0) = DeviceTensorND{cn};
x86_inps->at(1) = DeviceTensorND{cn};
x86_inps->at(0).copy_from(*host_i0).sync();
x86_inps->at(1).copy_from(*host_i1).sync();
std::shared_ptr<SmallVector<DeviceTensorND>> x86_oups =
std::make_shared<SmallVector<DeviceTensorND>>(2);
x86_oups->at(0) = DeviceTensorND{cn, {5, 6}, dtype::Float32{}};
x86_oups->at(1) = DeviceTensorND{cn, {3, 4}, dtype::Int32{}};
dispatch_custom_op(op, param, x86_inps, x86_oups);
cn.sync();
HostTensorND host_o0, host_o1;
host_o0.copy_from(x86_oups->at(0)).sync();
host_o1.copy_from(x86_oups->at(1)).sync();
MGB_ASSERT_TENSOR_NEAR(*expect_o0, host_o0, 1e-6);
MGB_ASSERT_TENSOR_NEAR(*expect_o1, host_o1, 1e-6);
} }
} // namespace custom } // namespace custom
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "megbrain/comp_node.h" #include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/adaptor.h"
#include "megbrain/custom/tensor.h" #include "megbrain/custom/tensor.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain_build_config.h" #include "megbrain_build_config.h"
......
...@@ -114,24 +114,21 @@ void CustomOpNode::init_output_comp_node() { ...@@ -114,24 +114,21 @@ void CustomOpNode::init_output_comp_node() {
void CustomOpNode::do_execute(ExecEnv& env) { void CustomOpNode::do_execute(ExecEnv& env) {
auto runner = [this]() { auto runner = [this]() {
std::shared_ptr<SmallVector<DeviceTensorND>> inputs =
std::make_shared<SmallVector<DeviceTensorND>>();
std::shared_ptr<SmallVector<DeviceTensorND>> outputs =
std::make_shared<SmallVector<DeviceTensorND>>();
for (size_t i = 0; i < input_num(); i++) {
inputs->emplace_back(input(i)->dev_tensor());
}
for (size_t i = 0; i < output_num(); i++) {
outputs->emplace_back(output(i)->dev_tensor());
}
this->owner_graph()->event().signal_inplace<cg::event::BeforeKernel>( this->owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
this, m_comp_node); this, m_comp_node);
m_comp_node.activate(); m_comp_node.activate();
custom::dispatch_custom_op(m_op, m_param, inputs, outputs);
SmallVector<DeviceTensorND> inputs, outputs;
for (size_t i = 0; i < input_num(); i++)
inputs.push_back(input(i)->dev_tensor());
for (size_t i = 0; i < output_num(); i++)
outputs.push_back(output(i)->dev_tensor());
std::vector<custom::Tensor> custom_inputs =
custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
std::vector<custom::Tensor> custom_outputs =
custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
m_op->compute(custom_inputs, m_param, custom_outputs);
// [TODO] sync should be modified
CompNode::sync_all();
this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>( this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
this, m_comp_node); this, m_comp_node);
}; };
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
#include "megbrain/custom/adaptor.h"
#include "megbrain/custom/custom.h" #include "megbrain/custom/custom.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/manager.h" #include "megbrain/custom/manager.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册