提交 ca25f17a 编写于 作者: M Megvii Engine Team

refactor(src): add cuda helper for custom op

GitOrigin-RevId: ce32ccc9f5f5096b0131618947417e9e5dc6a0f3
上级 004cb8cd
......@@ -3,6 +3,7 @@
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h"
......@@ -704,6 +705,11 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
case custom::ParamDynType::Device: {
param_val =
to_custom_device(py::handle(kv.second).cast<mgb::CompNode>());
break;
}
default: {
mgb_assert(
false, "param dtype of %s:%s is invalid", op_name.c_str(),
......
......@@ -172,24 +172,23 @@ namespace custom_opdef { // avoid name conflict
void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
for (auto&& output : (*outputs)) {
auto cn = output.comp_node();
cn.activate();
}
// [TODO] sync should be modified
CompNode::sync_all();
auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs);
CompNode::sync_all();
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
mgb_assert(validated == true, "infer output attributes fall\n");
SmallVector<TensorPtr> outputs(output_descs.size());
if (validated == false) {
auto&& op = static_cast<const CustomOpDef&>(def);
for (size_t i = 0; i < outputs.size(); ++i) {
auto [output_descs, success] = op.infer_output_attrs(inputs);
mgb_assert(success == true, "infer output attributes fall\n");
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
auto& output = outputs[i];
output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
......@@ -241,12 +240,13 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return a.param() == b.param() && a.runtime_id() == b.runtime_id();
}
// [TODO] to be implemented
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now");
// can be implement with param schema
// auto&& custom_opdef = def.cast_final_safe<CustomOpDef>();
auto&& custom_opdef = def.cast_final_safe<CustomOpDef>();
auto&& param_raw = custom_opdef.param().raw();
std::vector<std::pair<const char*, std::string>> props_;
for (auto&& kv : param_raw) {
props_.emplace_back(kv.first.c_str(), kv.second.str());
}
return props_;
}
......
......@@ -99,6 +99,10 @@ if(MGE_WITH_CUSTOM_OP)
endif()
endforeach(CUSOURCE)
if(MGE_WITH_CUDA)
list(APPEND SOURCES_ custom/impl/platform/custom_cuda.cpp)
endif()
list(APPEND SOURCES ${SOURCES_})
endif()
......
......@@ -2,7 +2,10 @@
#if MGB_CUSTOM_OP
#include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/param_val.h"
#include "megbrain/custom/tensor.h"
#pragma GCC diagnostic ignored "-Wsign-compare"
......@@ -268,6 +271,11 @@ std::string ParamVal::str() const {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST)
case ParamDynType::Device: {
auto&& rval = TypedRef(Device, m_ptr.get());
ss << to_builtin_device(rval).to_string();
break;
}
default:
mgb_assert(false, "invalid data of assignment operator of ParamVal");
}
......
#include "megbrain/common.h"
#include "megbrain_build_config.h"
#if MGB_CUSTOM_OP
#include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h"
......@@ -8,6 +11,8 @@ using namespace mgb;
namespace custom {
#if MGB_CUDA
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert(
rt_args.device().enumv() == DeviceEnum::cuda,
......@@ -18,4 +23,53 @@ const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
return {cuda_env.device, cuda_env.stream};
}
int get_cuda_device_id(Device device) {
auto cn = to_builtin<CompNode>(device);
return CompNodeEnv::from_comp_node(cn).cuda_env().device;
}
const cudaDeviceProp* get_cuda_device_props(Device device) {
auto cn = to_builtin<CompNode>(device);
return &CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
}
cudaStream_t get_cuda_stream(Device device) {
auto cn = to_builtin<CompNode>(device);
return CompNodeEnv::from_comp_node(cn).cuda_env().stream;
}
#else
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
int get_cuda_device_id(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
const cudaDeviceProp* get_cuda_device_props(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
cudaStream_t get_cuda_stream(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
#endif
} // namespace custom
#endif
......@@ -470,12 +470,16 @@ uint8_t Tensor::zero_point(void) const {
return dtype().zero_point();
}
void* Tensor::data(void) {
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr());
bool Tensor::is_contiguous() const {
return TensorImplRef(m_tensor).layout().is_contiguous();
}
bool Tensor::is_empty() const {
return TensorImplRef(m_tensor).layout().is_empty();
}
const void* Tensor::data(void) const {
return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr());
void* Tensor::data(void) const {
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr());
}
} // namespace custom
......
......@@ -10,7 +10,7 @@ MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> op_insert(
}
#define CUSTOM_OP_REG(OpName) \
CustomOp& _##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION)))
::custom::CustomOp& _##OpName = (*(::custom::op_insert(#OpName, CUSTOM_OP_VERSION)))
#define CUSTOM_OP_REG_BEGIN(OpName) \
namespace custom { \
......
......@@ -34,17 +34,23 @@ std::vector<CustomT> to_custom(const megdnn::SmallVector<BuiltinT>& builtins) {
} // namespace custom
#define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr)
#define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr)
#define to_custom_device(expr) \
::custom::to_custom<::mgb::CompNode, ::custom::Device>(expr)
#define to_builtin_device(expr) \
::custom::to_builtin<::mgb::CompNode, ::custom::Device>(expr)
#define to_custom_shape(expr) \
custom::to_custom<megdnn::TensorShape, custom::Shape>(expr)
::custom::to_custom<::megdnn::TensorShape, ::custom::Shape>(expr)
#define to_builtin_shape(expr) \
custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr)
#define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr)
#define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr)
::custom::to_builtin<::megdnn::TensorShape, ::custom::Shape>(expr)
#define to_custom_dtype(expr) \
::custom::to_custom<::megdnn::DType, ::custom::DType>(expr)
#define to_builtin_dtype(expr) \
::custom::to_builtin<::megdnn::DType, ::custom::DType>(expr)
#define to_custom_format(expr) \
custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr)
::custom::to_custom<::megdnn::TensorLayout::Format, ::custom::Format>(expr)
#define to_builtin_format(expr) \
custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr)
#define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr)
#define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr)
::custom::to_builtin<::megdnn::TensorLayout::Format, ::custom::Format>(expr)
#define to_custom_tensor(expr) \
::custom::to_custom<::mgb::DeviceTensorND, ::custom::Tensor>(expr)
#define to_builtin_tensor(expr) \
::custom::to_builtin<::mgb::DeviceTensorND, ::custom::Tensor>(expr)
......@@ -14,23 +14,28 @@ namespace custom {
* we can add a new basic data type here, basic means we can perform binary
* op such as: +, -, *, /, ==, != between any two of them
*/
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format off
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format on
#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) cb(String, std::string, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \
cb(Float64List, std::vector<double>, \
##__VA_ARGS__)
// clang-format off
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \
cb(Float64List, std::vector<double>, ##__VA_ARGS__)
// clang-format on
#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \
cb(BoolList, std::vector<bool>, ##__VA_ARGS__)
......@@ -41,19 +46,26 @@ namespace custom {
/**
* to avoid the recursive of MACRO
*/
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format off
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \
cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format on
class Device;
#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__)
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
cb(Device, ::custom::Device, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
......
......@@ -3,10 +3,11 @@
#include "megbrain/custom/op.h"
#include <cuda_runtime_api.h>
#include <driver_types.h>
namespace custom {
class CudaRuntimeArgs {
class MGE_WIN_DECLSPEC_FUC CudaRuntimeArgs {
private:
int m_device;
cudaStream_t m_stream;
......@@ -20,6 +21,10 @@ public:
cudaStream_t stream() const { return m_stream; }
};
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args);
MGE_WIN_DECLSPEC_FUC const CudaRuntimeArgs
get_cuda_runtime_args(const RuntimeArgs& rt_args);
MGE_WIN_DECLSPEC_FUC int get_cuda_device_id(Device device);
MGE_WIN_DECLSPEC_FUC const cudaDeviceProp* get_cuda_device_props(Device device);
MGE_WIN_DECLSPEC_FUC cudaStream_t get_cuda_stream(Device device);
} // namespace custom
......@@ -80,15 +80,21 @@ using bfloat16_t = uint16_t;
cb(custom_dtype, dnn_dtype, c_dtype)
#endif
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \
cb(float32, Float32, float) cb(uint8, Uint8, uint8_t) cb(int8, Int8, int8_t) cb( \
int16, Int16, int16_t) cb(int32, Int32, int32_t) \
fp16_wrap(cb, float16, Float16, float16_t) fp16_wrap( \
cb, bfloat16, BFloat16, bfloat16_t) cb(uint16, Uint16, uint16_t) \
cb(quint8, Quantized8Asymm, uint8_t) \
cb(qint32, QuantizedS32, int32_t) \
cb(qint8, QuantizedS8, int8_t) \
cb(qint16, QuantizedS16, int16_t)
// clang-format off
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \
cb(float32, Float32, float) \
cb(uint8, Uint8, uint8_t) \
cb(int8, Int8, int8_t) \
cb(int16, Int16, int16_t) \
cb(int32, Int32, int32_t) \
fp16_wrap(cb, float16, Float16, float16_t) \
fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \
cb(uint16, Uint16, uint16_t) \
cb(quint8, Quantized8Asymm, uint8_t) \
cb(qint32, QuantizedS32, int32_t) \
cb(qint8, QuantizedS8, int8_t) \
cb(qint16, QuantizedS16, int16_t)
// clang-format on
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type,
......@@ -214,13 +220,12 @@ public:
float scale(void) const;
uint8_t zero_point(void) const;
void* data(void);
const void* data(void) const;
bool is_contiguous() const;
bool is_empty() const;
void* data(void) const;
template <typename T>
T* data(void);
template <typename T>
const T* data(void) const;
T* data(void) const;
template <
typename T, size_t N,
......@@ -238,16 +243,12 @@ public:
};
template <typename T>
T* Tensor::data(void) {
T* Tensor::data(void) const {
custom_assert(
dtype().is_compatible<T>(), "invalid convert, tensor data type is %s",
dtype().str().c_str());
return reinterpret_cast<T*>(data());
}
template <typename T>
const T* Tensor::data(void) const {
return const_cast<Tensor*>(this)->data<T>();
}
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t>
const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const {
......
......@@ -20,26 +20,28 @@ MGE_WIN_DECLSPEC_FUC void assert_failed_log(
const char* msg_fmt, ...);
#ifndef _WIN32
#define custom_expect(expr, msg...) \
if (!(expr)) { \
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
#define custom_expect(expr, msg...) \
if (!(expr)) { \
::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
}
#define custom_assert(expr, msg...) \
if (!(expr)) { \
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
} \
#define custom_assert(expr, msg...) \
if (!(expr)) { \
::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
} \
assert((expr))
#else
#define custom_expect(expr, ...) \
if (!(expr)) { \
assert_failed_log( \
::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
}
#define custom_assert(expr, ...) \
if (!(expr)) { \
assert_failed_log( \
::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
} \
assert((expr))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册