提交 3018ca51 编写于 作者: M Megvii Engine Team

fix(mge): add param RuntimeArgs to customop kernel on cuda

GitOrigin-RevId: 7ed44c42ded50a2a07c258c13306dc5dc90bbd93
上级 086ee045
...@@ -90,6 +90,15 @@ endif() ...@@ -90,6 +90,15 @@ endif()
if(MGE_WITH_CUSTOM_OP) if(MGE_WITH_CUSTOM_OP)
list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/custom/include) list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/custom/include)
file(GLOB_RECURSE SOURCES_ custom/impl/*.cpp) file(GLOB_RECURSE SOURCES_ custom/impl/*.cpp)
set(EXCLUDE_PLATFORM_DIR "custom/impl/platform")
foreach(CUSOURCE ${SOURCES_})
string(FIND ${CUSOURCE} ${EXCLUDE_PLATFORM_DIR} EXCLUDE_DIR_FOUND)
if(NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
list(REMOVE_ITEM SOURCES_ ${CUSOURCE})
endif()
endforeach(CUSOURCE)
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
endif() endif()
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <unordered_set> #include <unordered_set>
#include "megbrain/custom/op.h" #include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h" #include "megbrain/custom/utils.h"
#include "megbrain/utils/thin/function.h"
using namespace mgb; using namespace mgb;
...@@ -99,40 +100,6 @@ std::string ArgInfo::str() const { ...@@ -99,40 +100,6 @@ std::string ArgInfo::str() const {
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \ (arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \
static_cast<int>((real_shape).ndim())) static_cast<int>((real_shape).ndim()))
template <typename T>
class Function;
template <typename RType, typename... Args>
class Function<RType(Args...)> {
public:
using Functor = RType (*)(Args...);
Function() = default;
Function(Functor f) : m_f(f) {}
Function(const Function& rhs) { m_f = rhs.m_f; }
RType operator()(Args... args) {
custom_assert(m_f != nullptr, "invalid function ptr\n");
return m_f(std::forward<Args>(args)...);
}
void operator=(const Function& rhs) { // not allowed continuous assignment
m_f = rhs.m_f;
}
void operator=(const Functor f) { m_f = f; }
private:
Functor m_f = nullptr;
};
template <typename Functions>
class FuncWithSig : public Functions {
public:
using Functions::operator();
using Functions::operator=;
};
class CustomOpImpl { class CustomOpImpl {
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION; static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION;
const uint32_t m_version; const uint32_t m_version;
...@@ -143,29 +110,26 @@ class CustomOpImpl { ...@@ -143,29 +110,26 @@ class CustomOpImpl {
std::vector<ArgInfo> m_output_infos; std::vector<ArgInfo> m_output_infos;
ParamInfo m_param_infos; ParamInfo m_param_infos;
using DeviceInfer = FuncWithSig<Function<void( using DeviceInfer = thin_function<void(
const std::vector<Device>&, const Param&, std::vector<Device>&)>>; const std::vector<Device>&, const Param&, std::vector<Device>&)>;
using ShapeInfer = FuncWithSig<Function<void( using ShapeInfer = thin_function<void(
const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>; const std::vector<Shape>&, const Param&, std::vector<Shape>&)>;
using DTypeInfer = FuncWithSig<Function<void( using DTypeInfer = thin_function<void(
const std::vector<DType>&, const Param&, std::vector<DType>&)>>; const std::vector<DType>&, const Param&, std::vector<DType>&)>;
using FormatInfer = FuncWithSig<Function<void( using FormatInfer = thin_function<void(
const std::vector<Format>&, const Param&, std::vector<Format>&)>>; const std::vector<Format>&, const Param&, std::vector<Format>&)>;
using Preprocess = FuncWithSig<Function<void( using Process = thin_function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>; const std::vector<Tensor>&, const Param&, std::vector<Tensor>&,
using Postprocess = FuncWithSig<Function<void( const RuntimeArgs&)>;
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using Compute = FuncWithSig<Function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
DeviceInfer infer_output_device_func; DeviceInfer infer_output_device_func;
ShapeInfer infer_output_shape_func; ShapeInfer infer_output_shape_func;
DTypeInfer infer_output_dtype_func; DTypeInfer infer_output_dtype_func;
FormatInfer infer_output_format_func; FormatInfer infer_output_format_func;
std::unordered_map<std::string, Compute> compute_funcs; std::unordered_map<std::string, Process> compute_funcs;
std::unordered_map<std::string, Preprocess> preprocess_funcs; std::unordered_map<std::string, Process> preprocess_funcs;
std::unordered_map<std::string, Postprocess> postprocess_funcs; std::unordered_map<std::string, Process> postprocess_funcs;
public: public:
CustomOpImpl(const std::string&, uint32_t version); CustomOpImpl(const std::string&, uint32_t version);
...@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) ...@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
for (const auto& device : Device::legal_devices()) { for (const auto& device : Device::legal_devices()) {
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&, compute_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>& outputs) -> void { std::vector<Tensor>& outputs,
const RuntimeArgs&) -> void {
auto device = outputs[0].device(); auto device = outputs[0].device();
mgb_assert( mgb_assert(
false, false,
...@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version) ...@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
device.str().c_str()); device.str().c_str());
}; };
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>&) -> void { return; }; std::vector<Tensor>&,
const RuntimeArgs&) -> void { return; };
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&, postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>&) -> void { return; }; std::vector<Tensor>&,
const RuntimeArgs&) -> void { return; };
} }
m_param_infos.set_tag(op_type); m_param_infos.set_tag(op_type);
} }
...@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) { ...@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) {
return *this; return *this;
} }
CustomOp& CustomOp::set_preprocess(PreprocessFuncPtr func) { CustomOp& CustomOp::set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func) {
set_preprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_preprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output, const RuntimeArgs&) -> void {
return func(input, param, output);
};
OpImplRef(m_impl.get())->preprocess_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_preprocess(ProcessFuncPtr func) {
set_preprocess("x86", func); set_preprocess("x86", func);
return *this; return *this;
} }
CustomOp& CustomOp::set_preprocess(const std::string& device, PreprocessFuncPtr func) { CustomOp& CustomOp::set_preprocess(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->preprocess_funcs[device] = func; OpImplRef(m_impl.get())->preprocess_funcs[device] = func;
return *this; return *this;
} }
CustomOp& CustomOp::set_postprocess(PostprocessFuncPtr func) { CustomOp& CustomOp::set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func) {
set_postprocess("x86", func); set_postprocess("x86", func);
return *this; return *this;
} }
CustomOp& CustomOp::set_postprocess( CustomOp& CustomOp::set_postprocess(
const std::string& device, PostprocessFuncPtr func) { const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output,
const RuntimeArgs&) -> void { func(input, param, output); };
OpImplRef(m_impl.get())->postprocess_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_postprocess(ProcessFuncPtr func) {
set_postprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_postprocess(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->postprocess_funcs[device] = func; OpImplRef(m_impl.get())->postprocess_funcs[device] = func;
return *this; return *this;
} }
CustomOp& CustomOp::set_compute(ComputeFuncPtr func) { CustomOp& CustomOp::set_compute(ProcessFuncPtrWithoutRuntimeArgs func) {
set_compute("x86", func);
return *this;
}
CustomOp& CustomOp::set_compute(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output,
const RuntimeArgs&) -> void { func(input, param, output); };
OpImplRef(m_impl.get())->compute_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_compute(ProcessFuncPtr func) {
set_compute("x86", func); set_compute("x86", func);
return *this; return *this;
} }
CustomOp& CustomOp::set_compute(const std::string& device, ComputeFuncPtr func) { CustomOp& CustomOp::set_compute(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->compute_funcs[device] = func; OpImplRef(m_impl.get())->compute_funcs[device] = func;
return *this; return *this;
} }
...@@ -513,23 +525,28 @@ void CustomOp::compute( ...@@ -513,23 +525,28 @@ void CustomOp::compute(
return; return;
} }
std::string device = outputs[0].device().str(); Device device = outputs[0].device();
std::string device_str = device.str();
for (size_t i = 1; i < outputs.size(); ++i) { for (size_t i = 1; i < outputs.size(); ++i) {
mgb_assert( mgb_assert(
outputs[i].device().str() == device, outputs[i].device().str() == device_str,
"all output tensors should have the same device attribute"); "all output tensors should have the same device attribute");
} }
// need to add other input/output check // need to add other input/output check
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str()); mgb_assert(
Device::is_legal(device_str), "unsupported device type: %s",
device_str.c_str());
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device_str];
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device_str];
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device_str];
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device]; RuntimeArgs rt_args(device);
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device];
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device];
preprocess_func(inputs, param, outputs); preprocess_func(inputs, param, outputs, rt_args);
forward_func(inputs, param, outputs); forward_func(inputs, param, outputs, rt_args);
postprocess_func(outputs, param, outputs); postprocess_func(outputs, param, outputs, rt_args);
assert_outputs_size_right(outputs); assert_outputs_size_right(outputs);
} }
......
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"
using namespace mgb;
namespace custom {
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert(
rt_args.device().enumv() == DeviceEnum::cuda,
"devive type should be cuda.");
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(to_builtin<CompNode, Device>(rt_args.device()));
const CompNodeEnv::CudaEnv& cuda_env = env.cuda_env();
return {cuda_env.device, cuda_env.stream};
}
} // namespace custom
...@@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo { ...@@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo {
std::string str() const; std::string str() const;
}; };
class CudaRuntimeArgs;
class MGE_WIN_DECLSPEC_FUC RuntimeArgs {
Device m_device;
public:
RuntimeArgs() = default;
RuntimeArgs(Device device) : m_device(device){};
const Device& device() const { return m_device; }
};
class MGE_WIN_DECLSPEC_FUC CustomOp { class MGE_WIN_DECLSPEC_FUC CustomOp {
std::unique_ptr<void, void_deleter> m_impl; std::unique_ptr<void, void_deleter> m_impl;
...@@ -51,11 +63,10 @@ public: ...@@ -51,11 +63,10 @@ public:
void (*)(const std::vector<DType>&, const Param&, std::vector<DType>&); void (*)(const std::vector<DType>&, const Param&, std::vector<DType>&);
using FormatInferFuncPtr = using FormatInferFuncPtr =
void (*)(const std::vector<Format>&, const Param&, std::vector<Format>&); void (*)(const std::vector<Format>&, const Param&, std::vector<Format>&);
using PreprocessFuncPtr = using ProcessFuncPtr = void (*)(
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); const std::vector<Tensor>&, const Param&, std::vector<Tensor>&,
using PostprocessFuncPtr = const RuntimeArgs&);
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); using ProcessFuncPtrWithoutRuntimeArgs =
using ComputeFuncPtr =
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&); void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
// write for forward // write for forward
...@@ -63,12 +74,24 @@ public: ...@@ -63,12 +74,24 @@ public:
CustomOp& set_shape_infer(ShapeInferFuncPtr func); CustomOp& set_shape_infer(ShapeInferFuncPtr func);
CustomOp& set_dtype_infer(DTypeInferFuncPtr func); CustomOp& set_dtype_infer(DTypeInferFuncPtr func);
CustomOp& set_format_infer(FormatInferFuncPtr func); CustomOp& set_format_infer(FormatInferFuncPtr func);
CustomOp& set_preprocess(PreprocessFuncPtr func); //! set process function with RuntimeArgs e.g. cuda
CustomOp& set_preprocess(const std::string& device, PreprocessFuncPtr func); CustomOp& set_preprocess(ProcessFuncPtr func);
CustomOp& set_postprocess(PostprocessFuncPtr func); CustomOp& set_preprocess(const std::string& device, ProcessFuncPtr func);
CustomOp& set_postprocess(const std::string& device, PostprocessFuncPtr func); CustomOp& set_postprocess(ProcessFuncPtr func);
CustomOp& set_compute(ComputeFuncPtr func); CustomOp& set_postprocess(const std::string& device, ProcessFuncPtr func);
CustomOp& set_compute(const std::string& device, ComputeFuncPtr func); CustomOp& set_compute(ProcessFuncPtr func);
CustomOp& set_compute(const std::string& device, ProcessFuncPtr func);
//! set process function without RuntimeArgs e.g. cpu
CustomOp& set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_preprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_postprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_compute(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_compute(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_description(const std::string& op_desc); CustomOp& set_description(const std::string& op_desc);
CustomOp& add_input( CustomOp& add_input(
......
#pragma once
#include "megbrain/custom/op.h"
#include <cuda_runtime_api.h>
namespace custom {
class CudaRuntimeArgs {
private:
int m_device;
cudaStream_t m_stream;
public:
CudaRuntimeArgs(int device, cudaStream_t stream)
: m_device(device), m_stream(stream) {}
int device() const { return m_device; }
cudaStream_t stream() const { return m_stream; }
};
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args);
} // namespace custom
...@@ -119,6 +119,34 @@ void gpu_kernel( ...@@ -119,6 +119,34 @@ void gpu_kernel(
ASSERT_TRUE(params["device"] == "cuda"); ASSERT_TRUE(params["device"] == "cuda");
} }
void cpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "x86");
}
void gpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "cuda");
}
TEST(TestCustomOp, TestCustomOpFuncSetter) { TEST(TestCustomOp, TestCustomOpFuncSetter) {
#if MGB_CUDA #if MGB_CUDA
CustomOp test("TestOp", CUSTOM_OP_VERSION); CustomOp test("TestOp", CUSTOM_OP_VERSION);
...@@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ...@@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
ASSERT_TRUE(iformats[0].is_default()); ASSERT_TRUE(iformats[0].is_default());
ASSERT_TRUE(iformats[1].is_default()); ASSERT_TRUE(iformats[1].is_default());
test.set_compute(cpu_kernel_with_runtime_args);
test.set_compute(cpu_kernel); test.set_compute(cpu_kernel);
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{}); DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{});
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{}); DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
...@@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) { ...@@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
param["device"] = "x86"; param["device"] = "x86";
test.compute(cinputs, param, coutputs); test.compute(cinputs, param, coutputs);
test.set_compute("cuda", gpu_kernel_with_runtime_args);
test.set_compute("cuda", gpu_kernel); test.set_compute("cuda", gpu_kernel);
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{}); DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{});
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{}); DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册