提交 ca25f17a 编写于 作者: M Megvii Engine Team

refactor(src): add cuda helper for custom op

GitOrigin-RevId: ce32ccc9f5f5096b0131618947417e9e5dc6a0f3
上级 004cb8cd
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "./tensor.h" #include "./tensor.h"
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/graph_builder.h" #include "megbrain/imperative/graph_builder.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
...@@ -704,6 +705,11 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) { ...@@ -704,6 +705,11 @@ PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST) CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
case custom::ParamDynType::Device: {
param_val =
to_custom_device(py::handle(kv.second).cast<mgb::CompNode>());
break;
}
default: { default: {
mgb_assert( mgb_assert(
false, "param dtype of %s:%s is invalid", op_name.c_str(), false, "param dtype of %s:%s is invalid", op_name.c_str(),
......
...@@ -172,24 +172,23 @@ namespace custom_opdef { // avoid name conflict ...@@ -172,24 +172,23 @@ namespace custom_opdef { // avoid name conflict
void apply_on_device_tensornd( void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs, const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) { SmallVector<DeviceTensorND>* outputs) {
for (auto&& output : (*outputs)) {
auto cn = output.comp_node();
cn.activate();
}
// [TODO] sync should be modified
CompNode::sync_all();
auto&& op = static_cast<const CustomOpDef&>(def); auto&& op = static_cast<const CustomOpDef&>(def);
op.compute(inputs, outputs); op.compute(inputs, outputs);
CompNode::sync_all();
} }
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs, const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
mgb_assert(validated == true, "infer output attributes fall\n");
SmallVector<TensorPtr> outputs(output_descs.size()); SmallVector<TensorPtr> outputs(output_descs.size());
if (validated == false) {
auto&& op = static_cast<const CustomOpDef&>(def);
for (size_t i = 0; i < outputs.size(); ++i) {
auto [output_descs, success] = op.infer_output_attrs(inputs);
mgb_assert(success == true, "infer output attributes fall\n");
}
}
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
auto& output = outputs[i]; auto& output = outputs[i];
output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); output = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
...@@ -241,12 +240,13 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) { ...@@ -241,12 +240,13 @@ bool is_same_st(const OpDef& lhs, const OpDef& rhs) {
return a.param() == b.param() && a.runtime_id() == b.runtime_id(); return a.param() == b.param() && a.runtime_id() == b.runtime_id();
} }
// [TODO] to be implemented
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
mgb_assert(false, "Custom OpDef Props Function is not IMPLEMENTED now"); auto&& custom_opdef = def.cast_final_safe<CustomOpDef>();
// can be implement with param schema auto&& param_raw = custom_opdef.param().raw();
// auto&& custom_opdef = def.cast_final_safe<CustomOpDef>();
std::vector<std::pair<const char*, std::string>> props_; std::vector<std::pair<const char*, std::string>> props_;
for (auto&& kv : param_raw) {
props_.emplace_back(kv.first.c_str(), kv.second.str());
}
return props_; return props_;
} }
......
...@@ -99,6 +99,10 @@ if(MGE_WITH_CUSTOM_OP) ...@@ -99,6 +99,10 @@ if(MGE_WITH_CUSTOM_OP)
endif() endif()
endforeach(CUSOURCE) endforeach(CUSOURCE)
if(MGE_WITH_CUDA)
list(APPEND SOURCES_ custom/impl/platform/custom_cuda.cpp)
endif()
list(APPEND SOURCES ${SOURCES_}) list(APPEND SOURCES ${SOURCES_})
endif() endif()
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
#if MGB_CUSTOM_OP #if MGB_CUSTOM_OP
#include "megbrain/comp_node.h"
#include "megbrain/custom/data_adaptor.h"
#include "megbrain/custom/param_val.h" #include "megbrain/custom/param_val.h"
#include "megbrain/custom/tensor.h"
#pragma GCC diagnostic ignored "-Wsign-compare" #pragma GCC diagnostic ignored "-Wsign-compare"
...@@ -268,6 +271,11 @@ std::string ParamVal::str() const { ...@@ -268,6 +271,11 @@ std::string ParamVal::str() const {
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST)
case ParamDynType::Device: {
auto&& rval = TypedRef(Device, m_ptr.get());
ss << to_builtin_device(rval).to_string();
break;
}
default: default:
mgb_assert(false, "invalid data of assignment operator of ParamVal"); mgb_assert(false, "invalid data of assignment operator of ParamVal");
} }
......
#include "megbrain/common.h" #include "megbrain/common.h"
#include "megbrain_build_config.h"
#if MGB_CUSTOM_OP
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/custom/data_adaptor.h" #include "megbrain/custom/data_adaptor.h"
...@@ -8,6 +11,8 @@ using namespace mgb; ...@@ -8,6 +11,8 @@ using namespace mgb;
namespace custom { namespace custom {
#if MGB_CUDA
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) { const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert( mgb_assert(
rt_args.device().enumv() == DeviceEnum::cuda, rt_args.device().enumv() == DeviceEnum::cuda,
...@@ -18,4 +23,53 @@ const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) { ...@@ -18,4 +23,53 @@ const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
return {cuda_env.device, cuda_env.stream}; return {cuda_env.device, cuda_env.stream};
} }
int get_cuda_device_id(Device device) {
auto cn = to_builtin<CompNode>(device);
return CompNodeEnv::from_comp_node(cn).cuda_env().device;
}
const cudaDeviceProp* get_cuda_device_props(Device device) {
auto cn = to_builtin<CompNode>(device);
return &CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
}
cudaStream_t get_cuda_stream(Device device) {
auto cn = to_builtin<CompNode>(device);
return CompNodeEnv::from_comp_node(cn).cuda_env().stream;
}
#else
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
int get_cuda_device_id(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
const cudaDeviceProp* get_cuda_device_props(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
cudaStream_t get_cuda_stream(Device device) {
mgb_assert(
false,
"megbrain is not support cuda now, please rebuild megbrain with CUDA "
"ENABLED");
}
#endif
} // namespace custom } // namespace custom
#endif
...@@ -470,12 +470,16 @@ uint8_t Tensor::zero_point(void) const { ...@@ -470,12 +470,16 @@ uint8_t Tensor::zero_point(void) const {
return dtype().zero_point(); return dtype().zero_point();
} }
void* Tensor::data(void) { bool Tensor::is_contiguous() const {
return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr()); return TensorImplRef(m_tensor).layout().is_contiguous();
}
bool Tensor::is_empty() const {
return TensorImplRef(m_tensor).layout().is_empty();
} }
const void* Tensor::data(void) const { void* Tensor::data(void) const {
return static_cast<const void*>(TensorImplRef(m_tensor).raw_ptr()); return static_cast<void*>(TensorImplRef(m_tensor).raw_ptr());
} }
} // namespace custom } // namespace custom
......
...@@ -10,7 +10,7 @@ MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> op_insert( ...@@ -10,7 +10,7 @@ MGE_WIN_DECLSPEC_FUC std::shared_ptr<CustomOp> op_insert(
} }
#define CUSTOM_OP_REG(OpName) \ #define CUSTOM_OP_REG(OpName) \
CustomOp& _##OpName = (*(op_insert(#OpName, CUSTOM_OP_VERSION))) ::custom::CustomOp& _##OpName = (*(::custom::op_insert(#OpName, CUSTOM_OP_VERSION)))
#define CUSTOM_OP_REG_BEGIN(OpName) \ #define CUSTOM_OP_REG_BEGIN(OpName) \
namespace custom { \ namespace custom { \
......
...@@ -34,17 +34,23 @@ std::vector<CustomT> to_custom(const megdnn::SmallVector<BuiltinT>& builtins) { ...@@ -34,17 +34,23 @@ std::vector<CustomT> to_custom(const megdnn::SmallVector<BuiltinT>& builtins) {
} // namespace custom } // namespace custom
#define to_custom_device(expr) custom::to_custom<CompNode, custom::Device>(expr) #define to_custom_device(expr) \
#define to_builtin_device(expr) custom::to_builtin<CompNode, custom::Device>(expr) ::custom::to_custom<::mgb::CompNode, ::custom::Device>(expr)
#define to_builtin_device(expr) \
::custom::to_builtin<::mgb::CompNode, ::custom::Device>(expr)
#define to_custom_shape(expr) \ #define to_custom_shape(expr) \
custom::to_custom<megdnn::TensorShape, custom::Shape>(expr) ::custom::to_custom<::megdnn::TensorShape, ::custom::Shape>(expr)
#define to_builtin_shape(expr) \ #define to_builtin_shape(expr) \
custom::to_builtin<megdnn::TensorShape, custom::Shape>(expr) ::custom::to_builtin<::megdnn::TensorShape, ::custom::Shape>(expr)
#define to_custom_dtype(expr) custom::to_custom<megdnn::DType, custom::DType>(expr) #define to_custom_dtype(expr) \
#define to_builtin_dtype(expr) custom::to_builtin<megdnn::DType, custom::DType>(expr) ::custom::to_custom<::megdnn::DType, ::custom::DType>(expr)
#define to_builtin_dtype(expr) \
::custom::to_builtin<::megdnn::DType, ::custom::DType>(expr)
#define to_custom_format(expr) \ #define to_custom_format(expr) \
custom::to_custom<megdnn::TensorLayout::Format, custom::Format>(expr) ::custom::to_custom<::megdnn::TensorLayout::Format, ::custom::Format>(expr)
#define to_builtin_format(expr) \ #define to_builtin_format(expr) \
custom::to_builtin<megdnn::TensorLayout::Format, custom::Format>(expr) ::custom::to_builtin<::megdnn::TensorLayout::Format, ::custom::Format>(expr)
#define to_custom_tensor(expr) custom::to_custom<DeviceTensorND, custom::Tensor>(expr) #define to_custom_tensor(expr) \
#define to_builtin_tensor(expr) custom::to_builtin<DeviceTensorND, custom::Tensor>(expr) ::custom::to_custom<::mgb::DeviceTensorND, ::custom::Tensor>(expr)
#define to_builtin_tensor(expr) \
::custom::to_builtin<::mgb::DeviceTensorND, ::custom::Tensor>(expr)
...@@ -14,23 +14,28 @@ namespace custom { ...@@ -14,23 +14,28 @@ namespace custom {
* we can add a new basic data type here, basic means we can perform binary * we can add a new basic data type here, basic means we can perform binary
* op such as: +, -, *, /, ==, != between any two of them * op such as: +, -, *, /, ==, != between any two of them
*/ */
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \ // clang-format off
cb(Int32, int32_t, ##__VA_ARGS__) cb(Int64, int64_t, ##__VA_ARGS__) \ #define CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ...) \
cb(Uint32, uint32_t, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \ cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \ cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \ cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format on
#define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) cb(String, std::string, ##__VA_ARGS__) #define CUSTOM_FOR_STRING_PARAMTYPE(cb, ...) cb(String, std::string, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \ // clang-format off
cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \ #define CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ...) \
cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \ cb(Int32List, std::vector<int32_t>, ##__VA_ARGS__) \
cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \ cb(Int64List, std::vector<int64_t>, ##__VA_ARGS__) \
cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \ cb(Uint32List, std::vector<uint32_t>, ##__VA_ARGS__) \
cb(Float32List, std::vector<float>, ##__VA_ARGS__) \ cb(Uint64List, std::vector<uint64_t>, ##__VA_ARGS__) \
cb(Float64List, std::vector<double>, \ cb(Float32List, std::vector<float>, ##__VA_ARGS__) \
##__VA_ARGS__) cb(Float64List, std::vector<double>, ##__VA_ARGS__)
// clang-format on
#define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \ #define CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ...) \
cb(BoolList, std::vector<bool>, ##__VA_ARGS__) cb(BoolList, std::vector<bool>, ##__VA_ARGS__)
...@@ -41,19 +46,26 @@ namespace custom { ...@@ -41,19 +46,26 @@ namespace custom {
/** /**
* to avoid the recursive of MACRO * to avoid the recursive of MACRO
*/ */
#define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \ // clang-format off
cb(Int32, int32_t, ##__VA_ARGS__) cb(Int64, int64_t, ##__VA_ARGS__) \ #define CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(cb, ...) \
cb(Uint32, uint32_t, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \ cb(Int32, int32_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \ cb(Int64, int64_t, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \ cb(Uint32, uint32_t, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__) cb(Uint64, uint64_t, ##__VA_ARGS__) \
cb(Float32, float, ##__VA_ARGS__) \
cb(Float64, double, ##__VA_ARGS__) \
cb(Bool, bool, ##__VA_ARGS__)
// clang-format on
class Device;
#define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \ #define CUSTOM_FOR_EACH_VALID_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \ CUSTOM_FOR_EACH_BASIC_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \ CUSTOM_FOR_STRING_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ CUSTOM_FOR_BOOL_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) CUSTOM_FOR_STRING_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
cb(Device, ::custom::Device, ##__VA_ARGS__)
#define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \ #define CUSTOM_FOR_EACH_LIST_PARAMTYPE(cb, ...) \
CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \ CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(cb, ##__VA_ARGS__) \
......
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
#include "megbrain/custom/op.h" #include "megbrain/custom/op.h"
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <driver_types.h>
namespace custom { namespace custom {
class CudaRuntimeArgs { class MGE_WIN_DECLSPEC_FUC CudaRuntimeArgs {
private: private:
int m_device; int m_device;
cudaStream_t m_stream; cudaStream_t m_stream;
...@@ -20,6 +21,10 @@ public: ...@@ -20,6 +21,10 @@ public:
cudaStream_t stream() const { return m_stream; } cudaStream_t stream() const { return m_stream; }
}; };
const CudaRuntimeArgs get_cuda_runtime_args(const RuntimeArgs& rt_args); MGE_WIN_DECLSPEC_FUC const CudaRuntimeArgs
get_cuda_runtime_args(const RuntimeArgs& rt_args);
MGE_WIN_DECLSPEC_FUC int get_cuda_device_id(Device device);
MGE_WIN_DECLSPEC_FUC const cudaDeviceProp* get_cuda_device_props(Device device);
MGE_WIN_DECLSPEC_FUC cudaStream_t get_cuda_stream(Device device);
} // namespace custom } // namespace custom
...@@ -80,15 +80,21 @@ using bfloat16_t = uint16_t; ...@@ -80,15 +80,21 @@ using bfloat16_t = uint16_t;
cb(custom_dtype, dnn_dtype, c_dtype) cb(custom_dtype, dnn_dtype, c_dtype)
#endif #endif
#define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \ // clang-format off
cb(float32, Float32, float) cb(uint8, Uint8, uint8_t) cb(int8, Int8, int8_t) cb( \ #define CUSTOM_FOR_EACH_TENSOR_DATA_TYPE(cb) \
int16, Int16, int16_t) cb(int32, Int32, int32_t) \ cb(float32, Float32, float) \
fp16_wrap(cb, float16, Float16, float16_t) fp16_wrap( \ cb(uint8, Uint8, uint8_t) \
cb, bfloat16, BFloat16, bfloat16_t) cb(uint16, Uint16, uint16_t) \ cb(int8, Int8, int8_t) \
cb(quint8, Quantized8Asymm, uint8_t) \ cb(int16, Int16, int16_t) \
cb(qint32, QuantizedS32, int32_t) \ cb(int32, Int32, int32_t) \
cb(qint8, QuantizedS8, int8_t) \ fp16_wrap(cb, float16, Float16, float16_t) \
cb(qint16, QuantizedS16, int16_t) fp16_wrap(cb, bfloat16, BFloat16, bfloat16_t) \
cb(uint16, Uint16, uint16_t) \
cb(quint8, Quantized8Asymm, uint8_t) \
cb(qint32, QuantizedS32, int32_t) \
cb(qint8, QuantizedS8, int8_t) \
cb(qint16, QuantizedS16, int16_t)
// clang-format on
#define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type, #define CUSTOM_DTYPE_ENUM_DECL(custom_type, builtin_type, ctype) custom_type,
...@@ -214,13 +220,12 @@ public: ...@@ -214,13 +220,12 @@ public:
float scale(void) const; float scale(void) const;
uint8_t zero_point(void) const; uint8_t zero_point(void) const;
void* data(void); bool is_contiguous() const;
const void* data(void) const; bool is_empty() const;
void* data(void) const;
template <typename T> template <typename T>
T* data(void); T* data(void) const;
template <typename T>
const T* data(void) const;
template < template <
typename T, size_t N, typename T, size_t N,
...@@ -238,16 +243,12 @@ public: ...@@ -238,16 +243,12 @@ public:
}; };
template <typename T> template <typename T>
T* Tensor::data(void) { T* Tensor::data(void) const {
custom_assert( custom_assert(
dtype().is_compatible<T>(), "invalid convert, tensor data type is %s", dtype().is_compatible<T>(), "invalid convert, tensor data type is %s",
dtype().str().c_str()); dtype().str().c_str());
return reinterpret_cast<T*>(data()); return reinterpret_cast<T*>(data());
} }
template <typename T>
const T* Tensor::data(void) const {
return const_cast<Tensor*>(this)->data<T>();
}
template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t> template <typename T, size_t N, template <typename U> class PtrTraits, typename index_t>
const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const { const TensorAccessor<T, N, PtrTraits, index_t> Tensor::accessor() const {
......
...@@ -20,26 +20,28 @@ MGE_WIN_DECLSPEC_FUC void assert_failed_log( ...@@ -20,26 +20,28 @@ MGE_WIN_DECLSPEC_FUC void assert_failed_log(
const char* msg_fmt, ...); const char* msg_fmt, ...);
#ifndef _WIN32 #ifndef _WIN32
#define custom_expect(expr, msg...) \ #define custom_expect(expr, msg...) \
if (!(expr)) { \ if (!(expr)) { \
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \ ::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
} }
#define custom_assert(expr, msg...) \ #define custom_assert(expr, msg...) \
if (!(expr)) { \ if (!(expr)) { \
assert_failed_log(__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \ ::custom::assert_failed_log( \
} \ __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, ##msg); \
} \
assert((expr)) assert((expr))
#else #else
#define custom_expect(expr, ...) \ #define custom_expect(expr, ...) \
if (!(expr)) { \ if (!(expr)) { \
assert_failed_log( \ ::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \ __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
} }
#define custom_assert(expr, ...) \ #define custom_assert(expr, ...) \
if (!(expr)) { \ if (!(expr)) { \
assert_failed_log( \ ::custom::assert_failed_log( \
__FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \ __FILE__, __LINE__, __PRETTY_FUNCTION__, #expr, __VA_ARGS__); \
} \ } \
assert((expr)) assert((expr))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册