提交 3018ca51 编写于 作者: M Megvii Engine Team

fix(mge): add param RuntimeArgs to customop kernel on cuda

GitOrigin-RevId: 7ed44c42ded50a2a07c258c13306dc5dc90bbd93
上级 086ee045
......@@ -90,6 +90,15 @@ endif()
if(MGE_WITH_CUSTOM_OP)
list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/custom/include)
file(GLOB_RECURSE SOURCES_ custom/impl/*.cpp)
set(EXCLUDE_PLATFORM_DIR "custom/impl/platform")
foreach(CUSOURCE ${SOURCES_})
string(FIND ${CUSOURCE} ${EXCLUDE_PLATFORM_DIR} EXCLUDE_DIR_FOUND)
if(NOT ${EXCLUDE_DIR_FOUND} EQUAL -1)
list(REMOVE_ITEM SOURCES_ ${CUSOURCE})
endif()
endforeach(CUSOURCE)
list(APPEND SOURCES ${SOURCES_})
endif()
......
......@@ -6,6 +6,7 @@
#include <unordered_set>
#include "megbrain/custom/op.h"
#include "megbrain/custom/utils.h"
#include "megbrain/utils/thin/function.h"
using namespace mgb;
......@@ -99,40 +100,6 @@ std::string ArgInfo::str() const {
(arg_info).name().c_str(), static_cast<int>((arg_info).ndim()), \
static_cast<int>((real_shape).ndim()))
template <typename T>
class Function;
template <typename RType, typename... Args>
class Function<RType(Args...)> {
public:
using Functor = RType (*)(Args...);
Function() = default;
Function(Functor f) : m_f(f) {}
Function(const Function& rhs) { m_f = rhs.m_f; }
RType operator()(Args... args) {
custom_assert(m_f != nullptr, "invalid function ptr\n");
return m_f(std::forward<Args>(args)...);
}
void operator=(const Function& rhs) { // not allowed continuous assignment
m_f = rhs.m_f;
}
void operator=(const Functor f) { m_f = f; }
private:
Functor m_f = nullptr;
};
template <typename Functions>
class FuncWithSig : public Functions {
public:
using Functions::operator();
using Functions::operator=;
};
class CustomOpImpl {
static constexpr uint32_t CURRENT_VERSION = CUSTOM_OP_VERSION;
const uint32_t m_version;
......@@ -143,29 +110,26 @@ class CustomOpImpl {
std::vector<ArgInfo> m_output_infos;
ParamInfo m_param_infos;
using DeviceInfer = FuncWithSig<Function<void(
const std::vector<Device>&, const Param&, std::vector<Device>&)>>;
using ShapeInfer = FuncWithSig<Function<void(
const std::vector<Shape>&, const Param&, std::vector<Shape>&)>>;
using DTypeInfer = FuncWithSig<Function<void(
const std::vector<DType>&, const Param&, std::vector<DType>&)>>;
using FormatInfer = FuncWithSig<Function<void(
const std::vector<Format>&, const Param&, std::vector<Format>&)>>;
using Preprocess = FuncWithSig<Function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using Postprocess = FuncWithSig<Function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using Compute = FuncWithSig<Function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&)>>;
using DeviceInfer = thin_function<void(
const std::vector<Device>&, const Param&, std::vector<Device>&)>;
using ShapeInfer = thin_function<void(
const std::vector<Shape>&, const Param&, std::vector<Shape>&)>;
using DTypeInfer = thin_function<void(
const std::vector<DType>&, const Param&, std::vector<DType>&)>;
using FormatInfer = thin_function<void(
const std::vector<Format>&, const Param&, std::vector<Format>&)>;
using Process = thin_function<void(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&,
const RuntimeArgs&)>;
DeviceInfer infer_output_device_func;
ShapeInfer infer_output_shape_func;
DTypeInfer infer_output_dtype_func;
FormatInfer infer_output_format_func;
std::unordered_map<std::string, Compute> compute_funcs;
std::unordered_map<std::string, Preprocess> preprocess_funcs;
std::unordered_map<std::string, Postprocess> postprocess_funcs;
std::unordered_map<std::string, Process> compute_funcs;
std::unordered_map<std::string, Process> preprocess_funcs;
std::unordered_map<std::string, Process> postprocess_funcs;
public:
CustomOpImpl(const std::string&, uint32_t version);
......@@ -215,7 +179,8 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
for (const auto& device : Device::legal_devices()) {
compute_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>& outputs) -> void {
std::vector<Tensor>& outputs,
const RuntimeArgs&) -> void {
auto device = outputs[0].device();
mgb_assert(
false,
......@@ -224,9 +189,11 @@ CustomOpImpl::CustomOpImpl(const std::string& op_type, uint32_t version)
device.str().c_str());
};
preprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>&) -> void { return; };
std::vector<Tensor>&,
const RuntimeArgs&) -> void { return; };
postprocess_funcs[device] = [](const std::vector<Tensor>&, const Param&,
std::vector<Tensor>&) -> void { return; };
std::vector<Tensor>&,
const RuntimeArgs&) -> void { return; };
}
m_param_infos.set_tag(op_type);
}
......@@ -256,33 +223,78 @@ CustomOp& CustomOp::set_format_infer(FormatInferFuncPtr func) {
return *this;
}
CustomOp& CustomOp::set_preprocess(PreprocessFuncPtr func) {
CustomOp& CustomOp::set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func) {
set_preprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_preprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output, const RuntimeArgs&) -> void {
return func(input, param, output);
};
OpImplRef(m_impl.get())->preprocess_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_preprocess(ProcessFuncPtr func) {
set_preprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_preprocess(const std::string& device, PreprocessFuncPtr func) {
CustomOp& CustomOp::set_preprocess(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->preprocess_funcs[device] = func;
return *this;
}
CustomOp& CustomOp::set_postprocess(PostprocessFuncPtr func) {
CustomOp& CustomOp::set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func) {
set_postprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_postprocess(
const std::string& device, PostprocessFuncPtr func) {
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output,
const RuntimeArgs&) -> void { func(input, param, output); };
OpImplRef(m_impl.get())->postprocess_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_postprocess(ProcessFuncPtr func) {
set_postprocess("x86", func);
return *this;
}
CustomOp& CustomOp::set_postprocess(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->postprocess_funcs[device] = func;
return *this;
}
CustomOp& CustomOp::set_compute(ComputeFuncPtr func) {
CustomOp& CustomOp::set_compute(ProcessFuncPtrWithoutRuntimeArgs func) {
set_compute("x86", func);
return *this;
}
CustomOp& CustomOp::set_compute(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func) {
auto wrap_func = [func](const std::vector<Tensor>& input, const Param& param,
std::vector<Tensor>& output,
const RuntimeArgs&) -> void { func(input, param, output); };
OpImplRef(m_impl.get())->compute_funcs[device] = wrap_func;
return *this;
}
CustomOp& CustomOp::set_compute(ProcessFuncPtr func) {
set_compute("x86", func);
return *this;
}
CustomOp& CustomOp::set_compute(const std::string& device, ComputeFuncPtr func) {
CustomOp& CustomOp::set_compute(const std::string& device, ProcessFuncPtr func) {
OpImplRef(m_impl.get())->compute_funcs[device] = func;
return *this;
}
......@@ -513,23 +525,28 @@ void CustomOp::compute(
return;
}
std::string device = outputs[0].device().str();
Device device = outputs[0].device();
std::string device_str = device.str();
for (size_t i = 1; i < outputs.size(); ++i) {
mgb_assert(
outputs[i].device().str() == device,
outputs[i].device().str() == device_str,
"all output tensors should have the same device attribute");
}
// need to add other input/output check
mgb_assert(Device::is_legal(device), "unsupported device type: %s", device.c_str());
mgb_assert(
Device::is_legal(device_str), "unsupported device type: %s",
device_str.c_str());
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device_str];
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device_str];
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device_str];
auto preprocess_func = OpImplRef(m_impl.get())->preprocess_funcs[device];
auto forward_func = OpImplRef(m_impl.get())->compute_funcs[device];
auto postprocess_func = OpImplRef(m_impl.get())->postprocess_funcs[device];
RuntimeArgs rt_args(device);
preprocess_func(inputs, param, outputs);
forward_func(inputs, param, outputs);
postprocess_func(outputs, param, outputs);
preprocess_func(inputs, param, outputs, rt_args);
forward_func(inputs, param, outputs, rt_args);
postprocess_func(outputs, param, outputs, rt_args);
assert_outputs_size_right(outputs);
}
......
#include "megbrain/common.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/platform/custom_cuda.h"
using namespace mgb;
namespace custom {
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert(
rt_args.device().enumv() == DeviceEnum::cuda,
"devive type should be cuda.");
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(to_builtin<CompNode, Device>(rt_args.device()));
const CompNodeEnv::CudaEnv& cuda_env = env.cuda_env();
return {cuda_env.device, cuda_env.stream};
}
} // namespace custom
......@@ -36,6 +36,18 @@ class MGE_WIN_DECLSPEC_FUC ArgInfo {
std::string str() const;
};
class CudaRuntimeArgs;
class MGE_WIN_DECLSPEC_FUC RuntimeArgs {
Device m_device;
public:
RuntimeArgs() = default;
RuntimeArgs(Device device) : m_device(device){};
const Device& device() const { return m_device; }
};
class MGE_WIN_DECLSPEC_FUC CustomOp {
std::unique_ptr<void, void_deleter> m_impl;
......@@ -51,11 +63,10 @@ public:
void (*)(const std::vector<DType>&, const Param&, std::vector<DType>&);
using FormatInferFuncPtr =
void (*)(const std::vector<Format>&, const Param&, std::vector<Format>&);
using PreprocessFuncPtr =
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
using PostprocessFuncPtr =
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
using ComputeFuncPtr =
using ProcessFuncPtr = void (*)(
const std::vector<Tensor>&, const Param&, std::vector<Tensor>&,
const RuntimeArgs&);
using ProcessFuncPtrWithoutRuntimeArgs =
void (*)(const std::vector<Tensor>&, const Param&, std::vector<Tensor>&);
// write for forward
......@@ -63,12 +74,24 @@ public:
CustomOp& set_shape_infer(ShapeInferFuncPtr func);
CustomOp& set_dtype_infer(DTypeInferFuncPtr func);
CustomOp& set_format_infer(FormatInferFuncPtr func);
CustomOp& set_preprocess(PreprocessFuncPtr func);
CustomOp& set_preprocess(const std::string& device, PreprocessFuncPtr func);
CustomOp& set_postprocess(PostprocessFuncPtr func);
CustomOp& set_postprocess(const std::string& device, PostprocessFuncPtr func);
CustomOp& set_compute(ComputeFuncPtr func);
CustomOp& set_compute(const std::string& device, ComputeFuncPtr func);
//! set process function with RuntimeArgs e.g. cuda
CustomOp& set_preprocess(ProcessFuncPtr func);
CustomOp& set_preprocess(const std::string& device, ProcessFuncPtr func);
CustomOp& set_postprocess(ProcessFuncPtr func);
CustomOp& set_postprocess(const std::string& device, ProcessFuncPtr func);
CustomOp& set_compute(ProcessFuncPtr func);
CustomOp& set_compute(const std::string& device, ProcessFuncPtr func);
//! set process function without RuntimeArgs e.g. cpu
CustomOp& set_preprocess(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_preprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_postprocess(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_postprocess(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_compute(ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_compute(
const std::string& device, ProcessFuncPtrWithoutRuntimeArgs func);
CustomOp& set_description(const std::string& op_desc);
CustomOp& add_input(
......
#pragma once
#include "megbrain/custom/op.h"
#include <cuda_runtime_api.h>
namespace custom {
class CudaRuntimeArgs {
private:
int m_device;
cudaStream_t m_stream;
public:
CudaRuntimeArgs(int device, cudaStream_t stream)
: m_device(device), m_stream(stream) {}
int device() const { return m_device; }
cudaStream_t stream() const { return m_stream; }
};
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args);
} // namespace custom
......@@ -119,6 +119,34 @@ void gpu_kernel(
ASSERT_TRUE(params["device"] == "cuda");
}
void cpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking CPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "x86");
}
void gpu_kernel_with_runtime_args(
const std::vector<Tensor>& inputs, const Param& params,
std::vector<Tensor>& outputs, const RuntimeArgs& args) {
(void)inputs;
(void)params;
(void)outputs;
(void)args;
#if OP_TEST_LOG
std::cout << "Checking GPU Forward - " << params["device"].as<std::string>()
<< std::endl;
#endif
ASSERT_TRUE(params["device"] == "cuda");
}
TEST(TestCustomOp, TestCustomOpFuncSetter) {
#if MGB_CUDA
CustomOp test("TestOp", CUSTOM_OP_VERSION);
......@@ -179,6 +207,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
ASSERT_TRUE(iformats[0].is_default());
ASSERT_TRUE(iformats[1].is_default());
test.set_compute(cpu_kernel_with_runtime_args);
test.set_compute(cpu_kernel);
DeviceTensorND cdev_itensor0(CompNode::load("cpux"), {3, 2}, dtype::Int32{});
DeviceTensorND cdev_itensor1(CompNode::load("cpux"), {3, 2}, dtype::Float32{});
......@@ -192,6 +221,7 @@ TEST(TestCustomOp, TestCustomOpFuncSetter) {
param["device"] = "x86";
test.compute(cinputs, param, coutputs);
test.set_compute("cuda", gpu_kernel_with_runtime_args);
test.set_compute("cuda", gpu_kernel);
DeviceTensorND gdev_itensor0(CompNode::load("gpux"), {3, 2}, dtype::Int32{});
DeviceTensorND gdev_itensor1(CompNode::load("gpux"), {3, 2}, dtype::Float32{});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册