#include "megbrain/opr/custom_opnode.h" #include "megbrain/serialization/sereg.h" namespace mgb { namespace serialization { void custom_dumper(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { auto&& custom_op = opr.cast_final_safe(); std::string op_type = custom_op.op_type(); ctx.dump_buf_with_len(op_type.c_str(), op_type.size()); uint32_t tag = custom_op.param_tag(); ctx.dump_buf_with_len(&tag, sizeof(tag)); std::string bytes = custom_op.param().to_bytes(); ctx.dump_buf_with_len(bytes.c_str(), bytes.size()); } mgb::cg::OperatorNodeBase* custom_loader( OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { std::string op_type = ctx.load_buf_with_len(); auto* op_manager = custom::CustomOpManager::inst(); auto op = op_manager->find(op_type); std::string tag_str = ctx.load_buf_with_len(); uint32_t tag = *reinterpret_cast(tag_str.c_str()); mgb_assert( tag == op->param_info().tag(), "Wrong Param TAG of Op %s, should be %u, but load %u\n", op_type.c_str(), op->param_info().tag(), tag); custom::Param param(op->param_info()); std::string bytes = ctx.load_buf_with_len(); param.from_bytes(bytes); return opr::CustomOpNode::make(op, inputs, param, config)[0]->owner_opr(); } } // namespace serialization } // namespace mgb #define CUSTOM_OP_SEREG_REG(cls) \ namespace { \ struct _OprReg##cls { \ static void entry() { \ MGB_SEREG_OPR_INTL_CALL_ADD( \ cls, ::mgb::serialization::custom_dumper, \ ::mgb::serialization::custom_loader, true); \ } \ }; \ } \ MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) #define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max) \ namespace { \ struct _OprRegV2##cls { \ static void entry() { \ MGB_SEREG_OPR_INTL_CALL_ADD_V2( \ cls, ::mgb::serialization::custom_dumper, \ ::mgb::serialization::custom_loader, nullptr, _version_min, \ _version_max); \ } \ }; \ } \ MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls) using namespace mgb; using CustomOpNode = opr::CustomOpNode; CUSTOM_OP_SEREG_REG(CustomOpNode); CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);