From 8a692573f816c8097bdca65aca87d66089144de5 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 24 Aug 2023 12:19:32 +0800 Subject: [PATCH] refactor(customop): support write builtin op with custom op GitOrigin-RevId: cd90002fe851a025b002e918f4b6f638936e660f --- .../python/megengine/core/ops/custom.py | 22 +- .../python/megengine/utils/custom_op_tools.py | 15 +- imperative/python/src/ops.cpp | 58 ++-- .../unit/core/custom_opsrc/matmul_scale.cu | 20 +- .../python/test/unit/core/test_custom_op.py | 297 ++++++++++++++---- imperative/src/impl/ops/custom_opdef.cpp | 40 +-- .../megbrain/imperative/ops/custom_opdef.h | 3 +- src/custom/impl/manager.cpp | 143 ++++----- src/custom/impl/op.cpp | 42 +++ src/custom/impl/param_val.cpp | 18 +- src/custom/impl/platform/custom_cuda.cpp | 2 +- src/custom/impl/tensor.cpp | 132 ++++---- .../custom/{data_adaptor.h => adaptor.h} | 18 +- src/custom/include/megbrain/custom/manager.h | 59 ++-- .../include/megbrain/custom/param_val.h | 10 +- src/custom/include/megbrain/custom/utils.h | 4 + src/custom/test/manager.cpp | 18 +- src/custom/test/op.cpp | 172 +++++----- src/custom/test/tensor.cpp | 2 +- src/opr/impl/custom_opnode.cpp | 27 +- src/opr/include/megbrain/opr/custom_opnode.h | 2 +- 21 files changed, 667 insertions(+), 437 deletions(-) rename src/custom/include/megbrain/custom/{data_adaptor.h => adaptor.h} (75%) diff --git a/imperative/python/megengine/core/ops/custom.py b/imperative/python/megengine/core/ops/custom.py index ec458c078..27041c64e 100644 --- a/imperative/python/megengine/core/ops/custom.py +++ b/imperative/python/megengine/core/ops/custom.py @@ -3,6 +3,7 @@ import os from .._imperative_rt.ops._custom import ( + _get_custom_op_lib_info, _get_custom_op_list, _install, _make_custom_op, @@ -22,8 +23,7 @@ def _gen_custom_op_maker(custom_op_name): def load(lib_path): lib_path = os.path.abspath(lib_path) - lib_name = os.path.splitext(lib_path)[0] - op_in_this_lib = _install(lib_name, lib_path) + op_in_this_lib = _install(lib_path, lib_path) for op in op_in_this_lib: op_maker = _gen_custom_op_maker(op) globals()[op] = op_maker @@ -32,5 +32,19 @@ def load(lib_path): def unload(lib_path): lib_path = os.path.abspath(lib_path) - lib_name = os.path.splitext(lib_path)[0] - _uninstall(lib_name) + op_in_lib = _uninstall(lib_path) + for op in op_in_lib: + del globals()[op] + __all__.remove(op) + + +def _make_official_custom_op(): + official_opr_list = _get_custom_op_list() + for op in official_opr_list: + op_maker = _gen_custom_op_maker(op) + if op not in globals(): + globals()[op] = op_maker + __all__.append(op) + + +_make_official_custom_op() diff --git a/imperative/python/megengine/utils/custom_op_tools.py b/imperative/python/megengine/utils/custom_op_tools.py index b56604ef3..2d345ebaa 100644 --- a/imperative/python/megengine/utils/custom_op_tools.py +++ b/imperative/python/megengine/utils/custom_op_tools.py @@ -782,6 +782,10 @@ def build( with_cudnn, abi_tag, ) + + target_libpath = "{}_v{}".format(name, version) + str( + ".dll" if IS_WINDOWS else ".so" + ) if verbose: if version != old_version and old_version != None: print( @@ -795,8 +799,7 @@ def build( print( "No modifications detected for {}, skipping build step...".format(name) ) - return - name = "{}_v{}".format(name, version) + return os.path.join(build_dir, "{}".format(target_libpath)) # phase 3: compiler and ninja check _check_ninja_availability() @@ -830,8 +833,6 @@ def build( try: # phase 5: generate ninja build file objs = [_obj_file_path(src) for src in sources] - name += ".dll" if IS_WINDOWS else ".so" - build_file_path = os.path.join(build_dir, "build.ninja") if verbose: print("Emitting ninja build file {}".format(build_file_path)) @@ -844,7 +845,7 @@ def build( sources=sources, objects=objs, ldflags=ldflags, - library_target=name, + library_target=target_libpath, with_cuda=with_cuda, ) @@ -852,7 +853,7 @@ def build( if verbose: print( "Compiling and linking your custom op {}".format( - os.path.join(build_dir, name) + os.path.join(build_dir, target_libpath) ) ) _build_with_ninja(build_dir, verbose, "compiling error") @@ -861,7 +862,7 @@ def build( else: baton.wait() - return os.path.join(build_dir, name) + return os.path.join(build_dir, target_libpath) def build_and_load( diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 1abf9733f..239662f5e 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -3,7 +3,7 @@ #include "./tensor.h" #include "megbrain/common.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/imperative.h" #include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/ops/autogen.h" @@ -725,9 +725,7 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { return obj; #else - mgb_assert( - false, - "Custom Op is disabled now, please build megengine with Custom Op open"); + mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open"); return nullptr; #endif } @@ -737,46 +735,49 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { py::list install_custom(const std::string& name, const std::string& path) { #if MGB_CUSTOM_OP - py::list ret; - const auto& ops_in_lib = custom::LibManager::inst()->install(name, path); - for (const auto& op : ops_in_lib) { - ret.append(op); - } + const auto& ops_in_lib = custom::CustomOpManager::inst()->install(name, path); + py::list ret = py::cast(ops_in_lib); return ret; #else - mgb_assert( - false, - "Custom Op is disabled now, please build megengine with Custom Op open"); - py::list ret; - return ret; + mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open"); + return py::list{}; #endif } -bool uninstall_custom(const std::string& name) { +py::list uninstall_custom(const std::string& name) { #if MGB_CUSTOM_OP - return custom::LibManager::inst()->uninstall(name); + const auto& ops_in_lib = custom::CustomOpManager::inst()->uninstall(name); + py::list ret = py::cast(ops_in_lib); + return ret; #else - mgb_assert( - false, - "Custom Op is disabled now, please build megengine with Custom Op open"); + mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open"); return false; #endif } py::list get_custom_op_list(void) { #if MGB_CUSTOM_OP - std::vector all_ops = CustomOpDefFactory::inst()->op_list(); - py::list ret; - for (auto& op : all_ops) { - ret.append(op); - } + std::vector all_ops = custom::CustomOpManager::inst()->op_name_list(); + py::list ret = py::cast(all_ops); return ret; #else - mgb_assert( - false, - "Custom Op is disabled now, please build megengine with Custom Op open"); - py::list ret; + mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open"); + return py::list{}; +#endif +} + +py::dict get_custom_op_lib_info(void) { +#if MGB_CUSTOM_OP + auto&& libs = custom::CustomOpManager::inst()->lib_info(); + py::dict ret; + for (auto&& [lib_name, lib_handle] : libs) { + py::list ops = py::cast(lib_handle->ops_in_lib()); + ret[py::str(lib_name)] = ops; + } return ret; +#else + mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open"); + return py::list{}; #endif } @@ -792,6 +793,7 @@ void init_custom(pybind11::module m) { m.def("_install", &install_custom); m.def("_uninstall", &uninstall_custom); m.def("_get_custom_op_list", &get_custom_op_list); + m.def("_get_custom_op_lib_info", &get_custom_op_lib_info); m.def("get_custom_op_abi_tag", [](void) -> int { int ret = 0; #ifdef _GLIBCXX_USE_CXX11_ABI diff --git a/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu index 181b11c85..9ce0bbf93 100644 --- a/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu +++ b/imperative/python/test/unit/core/custom_opsrc/matmul_scale.cu @@ -2,6 +2,7 @@ #include #include #include "./matmul_scale.h" +#include "megbrain/custom/platform/custom_cuda.h" using namespace custom; @@ -51,12 +52,13 @@ void matmul_forward_helper( float scale) { dim3 block(1, 1); dim3 grid(N / block.x, M / block.y); - - DISPATCH_INT_AND_FLOAT_TYPES(res.dtype(), "matmul_forward", ([&]() { - matmul_forward_naive<<>>( - lhs.data(), rhs.data(), - res.data(), M, K, N, scale); - })); + auto stream = get_cuda_stream(lhs.device()); + DISPATCH_INT_AND_FLOAT_TYPES( + res.dtype(), "matmul_forward", ([&]() { + matmul_forward_naive<<>>( + lhs.data(), rhs.data(), + res.data(), M, K, N, scale); + })); } void matmul_backward_lhs_helper( @@ -64,9 +66,10 @@ void matmul_backward_lhs_helper( size_t N, float scale) { dim3 block(1, 1); dim3 grid(K / block.x, M / block.y); + auto stream = get_cuda_stream(rhs.device()); DISPATCH_INT_AND_FLOAT_TYPES( lhs_grad.dtype(), "matmul_backward_lhs", ([&]() { - matmul_backward_lhs_naive<<>>( + matmul_backward_lhs_naive<<>>( rhs.data(), ograd.data(), lhs_grad.data(), M, K, N, scale); })); @@ -77,9 +80,10 @@ void matmul_backward_rhs_helper( size_t N, float scale) { dim3 block(1, 1); dim3 grid(N / block.x, K / block.y); + auto stream = get_cuda_stream(lhs.device()); DISPATCH_INT_AND_FLOAT_TYPES( rhs_grad.dtype(), "matmul_backward_rhs", ([&]() { - matmul_backward_rhs_naive<<>>( + matmul_backward_rhs_naive<<>>( lhs.data(), ograd.data(), rhs_grad.data(), M, K, N, scale); })); diff --git a/imperative/python/test/unit/core/test_custom_op.py b/imperative/python/test/unit/core/test_custom_op.py index c9b4b6426..e5b093bdd 100644 --- a/imperative/python/test/unit/core/test_custom_op.py +++ b/imperative/python/test/unit/core/test_custom_op.py @@ -6,92 +6,132 @@ import sys import numpy as np import pytest -import megengine import megengine.functional as F -import megengine.optimizer as optim from megengine import jit from megengine.autodiff import Function, GradManager from megengine.core._imperative_rt.core2 import apply from megengine.core.ops import custom from megengine.device import get_device_count -from megengine.module import Conv2d, Linear, Module -from megengine.random import normal -from megengine.tensor import Parameter, Tensor +from megengine.tensor import Tensor from megengine.utils import custom_op_tools +build_path = os.path.join( + custom_op_tools._get_default_build_root(), "custom_opsrc", "build" +) +cur_dir_path = os.path.dirname(os.path.abspath(__file__)) +mgb_root_path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(cur_dir_path)))) +) +extra_include_paths = [os.path.join(mgb_root_path, "src", "custom", "include")] -def compare(ref, real): - if ref.shape != real.shape: - real = real.T - np.testing.assert_allclose(ref, real, rtol=1e-3, atol=1e-5) - +extra_ld_flags = [] +if sys.platform != "win32": + ld_path = os.environ.get("LD_LIBRARY_PATH") + if ld_path != None: + ld_dirs = ld_path.split(":") + for ld_dir in ld_dirs: + if os.path.exists(ld_dir) and os.path.isdir(ld_dir): + for lib in os.listdir(ld_dir): + if "megengine_shared" in lib: + extra_ld_flags += ["-L{} -Wl,-rpath,{}".format(ld_dir, ld_dir)] + break -def build_and_clean(test_func): - def wrapper(): - cur_dir_path = os.path.dirname(os.path.abspath(__file__)) - build_root_dir = custom_op_tools._get_default_build_root() - build_path = os.path.join(build_root_dir, "custom_opsrc", "build") - if os.path.exists(build_path): - shutil.rmtree(build_path) +def build_and_clean(*srcs): + def deco(test_func): + custom_op_srcs = [os.path.join(cur_dir_path, "custom_opsrc", s) for s in srcs] - mgb_root_path = os.path.dirname( - os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(cur_dir_path))) - ) - ) - extra_include_paths = [os.path.join(mgb_root_path, "src", "custom", "include")] - extra_ld_flags = [] - - if sys.platform != "win32": - ld_path = os.environ.get("LD_LIBRARY_PATH") - if ld_path != None: - ld_dirs = ld_path.split(":") - for ld_dir in ld_dirs: - if os.path.exists(ld_dir) and os.path.isdir(ld_dir): - for lib in os.listdir(ld_dir): - if "megengine_shared" in lib: - extra_ld_flags += [ - "-L{} -Wl,-rpath,{}".format(ld_dir, ld_dir) - ] - break - - if get_device_count("gpu") > 0: - custom_opsrc = [ - os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cpp"), - os.path.join(cur_dir_path, "custom_opsrc", "matmul_scale.cu"), - ] - else: - custom_opsrc = [os.path.join(cur_dir_path, "custom_opsrc", "elem_add.cpp")] - - try: + def wrapper(*args, **kwargs): lib_path = custom_op_tools.build_and_load( "test_op", - custom_opsrc, + custom_op_srcs, extra_include_paths=extra_include_paths, - extra_ldflags=extra_ld_flags, build_dir=build_path, - verbose=False, + extra_ldflags=extra_ld_flags, + verbose=True, ) - test_func() + test_func(*args, **kwargs) custom.unload(lib_path) - finally: - if os.path.exists(build_path): - shutil.rmtree(build_path) + return wrapper - return wrapper + return deco @pytest.mark.skipif( get_device_count("gpu") > 0, reason="elem_add operator is only supported on CPU" ) -@build_and_clean -def test_custom_op_cpu_build(): - assert "ElemAddSmoothForward" in custom._get_custom_op_list() - assert "ElemAddSmoothBackward" in custom._get_custom_op_list() - assert hasattr(custom, "ElemAddSmoothForward") - assert hasattr(custom, "ElemAddSmoothBackward") +@build_and_clean("elem_add.cpp") +def test_cpu_func(): + class ElemAddSmooth(Function): + def __init__(self, smooth): + super().__init__() + self.smooth = smooth + + def forward(self, lhs, rhs): + op = custom.ElemAddSmoothForward(smooth=self.smooth) + return apply(op, lhs, rhs)[0] + + def backward(self, ograd): + op = custom.ElemAddSmoothBackward() + return apply(op, ograd) + + def gen_elemadd_data(seed, shape, low=-1, high=1): + rng = np.random.RandomState(seed=seed) + lhs_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32) + rhs_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32) + ograd_np = rng.uniform(low=low, high=high, size=shape).astype(np.float32) + return lhs_np, rhs_np, ograd_np + + def builtin_func(lhs, rhs, smooth): + out = lhs + rhs + return F.where(out < 0, out + smooth, out - smooth) + + def test_elemadd_smooth_train(smooth=0.5, m=4, n=2, seed=2021): + lhs_np, rhs_np, ograd_np = gen_elemadd_data(seed, (m, n)) + custom_lhs, custom_rhs = Tensor(lhs_np), Tensor(rhs_np) + builtin_lhs, builtin_rhs = Tensor(lhs_np), Tensor(rhs_np) + ograd_tensor = Tensor(ograd_np) + + custom_func = ElemAddSmooth(smooth=smooth) + gm = GradManager().attach([custom_lhs, custom_rhs]) + with gm: + custom_out = custom_func(custom_lhs, custom_rhs) + gm.backward(custom_out, ograd_tensor) + + gm = GradManager().attach([builtin_lhs, builtin_rhs]) + with gm: + builtin_out = builtin_func(builtin_lhs, builtin_rhs, smooth) + gm.backward(builtin_out, ograd_tensor) + + np.testing.assert_allclose(custom_out, builtin_out, rtol=1e-3, atol=1e-5) + np.testing.assert_allclose( + custom_lhs.grad.numpy(), builtin_lhs.grad.numpy(), rtol=1e-3, atol=1e-5 + ) + np.testing.assert_allclose( + custom_rhs.grad.numpy(), builtin_rhs.grad.numpy(), rtol=1e-3, atol=1e-5 + ) + + def test_elemadd_smooth_trace(smooth=0.5, m=4, n=2, seed=2021): + @jit.trace(capture_as_const=True) + def func_dumper(lhs, rhs, *, net): + return net(lhs, rhs) + + lhs_np, rhs_np, _ = gen_elemadd_data(seed, (m, n)) + lhs_tensor = Tensor(lhs_np) + rhs_tensor = Tensor(rhs_np) + func = ElemAddSmooth(smooth=smooth) + real = func_dumper(lhs_tensor, rhs_tensor, net=func) + real = func_dumper(lhs_tensor, rhs_tensor, net=func) + + ref = builtin_func(Tensor(lhs_np), Tensor(rhs_np), smooth) + np.testing.assert_allclose(real.numpy(), ref.numpy(), rtol=1e-3, atol=1e-5) + + test_elemadd_smooth_train(0.2, 128, 256, 2027) + test_elemadd_smooth_train(0.3, 256, 128, 2028) + test_elemadd_smooth_train(0.4, 128, 512, 2029) + + test_elemadd_smooth_trace(0.2, 256, 64, 2030) @pytest.mark.skipif( @@ -101,9 +141,136 @@ def test_custom_op_cpu_build(): @pytest.mark.skipif( get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU" ) -@build_and_clean -def test_custom_op_gpu_build(): +@build_and_clean("matmul_scale.cpp", "matmul_scale.cu") +def test_gpu_func(): + class MatMulScale(Function): + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, lhs, rhs): + op = custom.MatMulScaleForward(scale=self.scale) + self.lhs = lhs + self.rhs = rhs + return apply(op, lhs, rhs)[0] + + def backward(self, ograd): + op = custom.MatMulScaleBackward(scale=self.scale) + return apply(op, ograd, self.lhs, self.rhs) + + def gen_matmul_data(seed, m, k, n, low=-0.5, high=0.5, dtype=np.float32): + rng = np.random.RandomState(seed=seed) + lhs_np = rng.uniform(low=low, high=high, size=(m, k)).astype(dtype) + rhs_np = rng.uniform(low=low, high=high, size=(k, n)).astype(dtype) + ograd_np = rng.uniform(low=low, high=high, size=(m, n)).astype(dtype) + scale = rng.uniform(low=0.1, high=0.9, size=(1)).astype(np.float32)[0] + + return lhs_np, rhs_np, ograd_np, scale + + def builtin_func(lhs, rhs, scale): + out = F.matmul(lhs, rhs) * scale + return out + + def test_matmul_scale(m=1, k=1, n=1, seed=2021): + lhs_np, rhs_np, _, scale = gen_matmul_data(seed, m, k, n) + custom_lhs, custom_rhs = Tensor(lhs_np), Tensor(rhs_np) + builtin_lhs, builtin_rhs = Tensor(lhs_np), Tensor(rhs_np) + + custom_func = MatMulScale(scale=scale) + custom_out = custom_func(custom_lhs, custom_rhs) + builtin_out = builtin_func(builtin_lhs, builtin_rhs, scale) + + np.testing.assert_allclose(custom_out, builtin_out, rtol=1e-3, atol=1e-5) + + def test_matmul_scale_trace(m=1, k=1, n=1, seed=2021): + @jit.trace(capture_as_const=True) + def func_dumper(lhs, rhs, *, net): + return net(lhs, rhs) + + lhs_np, rhs_np, _, scale = gen_matmul_data(seed, m, k, n) + lhs_tensor, rhs_tensor = Tensor(lhs_np), Tensor(rhs_np) + func = MatMulScale(scale=scale) + real = func_dumper(lhs_tensor, rhs_tensor, net=func) + real = func_dumper(lhs_tensor, rhs_tensor, net=func) + + ref = builtin_func(Tensor(lhs_np), Tensor(rhs_np), scale) + np.testing.assert_allclose(real.numpy(), ref.numpy(), rtol=1e-3, atol=1e-5) + + test_matmul_scale(128, 256, 64, 2028) + test_matmul_scale(64, 32, 16, 2029) + + test_matmul_scale_trace(64, 32, 16, 2030) + + +@pytest.mark.skipif( + get_device_count("gpu") < 1, reason="matmul scale operator is only supported on GPU" +) +def test_custom_op(): + org_op_list = custom._get_custom_op_list() + assert len(custom._get_custom_op_lib_info()) == 0 + + assert "ElemAddSmoothForward" not in custom._get_custom_op_list() + assert not hasattr(custom, "ElemAddSmoothForward") + assert "MatMulScaleForward" not in custom._get_custom_op_list() + assert not hasattr(custom, "MatMulScaleForward") + + srcs1 = [os.path.join(cur_dir_path, "custom_opsrc", "elem_add.cpp")] + lib_path1 = custom_op_tools.build_and_load( + "elem", + srcs1, + extra_include_paths=extra_include_paths, + build_dir=build_path, + extra_ldflags=extra_ld_flags, + verbose=True, + ) + assert "ElemAddSmoothForward" in custom._get_custom_op_list() + assert hasattr(custom, "ElemAddSmoothForward") + assert lib_path1 in custom._get_custom_op_lib_info() + assert "ElemAddSmoothForward" in custom._get_custom_op_lib_info()[lib_path1] + + srcs2 = [ + os.path.join(cur_dir_path, "custom_opsrc", src) + for src in ["matmul_scale.cpp", "matmul_scale.cu"] + ] + lib_path2 = custom_op_tools.build_and_load( + "matmul", + srcs2, + extra_include_paths=extra_include_paths, + build_dir=build_path, + extra_ldflags=extra_ld_flags, + verbose=True, + ) + + assert "MatMulScaleForward" in custom._get_custom_op_list() + assert hasattr(custom, "MatMulScaleForward") + assert lib_path2 in custom._get_custom_op_lib_info() + assert "MatMulScaleForward" in custom._get_custom_op_lib_info()[lib_path2] + + assert len(custom._get_custom_op_list()) == len(org_op_list) + 4 + + custom.unload(lib_path1) + assert "ElemAddSmoothForward" not in custom._get_custom_op_list() + assert not hasattr(custom, "ElemAddSmoothForward") + assert lib_path1 not in custom._get_custom_op_lib_info() + + custom.unload(lib_path2) + assert "MatMulScaleForward" not in custom._get_custom_op_list() + assert not hasattr(custom, "MatMulScaleForward") + assert lib_path1 not in custom._get_custom_op_lib_info() + + assert len(custom._get_custom_op_lib_info()) == 0 + assert custom._get_custom_op_list() == org_op_list + + custom.load(lib_path2) assert "MatMulScaleForward" in custom._get_custom_op_list() - assert "MatMulScaleBackward" in custom._get_custom_op_list() assert hasattr(custom, "MatMulScaleForward") - assert hasattr(custom, "MatMulScaleBackward") + assert lib_path2 in custom._get_custom_op_lib_info() + assert "MatMulScaleForward" in custom._get_custom_op_lib_info()[lib_path2] + + custom.unload(lib_path2) + assert "MatMulScaleForward" not in custom._get_custom_op_list() + assert not hasattr(custom, "MatMulScaleForward") + assert lib_path1 not in custom._get_custom_op_lib_info() + + assert len(custom._get_custom_op_lib_info()) == 0 + assert custom._get_custom_op_list() == org_op_list diff --git a/imperative/src/impl/ops/custom_opdef.cpp b/imperative/src/impl/ops/custom_opdef.cpp index 932605e9d..a515e1f63 100644 --- a/imperative/src/impl/ops/custom_opdef.cpp +++ b/imperative/src/impl/ops/custom_opdef.cpp @@ -3,7 +3,7 @@ #if MGB_CUSTOM_OP #include "../op_trait.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/opr/custom_opnode.h" namespace mgb { @@ -51,13 +51,9 @@ const std::shared_ptr& CustomOpDef::impl(void) const { } void CustomOpDef::compute( - const SmallVector& inputs, - SmallVector* outputs) const { - std::vector custom_inputs = - custom::to_custom(inputs); - std::vector custom_outputs = - custom::to_custom(*outputs); - m_op->compute(custom_inputs, this->m_param, custom_outputs); + std::shared_ptr> inputs, + std::shared_ptr> outputs) const { + custom::dispatch_custom_op(m_op, m_param, inputs, outputs); } std::tuple, bool> CustomOpDef::infer_output_attrs( @@ -169,13 +165,6 @@ std::shared_ptr CustomOpDefFactory::create_opdef( namespace custom_opdef { // avoid name conflict -void apply_on_device_tensornd( - const OpDef& def, const SmallVector& inputs, - SmallVector* outputs) { - auto&& op = static_cast(def); - op.compute(inputs, outputs); -} - SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs, SmallVector& output_descs, const bool& validated) { @@ -194,15 +183,19 @@ SmallVector apply_on_physical_tensor( output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); } - SmallVector inp_tensornds(inputs.size()); - SmallVector oup_tensornds(outputs.size()); - - for (size_t i = 0; i < inputs.size(); ++i) - inp_tensornds[i] = inputs[i]->dev_tensor(); - for (size_t i = 0; i < outputs.size(); ++i) - oup_tensornds[i] = outputs[i]->dev_tensor(); + std::shared_ptr> inp_tensornds = + std::make_shared>(); + std::shared_ptr> oup_tensornds = + std::make_shared>(); + for (size_t i = 0; i < inputs.size(); ++i) { + inp_tensornds->emplace_back(inputs[i]->dev_tensor(true)); + } + for (size_t i = 0; i < outputs.size(); ++i) { + oup_tensornds->emplace_back(outputs[i]->dev_tensor(true)); + } - apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); + auto&& op = static_cast(def); + op.compute(inp_tensornds, oup_tensornds); return outputs; } @@ -258,7 +251,6 @@ std::string make_name(const OpDef& def) { OP_TRAIT_REG(CustomOpDef, CustomOpDef) .apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_var_node(apply_on_var_node) - .apply_on_device_tensornd(apply_on_device_tensornd) .infer_output_attrs_fallible(infer_output_attrs_fallible) .hash(hash) .is_same_st(is_same_st) diff --git a/imperative/src/include/megbrain/imperative/ops/custom_opdef.h b/imperative/src/include/megbrain/imperative/ops/custom_opdef.h index 4e8a7b1b0..8a8775029 100644 --- a/imperative/src/include/megbrain/imperative/ops/custom_opdef.h +++ b/imperative/src/include/megbrain/imperative/ops/custom_opdef.h @@ -31,7 +31,8 @@ public: const std::shared_ptr& impl(void) const; void compute( - const SmallVector&, SmallVector*) const; + std::shared_ptr>, + std::shared_ptr>) const; std::tuple, bool> infer_output_attrs( const SmallVector& inputs) const; std::tuple, bool> infer_output_attrs( diff --git a/src/custom/impl/manager.cpp b/src/custom/impl/manager.cpp index bdfc09b41..7fc964b63 100644 --- a/src/custom/impl/manager.cpp +++ b/src/custom/impl/manager.cpp @@ -32,6 +32,39 @@ const char* dlerror(void) { } #endif +CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY) + : m_handle(nullptr, [](void* handle) { dlclose(handle); }) { + auto op_list_before_load = CustomOpManager::inst()->op_name_list(); + std::unordered_set op_set_before_load( + op_list_before_load.begin(), op_list_before_load.end()); + + m_handle.reset(dlopen(path.c_str(), mode)); + mgb_assert( + m_handle != nullptr, "open custom op lib failed, error type: %s", + dlerror()); + + auto op_list_after_load = CustomOpManager::inst()->op_name_list(); + for (auto& op : op_list_after_load) { + if (op_set_before_load.find(op) == op_set_before_load.end()) { + m_ops.emplace_back(op); + } + } +} + +CustomLib::~CustomLib() { + for (auto& op : m_ops) { + CustomOpManager::inst()->erase(op); + } +} + +const std::vector& CustomLib::ops_in_lib(void) const { + return m_ops; +} + +bool CustomLib::valid() const { + return m_handle != nullptr; +} + CustomOpManager* CustomOpManager::inst(void) { static CustomOpManager op_manager; return &op_manager; @@ -39,12 +72,40 @@ CustomOpManager* CustomOpManager::inst(void) { CustomOpManager::~CustomOpManager() { mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); - LibManager::inst()->m_custom_libs.clear(); + { + MGB_LOCK_GUARD(m_lib_mtx); + m_custom_libs.clear(); + } + + mgb_assert(m_name2op.size() == m_id2op.size(), "Custom Op maintenance error!"); + MGB_LOCK_GUARD(m_op_mtx); + m_name2op.clear(); + m_id2op.clear(); +} + +const std::vector& CustomOpManager::install( + const std::string& name, const std::string& path) { + MGB_LOCK_GUARD(m_lib_mtx); + LibHandle handle = std::make_shared(path); + m_custom_libs.insert({name, handle}); + return m_custom_libs[name]->ops_in_lib(); +} + +std::vector CustomOpManager::uninstall(const std::string& name) { + MGB_LOCK_GUARD(m_lib_mtx); + std::vector op_names = m_custom_libs[name]->ops_in_lib(); + mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); + return op_names; +} + +const std::unordered_map& CustomOpManager::lib_info( + void) const { + return m_custom_libs; } std::shared_ptr CustomOpManager::insert( const std::string& name, uint32_t version) { - MGB_LOCK_GUARD(m_mtx); + MGB_LOCK_GUARD(m_op_mtx); auto iter = m_name2op.find(name); if (iter != m_name2op.end()) { mgb_log_warn( @@ -59,7 +120,7 @@ std::shared_ptr CustomOpManager::insert( } bool CustomOpManager::erase(const std::string& name) { - MGB_LOCK_GUARD(m_mtx); + MGB_LOCK_GUARD(m_op_mtx); auto iter = m_name2op.find(name); if (iter == m_name2op.end()) { mgb_log_warn( @@ -72,28 +133,6 @@ bool CustomOpManager::erase(const std::string& name) { return true; } -bool CustomOpManager::erase(const RunTimeId& id) { - MGB_LOCK_GUARD(m_mtx); - auto iter = m_id2op.find(id); - if (iter == m_id2op.end()) { - mgb_log_warn("Erase Custom Op Failed! The Op has not been registered"); - return false; - } - std::shared_ptr op = iter->second; - m_id2op.erase(op->runtime_id()); - m_name2op.erase(op->op_type()); - return true; -} - -std::shared_ptr CustomOpManager::find_or_reg( - const std::string& name, uint32_t version) { - auto iter = m_name2op.find(name); - if (iter == m_name2op.end()) { - return insert(name, version); - } - return std::const_pointer_cast(iter->second); -} - RunTimeId CustomOpManager::to_id(const std::string& name) const { std::shared_ptr op = find(name); return op->runtime_id(); @@ -135,60 +174,6 @@ std::vector CustomOpManager::op_id_list(void) { return ret; } -CustomLib::CustomLib(const std::string& path, int mode = RTLD_LAZY) - : m_handle(nullptr, [](void* handle) { dlclose(handle); }) { - auto op_list_before_load = CustomOpManager::inst()->op_name_list(); - std::unordered_set op_set_before_load( - op_list_before_load.begin(), op_list_before_load.end()); - - m_handle.reset(dlopen(path.c_str(), mode)); - mgb_assert( - m_handle != nullptr, "open custom op lib failed, error type: %s", - dlerror()); - - auto op_list_after_load = CustomOpManager::inst()->op_name_list(); - for (auto& op : op_list_after_load) { - if (op_set_before_load.find(op) == op_set_before_load.end()) { - m_ops.emplace_back(op); - } - } -} - -const std::vector& CustomLib::ops_in_lib(void) const { - return m_ops; -} - -CustomLib::~CustomLib() { - for (auto& op : m_ops) { - CustomOpManager::inst()->erase(op); - } -} - -bool CustomLib::valid() const { - return m_handle != nullptr; -} - -LibManager* LibManager::inst(void) { - static LibManager custom_libs; - return &custom_libs; -} - -const std::vector& LibManager::install( - const std::string& name, const std::string& path) { - MGB_LOCK_GUARD(m_mtx); - ; - LibHandle handle = std::make_shared(path); - m_custom_libs.insert({name, handle}); - return m_custom_libs[name]->ops_in_lib(); -} - -bool LibManager::uninstall(const std::string& name) { - MGB_LOCK_GUARD(m_mtx); - ; - mgb_assert(m_custom_libs.erase(name) == 1, "uninstall error"); - return true; -} - std::shared_ptr op_insert(std::string opname, uint32_t version) { return CustomOpManager::inst()->insert(opname, version); } diff --git a/src/custom/impl/op.cpp b/src/custom/impl/op.cpp index 93541f8b6..8b5777caa 100644 --- a/src/custom/impl/op.cpp +++ b/src/custom/impl/op.cpp @@ -4,8 +4,11 @@ #include #include +#include "megbrain/comp_node_env.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/op.h" #include "megbrain/custom/utils.h" +#include "megbrain/tensor.h" #include "megbrain/utils/thin/function.h" using namespace mgb; @@ -550,6 +553,45 @@ void CustomOp::compute( assert_outputs_size_right(outputs); } +void compute_impl( + std::shared_ptr op, const Param& param, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs) { + std::vector custom_inputs; + for (size_t i = 0; i < inputs->size(); ++i) { + custom_inputs.emplace_back(to_custom_tensor(inputs->operator[](i))); + } + std::vector custom_outputs; + for (size_t i = 0; i < outputs->size(); ++i) { + custom_outputs.emplace_back(to_custom_tensor(outputs->operator[](i))); + } + op->compute(custom_inputs, param, custom_outputs); +} + +void dispatch_custom_op( + std::shared_ptr op, const Param& param, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs) { + if (outputs->size() == 0) { + return; + } + + auto compnode = outputs->at(0).comp_node(); + if (compnode.device_type() == CompNode::DeviceType::CPU) { + auto&& cpu_env = CompNodeEnv::from_comp_node(compnode).cpu_env(); + cpu_env.dispatch([op, param, inputs, outputs]() { + compute_impl(op, param, inputs, outputs); + }); + + } else { + mgb_assert( + compnode.device_type() == CompNode::DeviceType::CUDA, + "custom op only support cuda/cpu now, but get %s", + compnode.to_string().c_str()); + compute_impl(op, param, inputs, outputs); + } +} + } // namespace custom #endif diff --git a/src/custom/impl/param_val.cpp b/src/custom/impl/param_val.cpp index b741ed9fa..a0a997604 100644 --- a/src/custom/impl/param_val.cpp +++ b/src/custom/impl/param_val.cpp @@ -3,7 +3,7 @@ #if MGB_CUSTOM_OP #include "megbrain/comp_node.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/param_val.h" #include "megbrain/custom/tensor.h" @@ -40,7 +40,7 @@ namespace custom { #define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ mgb_assert( \ lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ - type2name[lhs.m_type].c_str(), #op, type2name[rhs.m_type].c_str()) + ptype2name(lhs.m_type).c_str(), #op, ptype2name(rhs.m_type).c_str()) #define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ case (ParamDynType::dyn_type): { \ @@ -177,6 +177,18 @@ namespace custom { break; \ } +std::string ptype2name(ParamDynType ptype) { +#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \ + {ParamDynType::dyn_type, #dyn_type}, + + static std::unordered_map< + ParamDynType, std::string, EnumHash, EnumCmp> + type2name = {CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME){ + ParamDynType::Invalid, "Invalid"}}; +#undef CUSTOM_REG_DYN_PARAMTYPE_NAME + return type2name[ptype]; +} + ParamVal::ParamVal() : m_ptr(nullptr, [](void*) -> void {}) { m_type = ParamDynType::Invalid; } @@ -265,7 +277,7 @@ ParamDynType ParamVal::type(void) const { std::string ParamVal::str() const { std::stringstream ss; - ss << "type: " << type2name[m_type] << "\n" + ss << "type: " << ptype2name(m_type) << "\n" << "value: "; switch (m_type) { CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) diff --git a/src/custom/impl/platform/custom_cuda.cpp b/src/custom/impl/platform/custom_cuda.cpp index 11824b68a..a875552de 100644 --- a/src/custom/impl/platform/custom_cuda.cpp +++ b/src/custom/impl/platform/custom_cuda.cpp @@ -4,7 +4,7 @@ #if MGB_CUSTOM_OP #include "megbrain/comp_node_env.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/platform/custom_cuda.h" using namespace mgb; diff --git a/src/custom/impl/tensor.cpp b/src/custom/impl/tensor.cpp index 6485b5f7b..cf0f0b723 100644 --- a/src/custom/impl/tensor.cpp +++ b/src/custom/impl/tensor.cpp @@ -42,31 +42,33 @@ using TensorImpl = DeviceTensorND; #define TensorImplConstRef(rawptr) \ static_cast(*reinterpret_cast(rawptr)) -static std::unordered_map< - DeviceImpl::DeviceType, std::string, EnumHash, - EnumCmp> - dev_benum2cstr; -static std::unordered_map< - DeviceImpl::DeviceType, DeviceEnum, EnumHash, - EnumCmp> - dev_benum2cenum; -static std::unordered_map dev_cstr2bstr; -static std::unordered_map< - DeviceEnum, std::string, EnumHash, EnumCmp> - dev_cenum2bstr; - -#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ - auto be2cs##custom_impl = dev_benum2cstr.emplace( \ - DeviceImpl::DeviceType::builtin_device, std::string(#custom_impl)); \ - auto be2ce##custom_impl = dev_benum2cenum.emplace( \ - DeviceImpl::DeviceType::builtin_device, DeviceEnum::custom_impl); \ - auto cs2bs##custom_impl = dev_cstr2bstr.emplace( \ - std::string(#custom_impl), std::string(builtin_str)); \ - auto ce2bs##custom_impl = \ - dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str)); - -CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) +struct DeviceMapper { + using DeviceTy = DeviceImpl::DeviceType; + std::unordered_map dev_cstr2bstr; + EnumMap dev_benum2cstr; + EnumMap dev_benum2cenum; + EnumMap dev_cenum2bstr; + static DeviceMapper& inst(); + +private: + DeviceMapper(); +}; + +DeviceMapper::DeviceMapper() { +#define CUSTOM_BIND_DEVICE(custom_impl, builtin_device, builtin_str) \ + dev_benum2cstr.emplace(DeviceTy::builtin_device, std::string(#custom_impl)); \ + dev_benum2cenum.emplace(DeviceTy::builtin_device, DeviceEnum::custom_impl); \ + dev_cstr2bstr.emplace(std::string(#custom_impl), std::string(builtin_str)); \ + dev_cenum2bstr.emplace(DeviceEnum::custom_impl, std::string(builtin_str)); + + CUSTOM_FOR_EACH_DEVICE_TYPE(CUSTOM_BIND_DEVICE) #undef CUSTOM_BIND_DEVICE +} + +DeviceMapper& DeviceMapper::inst() { + static DeviceMapper dm; + return dm; +} CUSTOM_PIMPL_CLS_DEFINE(Device) @@ -81,6 +83,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter) { return; } + auto&& dev_benum2cenum = DeviceMapper::inst().dev_benum2cenum; auto builtin_device_enum = DeviceImplConstRef(impl).device_type(); mgb_assert( dev_benum2cenum.find(builtin_device_enum) != dev_benum2cenum.end(), @@ -91,7 +94,7 @@ Device::Device(const void* impl) : m_impl(nullptr, impl_deleter) { Device::Device(const std::string& device) : m_impl(nullptr, impl_deleter) { mgb_assert(is_legal(device), "invalid device type: %s", device.c_str()); - std::string builtin_device = dev_cstr2bstr[device]; + std::string builtin_device = DeviceMapper::inst().dev_cstr2bstr[device]; m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); } @@ -100,7 +103,7 @@ Device::Device(const char* device) : Device(std::string(device)) {} Device::Device(DeviceEnum device) : m_impl(nullptr, impl_deleter) { mgb_assert(is_legal(device), "invalid device type"); - std::string builtin_device = dev_cenum2bstr[device]; + std::string builtin_device = DeviceMapper::inst().dev_cenum2bstr[device]; m_impl.reset(new DeviceImpl(DeviceImpl::load(builtin_device))); } @@ -110,6 +113,7 @@ std::string Device::str(void) const { } auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); + auto&& dev_benum2cstr = DeviceMapper::inst().dev_benum2cstr; auto iter = dev_benum2cstr.find(builtin_device_type); mgb_assert( iter != dev_benum2cstr.end(), "invalid device type %s\n", @@ -123,6 +127,7 @@ DeviceEnum Device::enumv(void) const { "cannot get the enum value of invalid device"); auto builtin_device_type = DeviceImplRef(m_impl.get()).device_type(); + auto&& dev_benum2cenum = DeviceMapper::inst().dev_benum2cenum; auto iter = dev_benum2cenum.find(builtin_device_type); mgb_assert( iter != dev_benum2cenum.end(), "invalid device type %s\n", @@ -131,16 +136,18 @@ DeviceEnum Device::enumv(void) const { } bool Device::is_legal(const std::string& device_type) { + auto&& dev_cstr2bstr = DeviceMapper::inst().dev_cstr2bstr; return dev_cstr2bstr.find(device_type) != dev_cstr2bstr.end(); } bool Device::is_legal(DeviceEnum device_type) { + auto&& dev_cenum2bstr = DeviceMapper::inst().dev_cenum2bstr; return dev_cenum2bstr.find(device_type) != dev_cenum2bstr.end(); } std::vector Device::legal_devices(void) { std::vector ret; - for (const auto& kv : dev_cstr2bstr) { + for (const auto& kv : DeviceMapper::inst().dev_cstr2bstr) { ret.emplace_back(kv.first); } return ret; @@ -197,36 +204,37 @@ bool operator==(const Shape& lhs, const Shape& rhs) { return ShapeImplRef(lhs.m_impl.get()).eq_shape(ShapeImplRef(rhs.m_impl.get())); } -static std::unordered_map dtype_cstr2benum; -static std::unordered_map< - DTypeEnum, megdnn::DTypeEnum, EnumHash, EnumCmp> - dtype_cenum2benum; -static std::unordered_map< - megdnn::DTypeEnum, std::string, EnumHash, - EnumCmp> - dtype_benum2cstr; -static std::unordered_map< - megdnn::DTypeEnum, DTypeEnum, EnumHash, - EnumCmp> - dtype_benum2cenum; -static std::unordered_map< - DTypeEnum, std::string, EnumHash, EnumCmp> - dtype_cenum2cstr; - -#define CUSTOM_BIND_DTYPE(custom_impl, builtin_dtype, ctype) \ - auto cs2be##custom_impl = dtype_cstr2benum.emplace( \ - std::string(#custom_impl), megdnn::DTypeEnum::builtin_dtype); \ - auto ce2be##custom_impl = dtype_cenum2benum.emplace( \ - DTypeEnum::custom_impl, megdnn::DTypeEnum::builtin_dtype); \ - auto be2cs##custom_impl = dtype_benum2cstr.emplace( \ - megdnn::DTypeEnum::builtin_dtype, std::string(#custom_impl)); \ - auto be2ce##custom_impl = dtype_benum2cenum.emplace( \ - megdnn::DTypeEnum::builtin_dtype, DTypeEnum::custom_impl); \ - auto ce2cs##custom_impl = dtype_cenum2cstr.emplace( \ - DTypeEnum::custom_impl, std::string(#custom_impl)); - -CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) +struct DTypeMapper { + using CustomEnum = DTypeEnum; + using BuiltinEnum = megdnn::DTypeEnum; + + std::unordered_map dtype_cstr2benum; + EnumMap dtype_cenum2benum; + EnumMap dtype_benum2cstr; + EnumMap dtype_benum2cenum; + EnumMap dtype_cenum2cstr; + static DTypeMapper& inst(); + +private: + DTypeMapper(); +}; + +DTypeMapper::DTypeMapper() { +#define CUSTOM_BIND_DTYPE(custom_dty, builtin_dty, ctype) \ + dtype_cstr2benum.emplace(std::string(#custom_dty), BuiltinEnum::builtin_dty); \ + dtype_cenum2benum.emplace(DTypeEnum::custom_dty, BuiltinEnum::builtin_dty); \ + dtype_benum2cstr.emplace(BuiltinEnum::builtin_dty, std::string(#custom_dty)); \ + dtype_benum2cenum.emplace(BuiltinEnum::builtin_dty, DTypeEnum::custom_dty); \ + dtype_cenum2cstr.emplace(DTypeEnum::custom_dty, std::string(#custom_dty)); + + CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(CUSTOM_BIND_DTYPE) #undef CUSTOM_BIND_DTYPE +} + +DTypeMapper& DTypeMapper::inst() { + static DTypeMapper dm; + return dm; +} CUSTOM_PIMPL_CLS_DEFINE(DType) @@ -240,6 +248,7 @@ DType::DType(const void* impl) : m_impl(nullptr, impl_deleter) { } DType::DType(const std::string& dtype) : m_impl(nullptr, impl_deleter) { + auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum; auto iter = dtype_cstr2benum.find(dtype); mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); mgb_assert( @@ -254,6 +263,7 @@ DType::DType(const char* dtype) : DType(std::string(dtype)) {} DType::DType(const std::string& dtype, float scale, uint8_t zero_point) : m_impl(nullptr, impl_deleter) { + auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum; auto iter = dtype_cstr2benum.find(dtype); mgb_assert(iter != dtype_cstr2benum.end(), "invalid dtype %s", dtype.c_str()); mgb_assert( @@ -289,6 +299,7 @@ DType::DType(const char* dtype, float scale, uint8_t zero_point) : DType(std::string(dtype), scale, zero_point) {} DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter) { + auto&& dtype_cenum2benum = DTypeMapper::inst().dtype_cenum2benum; auto iter = dtype_cenum2benum.find(dtype); mgb_assert(iter != dtype_cenum2benum.end(), "invalid dtype"); mgb_assert( @@ -298,11 +309,13 @@ DType::DType(DTypeEnum dtype) : m_impl(nullptr, impl_deleter) { } DType::DType(DTypeEnum dtype, float scale, uint8_t zero_point) - : DType(dtype_cenum2cstr.find(dtype)->second, scale, zero_point) {} + : DType(DTypeMapper::inst().dtype_cenum2cstr.find(dtype)->second, scale, + zero_point) {} std::string DType::str(void) const { if (!DTypeImplRef(m_impl.get()).valid()) return "invalid"; + auto&& dtype_benum2cstr = DTypeMapper::inst().dtype_benum2cstr; auto iter = dtype_benum2cstr.find(DTypeImplRef(m_impl.get()).enumv()); if (iter == dtype_benum2cstr.end()) return "invalid"; @@ -310,6 +323,7 @@ std::string DType::str(void) const { } DTypeEnum DType::enumv(void) const { + auto&& dtype_benum2cenum = DTypeMapper::inst().dtype_benum2cenum; auto iter = dtype_benum2cenum.find(DTypeImplRef(m_impl.get()).enumv()); mgb_assert(iter != dtype_benum2cenum.end(), "invalid dtype"); return iter->second; @@ -337,16 +351,18 @@ uint8_t DType::zero_point() const { } bool DType::is_legal(const std::string& dtype) { + auto&& dtype_cstr2benum = DTypeMapper::inst().dtype_cstr2benum; return dtype_cstr2benum.find(dtype) != dtype_cstr2benum.end(); } bool DType::is_legal(const DTypeEnum& dtype) { + auto&& dtype_cenum2benum = DTypeMapper::inst().dtype_cenum2benum; return dtype_cenum2benum.find(dtype) != dtype_cenum2benum.end(); } std::vector DType::legal_dtypes(void) { std::vector ret; - for (const auto& kv : dtype_cstr2benum) + for (const auto& kv : DTypeMapper::inst().dtype_cstr2benum) ret.emplace_back(kv.first); return ret; } diff --git a/src/custom/include/megbrain/custom/data_adaptor.h b/src/custom/include/megbrain/custom/adaptor.h similarity index 75% rename from src/custom/include/megbrain/custom/data_adaptor.h rename to src/custom/include/megbrain/custom/adaptor.h index 2cf4d774c..4409788db 100644 --- a/src/custom/include/megbrain/custom/data_adaptor.h +++ b/src/custom/include/megbrain/custom/adaptor.h @@ -1,5 +1,8 @@ #pragma once +#include "megbrain/custom/op.h" +#include "megbrain/custom/tensor.h" +#include "megbrain/tensor.h" #include "megdnn/thin/small_vector.h" namespace custom { @@ -11,27 +14,32 @@ BuiltinT to_builtin(const CustomT& custom) { template CustomT to_custom(const BuiltinT& builtin) { - return std::move(CustomT(&builtin)); + return CustomT(&builtin); } template megdnn::SmallVector to_builtin(const std::vector& customs) { megdnn::SmallVector builtins; for (size_t i = 0; i < customs.size(); ++i) { - builtins.push_back(std::move(to_builtin(customs[i]))); + builtins.emplace_back(to_builtin(customs[i])); } - return std::move(builtins); + return builtins; } template std::vector to_custom(const megdnn::SmallVector& builtins) { std::vector customs; for (size_t i = 0; i < builtins.size(); ++i) { - customs.push_back(std::move(to_custom(builtins[i]))); + customs.emplace_back(to_custom(builtins[i])); } - return std::move(customs); + return customs; } +MGE_WIN_DECLSPEC_FUC void dispatch_custom_op( + std::shared_ptr op, const Param& param, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> inputs, + std::shared_ptr<::megdnn::SmallVector<::mgb::DeviceTensorND>> outputs); + } // namespace custom #define to_custom_device(expr) \ diff --git a/src/custom/include/megbrain/custom/manager.h b/src/custom/include/megbrain/custom/manager.h index 0399f0537..d2de013fa 100644 --- a/src/custom/include/megbrain/custom/manager.h +++ b/src/custom/include/megbrain/custom/manager.h @@ -5,10 +5,26 @@ namespace custom { +class CustomLib { + std::unique_ptr m_handle; + std::vector m_ops; + +public: + PREVENT_COPY_AND_ASSIGN(CustomLib); + CustomLib(const std::string& path, int mode); + ~CustomLib(); + MGE_WIN_DECLSPEC_FUC const std::vector& ops_in_lib(void) const; + bool valid(void) const; +}; + +using LibHandle = std::shared_ptr; + class CustomOpManager { + std::unordered_map m_custom_libs; std::unordered_map> m_name2op; std::unordered_map> m_id2op; - MGB_MUTEX m_mtx; + MGB_MUTEX m_lib_mtx; + MGB_MUTEX m_op_mtx; CustomOpManager() = default; public: @@ -16,13 +32,15 @@ public: MGE_WIN_DECLSPEC_FUC static CustomOpManager* inst(void); MGE_WIN_DECLSPEC_FUC ~CustomOpManager(); + MGE_WIN_DECLSPEC_FUC const std::vector& install( + const std::string& name, const std::string& path); + MGE_WIN_DECLSPEC_FUC std::vector uninstall(const std::string& name); + MGE_WIN_DECLSPEC_FUC const std::unordered_map& lib_info( + void) const; + MGE_WIN_DECLSPEC_FUC std::shared_ptr insert( const std::string& name, uint32_t version); MGE_WIN_DECLSPEC_FUC bool erase(const std::string& name); - MGE_WIN_DECLSPEC_FUC bool erase(const RunTimeId& id); - - MGE_WIN_DECLSPEC_FUC std::shared_ptr find_or_reg( - const std::string& name, uint32_t version); MGE_WIN_DECLSPEC_FUC RunTimeId to_id(const std::string& name) const; MGE_WIN_DECLSPEC_FUC std::string to_name(const RunTimeId& id) const; @@ -36,35 +54,4 @@ public: MGE_WIN_DECLSPEC_FUC std::vector op_id_list(void); }; -class CustomLib { - std::unique_ptr m_handle; - std::vector m_ops; - -public: - PREVENT_COPY_AND_ASSIGN(CustomLib); - - CustomLib(const std::string& path, int mode); - const std::vector& ops_in_lib(void) const; - ~CustomLib(); - bool valid(void) const; -}; - -using LibHandle = std::shared_ptr; - -class LibManager { - std::unordered_map m_custom_libs; - MGB_MUTEX m_mtx; - - LibManager() = default; - -public: - PREVENT_COPY_AND_ASSIGN(LibManager); - - MGE_WIN_DECLSPEC_FUC static LibManager* inst(void); - MGE_WIN_DECLSPEC_FUC const std::vector& install( - const std::string& name, const std::string& path); - MGE_WIN_DECLSPEC_FUC bool uninstall(const std::string& name); - friend class CustomOpManager; -}; - } // namespace custom diff --git a/src/custom/include/megbrain/custom/param_val.h b/src/custom/include/megbrain/custom/param_val.h index fb75e62d8..cb99bba43 100644 --- a/src/custom/include/megbrain/custom/param_val.h +++ b/src/custom/include/megbrain/custom/param_val.h @@ -76,8 +76,6 @@ class Device; * Macro Callback for Register */ #define CUSTOM_REG_DYN_PARAMTYPE(dyn_type, static_type) dyn_type, -#define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \ - {ParamDynType::dyn_type, #dyn_type}, #define CUSTOM_REG_DYN_PARAMTYPE_GETTER(dyn_type, static_type) \ template <> \ @@ -95,10 +93,7 @@ enum class ParamDynType : uint32_t { CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE) Invalid = 255 }; -static std::unordered_map< - ParamDynType, std::string, EnumHash, EnumCmp> - type2name = {CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME){ - ParamDynType::Invalid, "Invalid"}}; +MGE_WIN_DECLSPEC_FUC std::string ptype2name(ParamDynType); /** * get the dynamic data type according to the builtin static data type @@ -124,7 +119,6 @@ CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_GETTER) CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_STATIC_PARAMTYPE_GETTER) #undef CUSTOM_REG_DYN_PARAMTYPE -#undef CUSTOM_REG_DYN_PARAMTYPE_NAME #undef CUSTOM_REG_DYN_PARAMTYPE_GETTER #undef CUSTOM_REG_STATIC_PARAMTYPE_GETTER @@ -290,7 +284,7 @@ T& ParamVal::as(void) { ParamDynType t_dyn_type = get_dyn_type::type; custom_assert( t_dyn_type == m_type, "type mismatch, type %s cannot be cast to type %s\n", - type2name[m_type].c_str(), type2name[t_dyn_type].c_str()); + ptype2name(m_type).c_str(), ptype2name(t_dyn_type).c_str()); return TypedRef(T, m_ptr.get()); } diff --git a/src/custom/include/megbrain/custom/utils.h b/src/custom/include/megbrain/custom/utils.h index 5d6ad08bd..65cc00363 100644 --- a/src/custom/include/megbrain/custom/utils.h +++ b/src/custom/include/megbrain/custom/utils.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace custom { @@ -108,4 +109,7 @@ struct EnumCmp { } }; +template +using EnumMap = std::unordered_map, EnumCmp>; + } // namespace custom diff --git a/src/custom/test/manager.cpp b/src/custom/test/manager.cpp index 59cdca60b..fbbe98b54 100644 --- a/src/custom/test/manager.cpp +++ b/src/custom/test/manager.cpp @@ -12,16 +12,17 @@ namespace custom { TEST(TestOpManager, TestOpManager) { CustomOpManager* com = CustomOpManager::inst(); + std::vector builtin_op_names = com->op_name_list(); + size_t builtin_op_num = builtin_op_names.size(); + com->insert("Op1", CUSTOM_OP_VERSION); com->insert("Op2", CUSTOM_OP_VERSION); - std::shared_ptr ptr = com->find_or_reg("Op3", CUSTOM_OP_VERSION); - ASSERT_TRUE(ptr != nullptr); std::vector op_names = com->op_name_list(); std::vector op_ids = com->op_id_list(); - ASSERT_TRUE(op_names.size() == 3); - ASSERT_TRUE(op_ids.size() == 3); + ASSERT_TRUE(op_names.size() == builtin_op_num + 2); + ASSERT_TRUE(op_ids.size() == builtin_op_num + 2); #if MANAGER_TEST_LOG for (std::string& name : op_names) { @@ -52,12 +53,9 @@ TEST(TestOpManager, TestOpManager) { } #endif ASSERT_TRUE(com->erase("Op1")); - ASSERT_TRUE(com->erase(com->to_id("Op2"))); - ASSERT_TRUE(com->op_id_list().size() == 1); - ASSERT_TRUE(com->op_name_list().size() == 1); - ASSERT_TRUE(com->op_name_list()[0] == "Op3"); - ptr.reset(); - ASSERT_TRUE(com->erase("Op3")); + ASSERT_TRUE(com->op_id_list().size() == builtin_op_num + 1); + ASSERT_TRUE(com->op_name_list().size() == builtin_op_num + 1); + ASSERT_TRUE(com->erase("Op2")); } TEST(TestOpManager, TestOpReg) { diff --git a/src/custom/test/op.cpp b/src/custom/test/op.cpp index 1e504e985..c6afd91ac 100644 --- a/src/custom/test/op.cpp +++ b/src/custom/test/op.cpp @@ -4,9 +4,10 @@ #include "gtest/gtest.h" #include "megbrain/comp_node.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/op.h" #include "megbrain/tensor.h" +#include "megbrain/test/helper.h" #include "megbrain_build_config.h" #define OP_TEST_LOG 0 @@ -93,60 +94,6 @@ void format_infer( outputs[1] = inputs[0]; } -void cpu_kernel( - const std::vector& inputs, const Param& params, - std::vector& outputs) { - (void)inputs; - (void)params; - (void)outputs; -#if OP_TEST_LOG - std::cout << "Checking CPU Forward - " << params["device"].as() - << std::endl; -#endif - ASSERT_TRUE(params["device"] == "x86"); -} - -void gpu_kernel( - const std::vector& inputs, const Param& params, - std::vector& outputs) { - (void)inputs; - (void)params; - (void)outputs; -#if OP_TEST_LOG - std::cout << "Checking GPU Forward - " << params["device"].as() - << std::endl; -#endif - ASSERT_TRUE(params["device"] == "cuda"); -} - -void cpu_kernel_with_runtime_args( - const std::vector& inputs, const Param& params, - std::vector& outputs, const RuntimeArgs& args) { - (void)inputs; - (void)params; - (void)outputs; - (void)args; -#if OP_TEST_LOG - std::cout << "Checking CPU Forward - " << params["device"].as() - << std::endl; -#endif - ASSERT_TRUE(params["device"] == "x86"); -} - -void gpu_kernel_with_runtime_args( - const std::vector& inputs, const Param& params, - std::vector& outputs, const RuntimeArgs& args) { - (void)inputs; - (void)params; - (void)outputs; - (void)args; -#if OP_TEST_LOG - std::cout << "Checking GPU Forward - " << params["device"].as() - << std::endl; -#endif - ASSERT_TRUE(params["device"] == "cuda"); -} - TEST(TestCustomOp, TestCustomOpFuncSetter) { #if MGB_CUDA CustomOp test("TestOp", CUSTOM_OP_VERSION); @@ -155,7 +102,8 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) .add_output("outl", "outl of Test op", {"float32", "int32"}, 2) .add_output("outr", "outr of Test op", {"float32", "int32"}, 2) - .add_param("smooth", "smooth", 0.f) + .add_param("scale_f", "scale_f", 1.f) + .add_param("offset_i", "offset_i", 0) .add_param("device", "using for judge device", "x86"); std::vector idevices = {"x86", "cuda"}; @@ -206,35 +154,93 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ASSERT_TRUE(odtypes[1] == "int32"); ASSERT_TRUE(iformats[0].is_default()); ASSERT_TRUE(iformats[1].is_default()); +#endif +} - test.set_compute(cpu_kernel_with_runtime_args); - test.set_compute(cpu_kernel); - DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); - DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); - DeviceTensorND cdev_otensor0(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); - DeviceTensorND cdev_otensor1(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); - - std::vector cinputs = { - to_custom_tensor(cdev_itensor0), to_custom_tensor(cdev_itensor1)}; - std::vector coutputs = { - to_custom_tensor(cdev_otensor0), to_custom_tensor(cdev_otensor1)}; +void cpu_kernel( + const std::vector& inputs, const Param& params, + std::vector& outputs) { + ASSERT_TRUE(inputs.size() == 2); + ASSERT_TRUE(outputs.size() == 2); + ASSERT_TRUE(params["device"] == "x86"); + ASSERT_TRUE(params["scale_f"] == 2.12f); + ASSERT_TRUE(params["offset_i"] == 6); + ASSERT_TRUE(inputs[0].shape() == Shape({3, 4})); + ASSERT_TRUE(inputs[1].shape() == Shape({5, 6})); + ASSERT_TRUE(outputs[0].shape() == Shape({5, 6})); + ASSERT_TRUE(outputs[1].shape() == Shape({3, 4})); + ASSERT_TRUE(inputs[0].device() == "x86"); + ASSERT_TRUE(inputs[1].device() == "x86"); + ASSERT_TRUE(outputs[0].device() == "x86"); + ASSERT_TRUE(outputs[1].device() == "x86"); + + float scale_f = params["scale_f"].as(); + int offset_i = params["offset_i"].as(); + + for (size_t i = 0; i < 5 * 6; ++i) { + ASSERT_TRUE(inputs[1].data()[i] == static_cast(i)); + outputs[0].data()[i] = inputs[1].data()[i] * scale_f; + } + for (size_t i = 0; i < 3 * 4; ++i) { + ASSERT_TRUE(inputs[0].data()[i] == static_cast(i)); + outputs[1].data()[i] = inputs[0].data()[i] + offset_i; + } +} + +TEST(TestCustomOp, TestCustomOpCompute) { + std::shared_ptr op = + std::make_shared("TestOp", CUSTOM_OP_VERSION); + op->set_description("Test Op Forward Backward Union") + .add_input("lhs", "lhs of Test op", {"float32", "int32"}, 2) + .add_input("rhs", "rhs of Test op", {"float32", "int32"}, 2) + .add_output("outl", "outl of Test op", {"float32", "int32"}, 2) + .add_output("outr", "outr of Test op", {"float32", "int32"}, 2) + .add_param("scale_f", "scale_f", 1.f) + .add_param("offset_i", "offset_i", 0) + .add_param("device", "using for judge device", "x86") + .set_shape_infer(shape_infer) + .set_dtype_infer(dtype_infer) + .set_compute("x86", cpu_kernel); + + Param param(op->param_info()); param["device"] = "x86"; - test.compute(cinputs, param, coutputs); - - test.set_compute("cuda", gpu_kernel_with_runtime_args); - test.set_compute("cuda", gpu_kernel); - DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); - DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); - DeviceTensorND gdev_otensor0(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); - DeviceTensorND gdev_otensor1(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); - - std::vector ginputs = { - to_custom_tensor(gdev_itensor0), to_custom_tensor(gdev_itensor1)}; - std::vector goutputs = { - to_custom_tensor(gdev_otensor0), to_custom_tensor(gdev_otensor1)}; - param["device"] = "cuda"; - test.compute(ginputs, param, goutputs); -#endif + param["scale_f"] = 2.12f; + param["offset_i"] = 6; + + HostTensorGenerator gen_f; + HostTensorGenerator gen_i; + auto host_i0 = gen_i({3, 4}), host_i1 = gen_f({5, 6}); + auto expect_o0 = gen_f({5, 6}), expect_o1 = gen_i({3, 4}); + for (size_t i = 0; i < 5 * 6; ++i) { + host_i1->ptr()[i] = static_cast(i); + expect_o0->ptr()[i] = host_i1->ptr()[i] * 2.12f; + } + for (size_t i = 0; i < 3 * 4; ++i) { + host_i0->ptr()[i] = static_cast(i); + expect_o1->ptr()[i] = host_i0->ptr()[i] + 6; + } + + auto cn = CompNode::load("cpux"); + std::shared_ptr> x86_inps = + std::make_shared>(2); + x86_inps->at(0) = DeviceTensorND{cn}; + x86_inps->at(1) = DeviceTensorND{cn}; + x86_inps->at(0).copy_from(*host_i0).sync(); + x86_inps->at(1).copy_from(*host_i1).sync(); + + std::shared_ptr> x86_oups = + std::make_shared>(2); + x86_oups->at(0) = DeviceTensorND{cn, {5, 6}, dtype::Float32{}}; + x86_oups->at(1) = DeviceTensorND{cn, {3, 4}, dtype::Int32{}}; + + dispatch_custom_op(op, param, x86_inps, x86_oups); + cn.sync(); + HostTensorND host_o0, host_o1; + host_o0.copy_from(x86_oups->at(0)).sync(); + host_o1.copy_from(x86_oups->at(1)).sync(); + + MGB_ASSERT_TENSOR_NEAR(*expect_o0, host_o0, 1e-6); + MGB_ASSERT_TENSOR_NEAR(*expect_o1, host_o1, 1e-6); } } // namespace custom diff --git a/src/custom/test/tensor.cpp b/src/custom/test/tensor.cpp index 236e20fc6..bbeb7ccb9 100644 --- a/src/custom/test/tensor.cpp +++ b/src/custom/test/tensor.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "megbrain/comp_node.h" -#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/tensor.h" #include "megbrain/tensor.h" #include "megbrain_build_config.h" diff --git a/src/opr/impl/custom_opnode.cpp b/src/opr/impl/custom_opnode.cpp index f243edecd..ef970f95c 100644 --- a/src/opr/impl/custom_opnode.cpp +++ b/src/opr/impl/custom_opnode.cpp @@ -114,24 +114,21 @@ void CustomOpNode::init_output_comp_node() { void CustomOpNode::do_execute(ExecEnv& env) { auto runner = [this]() { + std::shared_ptr> inputs = + std::make_shared>(); + std::shared_ptr> outputs = + std::make_shared>(); + for (size_t i = 0; i < input_num(); i++) { + inputs->emplace_back(input(i)->dev_tensor()); + } + for (size_t i = 0; i < output_num(); i++) { + outputs->emplace_back(output(i)->dev_tensor()); + } + this->owner_graph()->event().signal_inplace( this, m_comp_node); m_comp_node.activate(); - - SmallVector inputs, outputs; - for (size_t i = 0; i < input_num(); i++) - inputs.push_back(input(i)->dev_tensor()); - for (size_t i = 0; i < output_num(); i++) - outputs.push_back(output(i)->dev_tensor()); - - std::vector custom_inputs = - custom::to_custom(inputs); - std::vector custom_outputs = - custom::to_custom(outputs); - m_op->compute(custom_inputs, m_param, custom_outputs); - // [TODO] sync should be modified - CompNode::sync_all(); - + custom::dispatch_custom_op(m_op, m_param, inputs, outputs); this->owner_graph()->event().signal_inplace( this, m_comp_node); }; diff --git a/src/opr/include/megbrain/opr/custom_opnode.h b/src/opr/include/megbrain/opr/custom_opnode.h index 14b18fdf6..db9709d9b 100644 --- a/src/opr/include/megbrain/opr/custom_opnode.h +++ b/src/opr/include/megbrain/opr/custom_opnode.h @@ -4,8 +4,8 @@ #if MGB_CUSTOM_OP +#include "megbrain/custom/adaptor.h" #include "megbrain/custom/custom.h" -#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/manager.h" #include "megbrain/graph/event.h" #include "megbrain/graph/helper.h" -- GitLab