#include "megbrain/common.h" #if MGB_CUSTOM_OP #include "megbrain/comp_node.h" #include "megbrain/custom/adaptor.h" #include "megbrain/custom/param_val.h" #include "megbrain/custom/tensor.h" #pragma GCC diagnostic ignored "-Wsign-compare" using namespace mgb; namespace custom { /** * Macro Callback for Case */ #define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ std::unique_ptr new_ptr( \ new static_type(TypedRef(static_type, rhs.m_ptr.get())), \ impl_deleter); \ m_ptr.swap(new_ptr); \ break; \ } #define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get()); \ break; \ } #define CUSTOM_ASSERT_OPERAND_VALID(operand, opr) \ mgb_assert( \ operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \ "invalid %s of operator %s of ParamVal", #operand, #opr) #define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op) \ mgb_assert( \ lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \ ptype2name(lhs.m_type).c_str(), #op, ptype2name(rhs.m_type).c_str()) #define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \ case (ParamDynType::dyn_type): { \ const auto& rval = TypedRef(static_type, rhs.m_ptr.get()); \ return lval op rval; \ } #define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \ case (ParamDynType::dyn_type): { \ const auto& lval = TypedRef(static_type, lhs.m_ptr.get()); \ switch (rhs.m_type) { \ CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY( \ CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op) \ default: \ CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ } \ break; \ } #define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \ case (ParamDynType::dyn_type): { \ CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ const auto& lval = TypedRef(static_type, lhs.m_ptr.get()); \ const auto& rval = TypedRef(static_type, rhs.m_ptr.get()); \ return lval op rval; \ } #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type) \ ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) { \ CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ \ switch (lhs.m_type) { \ CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ default: \ CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ } \ return {}; \ } #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type) \ ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) { \ CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ \ switch (lhs.m_type) { \ CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ default: \ CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ } \ return {}; \ } #define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type) \ ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) { \ CUSTOM_ASSERT_OPERAND_VALID(lhs, op); \ CUSTOM_ASSERT_OPERAND_VALID(rhs, op); \ \ switch (lhs.m_type) { \ CUSTOM_FOR_EACH_BASIC_PARAMTYPE( \ CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op) \ CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ CUSTOM_FOR_EACH_LIST_PARAMTYPE( \ CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \ default: \ CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op); \ } \ return {}; \ } #define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ auto rval = TypedRef(static_type, m_ptr.get()); \ ss << rval; \ break; \ } #define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ auto rval = TypedRef(static_type, m_ptr.get()); \ ss << vec2str(rval); \ break; \ } #define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ return TypedRef(static_type, m_ptr.get()).size(); \ break; \ } #define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ res.resize(sizeof(ParamDynType) + sizeof(static_type)); \ memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \ break; \ } #define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ auto& ref = TypedRef(static_type, value.m_ptr.get()); \ size_t len = ref.size(); \ size_t elem_size = len != 0 ? sizeof(ref[0]) : 0; \ res.resize(sizeof(ParamDynType) + sizeof(len) + len * elem_size); \ memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); \ memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); \ memcpy(&res[sizeof(ParamDynType) + sizeof(len)], ref.data(), len* elem_size); \ break; \ } #define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ static_type val; \ memcpy(&val, &bytes[offset], sizeof(val)); \ offset += sizeof(val); \ return val; \ break; \ } #define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type) \ case (ParamDynType::dyn_type): { \ size_t len = 0; \ memcpy(&len, &bytes[offset], sizeof(len)); \ offset += sizeof(len); \ static_type vals; \ vals.resize(len); \ size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \ memcpy(&vals[0], &bytes[offset], len* elem_size); \ offset += len * elem_size; \ return vals; \ break; \ } std::string ptype2name(ParamDynType ptype) { #define CUSTOM_REG_DYN_PARAMTYPE_NAME(dyn_type, static_type) \ {ParamDynType::dyn_type, #dyn_type}, static std::unordered_map< ParamDynType, std::string, EnumHash, EnumCmp> type2name = {CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_REG_DYN_PARAMTYPE_NAME){ ParamDynType::Invalid, "Invalid"}}; #undef CUSTOM_REG_DYN_PARAMTYPE_NAME return type2name[ptype]; } ParamVal::ParamVal() : m_ptr(nullptr, [](void*) -> void {}) { m_type = ParamDynType::Invalid; } ParamVal::ParamVal(const char* str) : ParamVal(std::string(str)) {} ParamVal::ParamVal(const std::initializer_list& strs) : ParamVal(std::vector(strs)) {} ParamVal::ParamVal(const std::vector& strs) : m_ptr(new std::vector(), impl_deleter>) { m_type = ParamDynType::StringList; for (const auto& str : strs) { TypedRef(std::vector, m_ptr.get()).emplace_back(str); } } ParamVal::ParamVal(const ParamVal& rhs) : m_ptr(nullptr, [](void*) -> void {}) { mgb_assert( rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, "invalid rhs of copy constructor of ParamVal"); m_type = rhs.m_type; switch (m_type) { CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS) default: { mgb_assert(false, "invalid rhs of copy constructor of ParamVal"); } } } ParamVal& ParamVal::operator=(const char* str) { this->operator=(std::string(str)); return *this; } ParamVal& ParamVal::operator=(const std::initializer_list& strs) { this->operator=(std::vector(strs)); return *this; } ParamVal& ParamVal::operator=(const std::vector& strs) { std::vector tmp_strs; for (const auto& str : strs) { tmp_strs.emplace_back(str); } this->operator=(tmp_strs); return *this; } ParamVal& ParamVal::operator=(const ParamVal& rhs) { if (&rhs == this) return *this; mgb_assert( rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr, "invalid rhs of assignment operator of ParamVal"); if (rhs.m_type == m_type) { switch (m_type) { CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS); default: mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); } } else { m_type = rhs.m_type; switch (m_type) { CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS); default: mgb_assert(false, "invalid rhs of assignment operator of ParamVal"); } } return *this; } const void* ParamVal::raw_ptr(void) const { return m_ptr.get(); } void* ParamVal::raw_ptr(void) { return m_ptr.get(); } ParamDynType ParamVal::type(void) const { return m_type; } std::string ParamVal::str() const { std::stringstream ss; ss << "type: " << ptype2name(m_type) << "\n" << "value: "; switch (m_type) { CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST) CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST) case ParamDynType::Device: { auto&& rval = TypedRef(Device, m_ptr.get()); ss << to_builtin_device(rval).to_string(); break; } default: mgb_assert(false, "invalid data of assignment operator of ParamVal"); } return ss.str(); } size_t ParamVal::size(void) const { switch (m_type) { CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE) default: mgb_assert(false, "there is no size() for basic data types"); } } std::string ParamVal::to_bytes(const ParamVal& value) { std::string res; // because the specialization of std::vector if (value.type() == ParamDynType::BoolList) { std::vector& ref = TypedRef(std::vector, value.m_ptr.get()); size_t len = ref.size(); size_t elem_size = sizeof(bool); res.resize(sizeof(ParamDynType) + sizeof(len) + len * elem_size); memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); size_t startpos = sizeof(ParamDynType) + sizeof(len); for (size_t idx = 0; idx < len; idx++) { bool b = ref[idx]; memcpy(&res[startpos + idx * sizeof(b)], &b, sizeof(b)); } return res; } else if (value.type() == ParamDynType::StringList) { std::vector& ref = TypedRef(std::vector, value.m_ptr.get()); size_t len = ref.size(); res.resize(sizeof(ParamDynType) + sizeof(len)); memcpy(&res[0], &(value.m_type), sizeof(ParamDynType)); memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len)); for (size_t idx = 0; idx < ref.size(); ++idx) { size_t str_len = ref[idx].size(); std::string bytes(sizeof(str_len) + str_len, ' '); memcpy(&bytes[0], &str_len, sizeof(str_len)); memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len); res += bytes; } return res; } switch (value.type()) { CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC) CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST) default: mgb_assert(false, "invalid param type"); } return res; } ParamVal ParamVal::from_bytes(const std::string& bytes, size_t& offset) { ParamDynType data_type = ParamDynType::Invalid; memcpy(&data_type, &bytes[offset], sizeof(ParamDynType)); offset += sizeof(ParamDynType); if (data_type == ParamDynType::BoolList) { std::vector ret; size_t len = 0; memcpy(&len, &bytes[offset], sizeof(len)); offset += sizeof(len); for (size_t idx = 0; idx < len; ++idx) { bool b = true; memcpy(&b, &bytes[offset], sizeof(bool)); offset += sizeof(bool); ret.push_back(b); } return ret; } else if (data_type == ParamDynType::StringList) { std::vector ret; size_t len = 0; memcpy(&len, &bytes[offset], sizeof(len)); offset += sizeof(len); for (size_t idx = 0; idx < len; ++idx) { size_t str_len = 0; memcpy(&str_len, &bytes[offset], sizeof(str_len)); offset += sizeof(str_len); std::string str(str_len, ' '); memcpy(&str[0], &bytes[offset], str_len); offset += str_len; ret.push_back(str); } return ret; } switch (data_type) { CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC) CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST) CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST); default: mgb_assert(false, "invalid param type"); } return {}; } CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool) CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool) } // namespace custom #endif