custom_opnode.sereg.h 3.2 KB
Newer Older
1 2 3 4 5
#include "megbrain/opr/custom_opnode.h"
#include "megbrain/serialization/sereg.h"

namespace mgb {
namespace serialization {
M
Megvii Engine Team 已提交
6

7
void custom_dumper(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
M
Megvii Engine Team 已提交
8
    auto&& custom_op = opr.cast_final_safe<opr::CustomOpNode>();
9 10 11

    std::string op_type = custom_op.op_type();
    ctx.dump_buf_with_len(op_type.c_str(), op_type.size());
M
Megvii Engine Team 已提交
12

13 14 15 16
    uint32_t tag = custom_op.param_tag();
    ctx.dump_buf_with_len(&tag, sizeof(tag));

    std::string bytes = custom_op.param().to_bytes();
M
Megvii Engine Team 已提交
17
    ctx.dump_buf_with_len(bytes.c_str(), bytes.size());
18 19
}

M
Megvii Engine Team 已提交
20 21 22
mgb::cg::OperatorNodeBase* custom_loader(
        OprLoadContext& ctx, const cg::VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
23
    std::string op_type = ctx.load_buf_with_len();
M
Megvii Engine Team 已提交
24
    auto* op_manager = custom::CustomOpManager::inst();
25 26 27 28 29
    auto op = op_manager->find(op_type);

    std::string tag_str = ctx.load_buf_with_len();
    uint32_t tag = *reinterpret_cast<const uint32_t*>(tag_str.c_str());
    mgb_assert(
M
Megvii Engine Team 已提交
30 31 32
            tag == op->param_info().tag(),
            "Wrong Param TAG of Op %s, should be %u, but load %u\n", op_type.c_str(),
            op->param_info().tag(), tag);
33 34 35 36 37 38 39

    custom::Param param(op->param_info());
    std::string bytes = ctx.load_buf_with_len();
    param.from_bytes(bytes);
    return opr::CustomOpNode::make(op, inputs, param, config)[0]->owner_opr();
}

M
Megvii Engine Team 已提交
40 41
}  // namespace serialization
}  // namespace mgb
42

M
Megvii Engine Team 已提交
43 44 45 46 47 48 49 50 51 52
#define CUSTOM_OP_SEREG_REG(cls)                              \
    namespace {                                               \
    struct _OprReg##cls {                                     \
        static void entry() {                                 \
            MGB_SEREG_OPR_INTL_CALL_ADD(                      \
                    cls, ::mgb::serialization::custom_dumper, \
                    ::mgb::serialization::custom_loader);     \
        }                                                     \
    };                                                        \
    }                                                         \
53 54
    MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls)

55 56 57 58 59 60 61 62 63 64 65 66 67
#define CUSTOM_OP_SEREG_REG_V2(cls, _version_min, _version_max)                 \
    namespace {                                                                 \
    struct _OprRegV2##cls {                                                     \
        static void entry() {                                                   \
            MGB_SEREG_OPR_INTL_CALL_ADD_V2(                                     \
                    cls, ::mgb::serialization::custom_dumper,                   \
                    ::mgb::serialization::custom_loader, nullptr, _version_min, \
                    _version_max);                                              \
        }                                                                       \
    };                                                                          \
    }                                                                           \
    MGB_SEREG_OPR_INTL_CALL_ENTRY_V2(cls, _OprRegV2##cls)

68 69 70
using namespace mgb;
using CustomOpNode = opr::CustomOpNode;
CUSTOM_OP_SEREG_REG(CustomOpNode);
71 72

CUSTOM_OP_SEREG_REG_V2(CustomOpNode, 2, CURRENT_VERSION);