diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 01efb2ef1df9a6ca6dbbfbe1edf04d634a86ef54..046cf99c17610f4d73c2ab030329fc9f7d0b7fb9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -90,6 +90,15 @@ endif() if(MGE_WITH_CUSTOM_OP) list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/custom/include) file(GLOB_RECURSE SOURCES_ custom/impl/*.cpp) + + set(EXCLUDE_PLATFORM_DIR "custom/impl/platform") + foreach(CUSOURCE ${SOURCES_}) + string(FIND ${CUSOURCE} ${EXCLUDE_PLATFORM_DIR} EXCLUDE_DIR_FOUND) + if(NOT ${EXCLUDE_DIR_FOUND} EQUAL -1) + list(REMOVE_ITEM SOURCES_ ${CUSOURCE}) + endif() + endforeach(CUSOURCE) + list(APPEND SOURCES ${SOURCES_}) endif() diff --git a/src/custom/impl/op.cpp b/src/custom/impl/op.cpp index e3e29f05a31e55a81e7c1e00cbfe4319d7f37c7b..93541f8b69d2073df99e198150336b2e5e231df8 100644 --- a/src/custom/impl/op.cpp +++ b/src/custom/impl/op.cpp @@ -6,6 +6,7 @@ #include #include "megbrain/custom/op.h" #include "megbrain/custom/utils.h" +#include "megbrain/utils/thin/function.h" using namespace mgb; @@ -99,40 +100,6 @@ std::string ArgInfo::str() const { (arg_info).name().c_str(), static_cast((arg_info).ndim()), \ static_cast((real_shape).ndim())) -template -class Function; - -template -class Function { -public: - using Functor = RType (*)(Args...); - - Function() = default; - Function(Functor f) : m_f(f) {} - Function(const Function& rhs) { m_f = rhs.m_f; } - - RType operator()(Args... args) { - custom_assert(m_f != nullptr, "invalid function ptr\n"); - return m_f(std::forward(args)...); - } - - void operator=(const Function& rhs) { // not allowed continuous assignment - m_f = rhs.m_f; - } - - void operator=(const Functor f) { m_f = f; } - -private: - Functor m_f = nullptr; -}; - -template -class FuncWithSig : public Functions { -public: - using Functions::operator(); - using Functions::operator=; -}; - class CustomOpImpl { static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; const uint32_t m_version; @@ -143,29 +110,26 @@ class CustomOpImpl { std::vector m_output_infos; ParamInfo m_param_infos; - using DeviceInfer = FuncWithSig&, const Param&, std::vector&)>>; - using ShapeInfer = FuncWithSig&, const Param&, std::vector&)>>; - using DTypeInfer = FuncWithSig&, const Param&, std::vector&)>>; - using FormatInfer = FuncWithSig&, const Param&, std::vector&)>>; - using Preprocess = FuncWithSig&, const Param&, std::vector&)>>; - using Postprocess = FuncWithSig&, const Param&, std::vector&)>>; - using Compute = FuncWithSig&, const Param&, std::vector&)>>; + using DeviceInfer = thin_function&, const Param&, std::vector&)>; + using ShapeInfer = thin_function&, const Param&, std::vector&)>; + using DTypeInfer = thin_function&, const Param&, std::vector&)>; + using FormatInfer = thin_function&, const Param&, std::vector&)>; + using Process = thin_function&, const Param&, std::vector&, + const RuntimeArgs&)>; DeviceInfer infer_output_device_func; ShapeInfer infer_output_shape_func; DTypeInfer infer_output_dtype_func; FormatInfer infer_output_format_func; - std::unordered_map compute_funcs; - std::unordered_map preprocess_funcs; - std::unordered_map postprocess_funcs; + std::unordered_map compute_funcs; + std::unordered_map preprocess_funcs; + std::unordered_map postprocess_funcs; public: CustomOpImpl(const std::string&, uint32_t version); @@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) for (const auto& device : Device::legal_devices()) { compute_funcs[device] = [](const std::vector&, const Param&, - std::vector& outputs) -> void { + std::vector& outputs, + const RuntimeArgs&) -> void { auto device = outputs[0].device(); mgb_assert( false, @@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) device.str().c_str()); }; preprocess_funcs[device] = [](const std::vector&, const Param&, - std::vector&) -> void { return; }; + std::vector&, + const RuntimeArgs&) -> void { return; }; postprocess_funcs[device] = [](const std::vector&, const Param&, - std::vector&) -> void { return; }; + std::vector&, + const RuntimeArgs&) -> void { return; }; } m_param_infos.set_tag(op_type); } @@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) { return *this; } -CustomOp& CustomOp::set_preprocess(PreprocessFuncPtr func) { +CustomOp& CustomOp::set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func) { + set_preprocess("x86", func); + return *this; +} + +CustomOp& CustomOp::set_preprocess( + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { + auto wrap_func = [func](const std::vector& input, const Param& param, + std::vector& output, const RuntimeArgs&) -> void { + return func(input, param, output); + }; + + OpImplRef(m_impl.get())->preprocess_funcs[device] = wrap_func; + return *this; +} + +CustomOp& CustomOp::set_preprocess(ProcessFuncPtr func) { set_preprocess("x86", func); return *this; } -CustomOp& CustomOp::set_preprocess(const std::string& device, PreprocessFuncPtr func) { +CustomOp& CustomOp::set_preprocess(const std::string& device, ProcessFuncPtr func) { OpImplRef(m_impl.get())->preprocess_funcs[device] = func; return *this; } -CustomOp& CustomOp::set_postprocess(PostprocessFuncPtr func) { +CustomOp& CustomOp::set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func) { set_postprocess("x86", func); return *this; } CustomOp& CustomOp::set_postprocess( - const std::string& device, PostprocessFuncPtr func) { + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { + auto wrap_func = [func](const std::vector& input, const Param& param, + std::vector& output, + const RuntimeArgs&) -> void { func(input, param, output); }; + + OpImplRef(m_impl.get())->postprocess_funcs[device] = wrap_func; + return *this; +} + +CustomOp& CustomOp::set_postprocess(ProcessFuncPtr func) { + set_postprocess("x86", func); + return *this; +} + +CustomOp& CustomOp::set_postprocess(const std::string& device, ProcessFuncPtr func) { OpImplRef(m_impl.get())->postprocess_funcs[device] = func; return *this; } -CustomOp& CustomOp::set_compute(ComputeFuncPtr func) { +CustomOp& CustomOp::set_compute(ProcessFuncPtrWithoutRuntimeArgs func) { + set_compute("x86", func); + return *this; +} + +CustomOp& CustomOp::set_compute( + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) { + auto wrap_func = [func](const std::vector& input, const Param& param, + std::vector& output, + const RuntimeArgs&) -> void { func(input, param, output); }; + + OpImplRef(m_impl.get())->compute_funcs[device] = wrap_func; + return *this; +} + +CustomOp& CustomOp::set_compute(ProcessFuncPtr func) { set_compute("x86", func); return *this; } -CustomOp& CustomOp::set_compute(const std::string& device, ComputeFuncPtr func) { +CustomOp& CustomOp::set_compute(const std::string& device, ProcessFuncPtr func) { OpImplRef(m_impl.get())->compute_funcs[device] = func; return *this; } @@ -513,23 +525,28 @@ void CustomOp::compute( return; } - std::string device = outputs[0].device().str(); + Device device = outputs[0].device(); + std::string device_str = device.str(); for (size_t i = 1; i < outputs.size(); ++i) { mgb_assert( - outputs[i].device().str() == device, + outputs[i].device().str() == device_str, "all output tensors should have the same device attribute"); } // need to add other input/output check - mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); + mgb_assert( + Device::is_legal(device_str), "unsupported device type: %s", + device_str.c_str()); + + auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device_str]; + auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device_str]; + auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device_str]; - auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; - auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device]; - auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device]; + RuntimeArgs rt_args(device); - preprocess_func(inputs, param, outputs); - forward_func(inputs, param, outputs); - postprocess_func(outputs, param, outputs); + preprocess_func(inputs, param, outputs, rt_args); + forward_func(inputs, param, outputs, rt_args); + postprocess_func(outputs, param, outputs, rt_args); assert_outputs_size_right(outputs); } diff --git a/src/custom/impl/platform/custom_cuda.cpp b/src/custom/impl/platform/custom_cuda.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fb456ddb61ec5deebfbdf318cc13fe8efbb1c56b --- /dev/null +++ b/src/custom/impl/platform/custom_cuda.cpp @@ -0,0 +1,21 @@ +#include "megbrain/common.h" + +#include "megbrain/comp_node_env.h" +#include "megbrain/custom/data_adaptor.h" +#include "megbrain/custom/platform/custom_cuda.h" + +using namespace mgb; + +namespace custom { + +const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) { + mgb_assert( + rt_args.device().enumv() == DeviceEnum::cuda, + "devive type should be cuda."); + const CompNodeEnv& env = + CompNodeEnv::from_comp_node(to_builtin(rt_args.device())); + const CompNodeEnv::CudaEnv& cuda_env = env.cuda_env(); + return {cuda_env.device, cuda_env.stream}; +} + +} // namespace custom diff --git a/src/custom/include/megbrain/custom/op.h b/src/custom/include/megbrain/custom/op.h index 69c46bf074de90c83df493094a34a3592ae82fd9..109b9c2d06b21469e1534b2e821175497562e11d 100644 --- a/src/custom/include/megbrain/custom/op.h +++ b/src/custom/include/megbrain/custom/op.h @@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo { std::string str() const; }; +class CudaRuntimeArgs; + +class MGE_WIN_DECLSPEC_FUC RuntimeArgs { + Device m_device; + +public: + RuntimeArgs() = default; + RuntimeArgs(Device device) : m_device(device){}; + + const Device& device() const { return m_device; } +}; + class MGE_WIN_DECLSPEC_FUC CustomOp { std::unique_ptr m_impl; @@ -51,11 +63,10 @@ public: void (*)(const std::vector&, const Param&, std::vector&); using FormatInferFuncPtr = void (*)(const std::vector&, const Param&, std::vector&); - using PreprocessFuncPtr = - void (*)(const std::vector&, const Param&, std::vector&); - using PostprocessFuncPtr = - void (*)(const std::vector&, const Param&, std::vector&); - using ComputeFuncPtr = + using ProcessFuncPtr = void (*)( + const std::vector&, const Param&, std::vector&, + const RuntimeArgs&); + using ProcessFuncPtrWithoutRuntimeArgs = void (*)(const std::vector&, const Param&, std::vector&); // write for forward @@ -63,12 +74,24 @@ public: CustomOp& set_shape_infer(ShapeInferFuncPtr func); CustomOp& set_dtype_infer(DTypeInferFuncPtr func); CustomOp& set_format_infer(FormatInferFuncPtr func); - CustomOp& set_preprocess(PreprocessFuncPtr func); - CustomOp& set_preprocess(const std::string& device, PreprocessFuncPtr func); - CustomOp& set_postprocess(PostprocessFuncPtr func); - CustomOp& set_postprocess(const std::string& device, PostprocessFuncPtr func); - CustomOp& set_compute(ComputeFuncPtr func); - CustomOp& set_compute(const std::string& device, ComputeFuncPtr func); + //! set process function with RuntimeArgs e.g. cuda + CustomOp& set_preprocess(ProcessFuncPtr func); + CustomOp& set_preprocess(const std::string& device, ProcessFuncPtr func); + CustomOp& set_postprocess(ProcessFuncPtr func); + CustomOp& set_postprocess(const std::string& device, ProcessFuncPtr func); + CustomOp& set_compute(ProcessFuncPtr func); + CustomOp& set_compute(const std::string& device, ProcessFuncPtr func); + + //! set process function without RuntimeArgs e.g. cpu + CustomOp& set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func); + CustomOp& set_preprocess( + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func); + CustomOp& set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func); + CustomOp& set_postprocess( + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func); + CustomOp& set_compute(ProcessFuncPtrWithoutRuntimeArgs func); + CustomOp& set_compute( + const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func); CustomOp& set_description(const std::string& op_desc); CustomOp& add_input( diff --git a/src/custom/include/megbrain/custom/platform/custom_cuda.h b/src/custom/include/megbrain/custom/platform/custom_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..52b070839c32b2288e4a7aaa814a101ba302db84 --- /dev/null +++ b/src/custom/include/megbrain/custom/platform/custom_cuda.h @@ -0,0 +1,25 @@ +#pragma once + +#include "megbrain/custom/op.h" + +#include + +namespace custom { + +class CudaRuntimeArgs { +private: + int m_device; + cudaStream_t m_stream; + +public: + CudaRuntimeArgs(int device, cudaStream_t stream) + : m_device(device), m_stream(stream) {} + + int device() const { return m_device; } + + cudaStream_t stream() const { return m_stream; } +}; + +const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args); + +} // namespace custom diff --git a/src/custom/test/op.cpp b/src/custom/test/op.cpp index b47a5d74463d5ded726f3dc989bd18408f84e76d..1e504e9856ff51d6ba6d28688ccb29dd81434f58 100644 --- a/src/custom/test/op.cpp +++ b/src/custom/test/op.cpp @@ -119,6 +119,34 @@ void gpu_kernel( ASSERT_TRUE(params["device"] == "cuda"); } +void cpu_kernel_with_runtime_args( + const std::vector& inputs, const Param& params, + std::vector& outputs, const RuntimeArgs& args) { + (void)inputs; + (void)params; + (void)outputs; + (void)args; +#if OP_TEST_LOG + std::cout << "Checking CPU Forward - " << params["device"].as() + << std::endl; +#endif + ASSERT_TRUE(params["device"] == "x86"); +} + +void gpu_kernel_with_runtime_args( + const std::vector& inputs, const Param& params, + std::vector& outputs, const RuntimeArgs& args) { + (void)inputs; + (void)params; + (void)outputs; + (void)args; +#if OP_TEST_LOG + std::cout << "Checking GPU Forward - " << params["device"].as() + << std::endl; +#endif + ASSERT_TRUE(params["device"] == "cuda"); +} + TEST(TestCustomOp, TestCustomOpFuncSetter) { #if MGB_CUDA CustomOp test("TestOp", CUSTOM_OP_VERSION); @@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ASSERT_TRUE(iformats[0].is_default()); ASSERT_TRUE(iformats[1].is_default()); + test.set_compute(cpu_kernel_with_runtime_args); test.set_compute(cpu_kernel); DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); @@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { param["device"] = "x86"; test.compute(cinputs, param, coutputs); + test.set_compute("cuda", gpu_kernel_with_runtime_args); test.set_compute("cuda", gpu_kernel); DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{});