diff --git a/src/serialization/impl/batched_device_value_loader.cpp b/src/serialization/impl/batched_device_value_loader.cpp index 255c4e0ffaf737f52cda0cf08c7d22c6177c3705..32c2a5df0eada61481462327a818759a51c68889 100644 --- a/src/serialization/impl/batched_device_value_loader.cpp +++ b/src/serialization/impl/batched_device_value_loader.cpp @@ -1,5 +1,4 @@ -#include "batched_device_value_loader.h" - +#include "megbrain/serialization/batched_device_value_loader.h" #include "megbrain/utils/arith_helper.h" namespace mgb { diff --git a/src/serialization/impl/serializer.cpp b/src/serialization/impl/serializer.cpp index e220bd3512eafe611e53160f3ca553e1586b0f1f..17fbc6227d01d5b40b51b4c453c26f250c3866ee 100644 --- a/src/serialization/impl/serializer.cpp +++ b/src/serialization/impl/serializer.cpp @@ -57,7 +57,11 @@ GraphLoader::SharedTensorNameMap GraphLoader::shared_tensor_name_map() { } std::unique_ptr make_fbs_loader(std::unique_ptr file); std::unique_ptr make_fbs_dumper(std::unique_ptr file); + +std::unique_ptr make_fbs_v2_loader(std::unique_ptr file); +std::unique_ptr make_fbs_v2_dumper(std::unique_ptr file); bool is_fbs_file(InputFile& file); +bool is_fbs_v2_file(InputFile& file); bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) { #if MGB_ENABLE_GRAD @@ -73,6 +77,11 @@ std::unique_ptr GraphDumper::make( case GraphDumpFormat::FLATBUFFERS: #if MGB_ENABLE_FBS_SERIALIZATION return make_fbs_dumper(std::move(file)); +#endif + MGB_FALLTHRU + case GraphDumpFormat::FLATBUFFERS_V2: +#if MGB_ENABLE_FBS_SERIALIZATION + return make_fbs_v2_dumper(std::move(file)); #endif MGB_FALLTHRU default: @@ -87,6 +96,11 @@ std::unique_ptr GraphLoader::make( case GraphDumpFormat::FLATBUFFERS: #if MGB_ENABLE_FBS_SERIALIZATION return make_fbs_loader(std::move(file)); +#endif + MGB_FALLTHRU + case GraphDumpFormat::FLATBUFFERS_V2: +#if MGB_ENABLE_FBS_SERIALIZATION + return make_fbs_v2_loader(std::move(file)); #endif MGB_FALLTHRU default: @@ -100,6 +114,9 @@ Maybe GraphLoader::identify_graph_dump_format(InputFile& file) if (is_fbs_file(file)) { return GraphDumpFormat::FLATBUFFERS; } + if (is_fbs_v2_file(file)) { + return GraphDumpFormat::FLATBUFFERS_V2; + } #endif return {}; } diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index d7fed1df5b4f2e75de603d1d510f34f3a3f0f856..39e0bda6e288356897c8eaa110ded8f84f1b8d49 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -11,17 +11,16 @@ */ #if MGB_ENABLE_FBS_SERIALIZATION -#include "batched_device_value_loader.h" - #include "megbrain/graph/exc_extra_info.h" #include "megbrain/opr/io.h" +#include "megbrain/serialization/batched_device_value_loader.h" #include "megbrain/serialization/helper.h" #include "megbrain/serialization/internal/flatbuffers_helper.h" #include "megbrain/serialization/internal/schema_generated.h" #include "megbrain/serialization/metadata.h" #include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/serializer.h" -#include "megbrain/version.h" +#include "serializer_oss_common.h" #include @@ -33,47 +32,8 @@ using namespace mgb; using namespace mgb::serialization; namespace { - -constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; - -constexpr uint32_t MGB_MAGIC = 0x4342474D; -// In order to maintain compatibility and to allow old models to be loaded, we keep -// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) -constexpr uint32_t MAGIC_V0 = 0x5342474D; -// Used to judge whether Magic is old or new, the new magic(MGB_MAGIC) is true and the -// old magic(MAGIC_V0) is false. bool magic_compare = true; -template -bool contains_any_in_set(const SmallVector& list, const ThinHashSet& set) { - for (const auto& x : list) { - if (set.count(x)) { - return true; - } - } - return false; -} - -void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) { - bool cond_normal = tensor.layout().format.is_default() && - tensor.layout().is_physical_contiguous(); - bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() && - tensor.layout().format.is_lowbit_aligned() && - tensor.layout().is_contiguous(); - mgb_assert( - cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s", - name.c_str(), tensor.layout().to_string().c_str()); - if (tensor.dtype() == dtype::Float32()) { - auto ptr = tensor.ptr(); - for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) { - if (!std::isfinite(ptr[i])) { - mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]); - break; - } - } - } -} - //! feature bits for backward compatibility; default value should be 0 struct FeatureBits64 { //! reserved for new fields @@ -947,13 +907,6 @@ std::unique_ptr make_fbs_loader(std::unique_ptr file) { return std::make_unique(std::move(file)); } -bool is_fbs_file(InputFile& file) { - uint64_t magic_with_reserved = 0; - file.read(&magic_with_reserved, sizeof(magic_with_reserved)); - file.skip(-sizeof(magic_with_reserved)); - return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); -} - } // namespace serialization } // namespace mgb diff --git a/src/serialization/impl/serializer_oss_common.cpp b/src/serialization/impl/serializer_oss_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3519043c1aadb43a986935a11c8f0aac0aba4ba8 --- /dev/null +++ b/src/serialization/impl/serializer_oss_common.cpp @@ -0,0 +1,39 @@ +#if MGB_ENABLE_FBS_SERIALIZATION + +#include "serializer_oss_common.h" + +namespace mgb { +namespace serialization { + +bool is_fbs_file(InputFile& file) { + //! check whether the model format is flatbuffer v2 + uint64_t magic_with_reserved = 0; + file.read(&magic_with_reserved, sizeof(magic_with_reserved)); + file.skip(-sizeof(magic_with_reserved)); + return (magic_with_reserved == MGB_MAGIC) || (magic_with_reserved == MAGIC_V0); +} + +void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor) { + bool cond_normal = tensor.layout().format.is_default() && + tensor.layout().is_physical_contiguous(); + bool cond_lowbit = tensor.layout().dtype.is_quantized_lowbit() && + tensor.layout().format.is_lowbit_aligned() && + tensor.layout().is_contiguous(); + mgb_assert( + cond_normal || cond_lowbit, "non-contiguous tensor: name=%s layout=%s", + name.c_str(), tensor.layout().to_string().c_str()); + if (tensor.dtype() == dtype::Float32()) { + auto ptr = tensor.ptr(); + for (size_t i = 0, it = tensor.shape().total_nr_elems(); i < it; ++i) { + if (!std::isfinite(ptr[i])) { + mgb_log_warn("invalid tensor value in %s: %g", name.c_str(), ptr[i]); + break; + } + } + } +} + +} // namespace serialization +} // namespace mgb + +#endif diff --git a/src/serialization/impl/serializer_oss_common.h b/src/serialization/impl/serializer_oss_common.h new file mode 100644 index 0000000000000000000000000000000000000000..94dc2da751fff7896550be73d8ab36a7561d6e62 --- /dev/null +++ b/src/serialization/impl/serializer_oss_common.h @@ -0,0 +1,32 @@ +#if MGB_ENABLE_FBS_SERIALIZATION + +#include "megbrain/serialization/serializer.h" +#include "megbrain/version.h" + +namespace mgb { +namespace serialization { + +constexpr uint32_t MGB_VERSION = (MGE_MAJOR * 1000 + MGE_MINOR) * 100 + MGE_PATCH; + +constexpr uint32_t MGB_MAGIC = 0x4342474D; + +// In order to maintain compatibility and to allow old models to be loaded, we keep +// the old magic(MAGIC_V0) value and creat a new magic(MGB_MAGIC) +constexpr uint32_t MAGIC_V0 = 0x5342474D; + +void check_tensor_value_valid(const std::string& name, const HostTensorND& tensor); + +template +bool contains_any_in_set(const SmallVector& list, const ThinHashSet& set) { + for (const auto& x : list) { + if (set.count(x)) { + return true; + } + } + return false; +} + +} // namespace serialization +} // namespace mgb + +#endif diff --git a/src/serialization/impl/serializer_oss_v2.cpp b/src/serialization/impl/serializer_oss_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5a9a9fe49adecfbb73b3ea7e1947a5e05fbdc0a --- /dev/null +++ b/src/serialization/impl/serializer_oss_v2.cpp @@ -0,0 +1,847 @@ +#if MGB_ENABLE_FBS_SERIALIZATION + +#include "megbrain/comp_node_env.h" +#include "megbrain/opr/io.h" +#include "megbrain/serialization/helper.h" +#include "megbrain/serialization/internal/flatbuffers_helper.h" +#include "megbrain/serialization/internal/schema_v2_generated.h" +#include "megbrain/serialization/metadata.h" +#include "megbrain/serialization/opr_load_dump.h" +#include "megbrain/serialization/oss_opr_load_dump.h" +#include "megbrain/utils/hash_ct.h" +#include "megdnn/tensor_format.h" +#include "serializer_oss_common.h" + +#include "megbrain/gopt/framework.h" + +namespace mgb { +namespace serialization { + +/*! + * \brief replace the the opr who has the replace_opr methord in OprLoadDumpImplV2 + */ +class PassConvertToCompatible : public gopt::Pass { + ThinHashMap< + Typeinfo*, thin_function> + m_opr_replace_func; + gopt::VarReplaceCheckFlag m_var_replace_check_flag = + gopt::VarReplaceCheckFlag::CHECK_ALL; + +public: + const char* name() const override { return "PassConvertToCompatible"; }; + + PassConvertToCompatible& set_var_replace_check_flag( + gopt::VarReplaceCheckFlag flag) { + m_var_replace_check_flag = flag; + return *this; + } + + void apply(gopt::OptState& state) const override { + state.set_var_replace_check_flag(m_var_replace_check_flag); + auto rewriter = state.graph().make_rewriter(); + + auto on_opr = [this, &rewriter](cg::OperatorNodeBase* opr) { + auto it = m_opr_replace_func.find(opr->dyn_typeinfo()); + if (it != m_opr_replace_func.end()) { + VarNodeArray new_inp; + new_inp.clear(); + new_inp.reserve(opr->input().size()); + for (auto i : opr->input()) { + new_inp.push_back(rewriter.get_var(i)); + } + auto new_opr = (it->second)(opr, new_inp); + + auto &&origin_out = opr->output(), &&cur_out = new_opr->output(); + for (size_t i = 0; i < std::min(origin_out.size(), cur_out.size()); + i++) { + rewriter.replace_var(origin_out[i], cur_out[i], nullptr); + } + } else { + rewriter.auto_replace_outputs(opr); + } + }; + state.graph().iter(on_opr); + rewriter.apply_inplace(); + } + + static std::unique_ptr make( + const SymbolVarArray& output_vars) { + auto ret = std::make_unique(); + // iterate oprs to init + auto on_opr = [&](cg::OperatorNodeBase* opr) { + if (!GraphDumper::should_remove_in_dump(opr)) { + auto registry = OprRegistryV2::versioned_find_by_typeinfo( + opr->dyn_typeinfo(), CURRENT_VERSION); + mgb_throw_if( + !registry, + cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make, + "serialization as FlatBuffers is not supported for " + "operator %s, typeinfo %p", + opr->dyn_typeinfo()->name, opr->dyn_typeinfo()); + if (registry->converter) { + ret->m_opr_replace_func[opr->dyn_typeinfo()] = registry->converter; + } + } + }; + cg::DepOprIter dep_opr_iter{on_opr}; + for (auto i : output_vars) { + dep_opr_iter.add(i.node()->owner_opr()); + } + return ret; + }; +}; + +namespace { +fbs::v2::TensorFormat get_flatbuffer_tensor_format_type( + const TensorLayout::Format& format) { + using Type = megdnn::TensorFormat::Type; + switch (format.type()) { + case Type::DEFAULT: + return fbs::v2::TensorFormat::TensorFormat_DefaultTensorFormat; + case Type::IMAGE2D_PACK4: + return fbs::v2::TensorFormat::TensorFormat_Image2DPackedTensorFormat; + case Type::LOWBITS_ALIGNED_TO_BYTE: + return fbs::v2::TensorFormat::TensorFormat_LowbitsAlignedTensorFormat; + default: + mgb_throw( + SerializationError, "invalid tensor format type in serialization."); + } +} +} // namespace + +flatbuffers::Offset GraphDumperOSSV2::build_dtype(DType dtype) { + return fbs::intl::build_dtype(m_builder, dtype); +} + +flatbuffers::Offset GraphDumperOSSV2::build_tensor_format( + const TensorLayout::Format& format) { + using Type = megdnn::TensorFormat::Type; + switch (format.type()) { + case Type::DEFAULT: + return fbs::v2::CreateDefaultTensorFormat(m_builder).Union(); + case Type::IMAGE2D_PACK4: + return fbs::v2::CreateImage2DPackedTensorFormat( + m_builder, format.as_impl() + .align_axis()) + .Union(); + case Type::LOWBITS_ALIGNED_TO_BYTE: { + auto size_bite = format.as_impl() + .size_nbits(); + auto align_size_in_bits = + format.as_impl() + .align_size_in_bits(); + return fbs::v2::CreateLowbitsAlignedTensorFormat( + m_builder, size_bite, align_size_in_bits) + .Union(); + } + default: + mgb_throw( + SerializationError, "invalid tensor format type in serialization."); + } +} + +flatbuffers::Offset GraphDumperOSSV2::build_middle_tensor( + const SymbolVar var) { + mgb_assert(var.node()); + auto fbname = m_builder.CreateSharedString(var.node()->name()); + flatbuffers::Offset serialized_middle_tensor; + if (var.node()->dev_tensor_valid()) { + auto layout = var.node()->layout(); + auto fshape = + m_builder.CreateVectorScalarCast(layout.shape, layout.ndim); + + auto fcomp_node = fbs::v2::CreateCompNode( + m_builder, m_builder.CreateSharedString( + var.node()->comp_node().to_string_logical())); + + auto fdtype = build_dtype(layout.dtype); + auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); + auto fformat = build_tensor_format(layout.format); + serialized_middle_tensor = fbs::v2::CreateMiddleTensor( + m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat); + } + serialized_middle_tensor = fbs::v2::CreateMiddleTensor(m_builder, fbname); + return serialized_middle_tensor; +} + +flatbuffers::Offset GraphDumperOSSV2::build_output_var( + const SymbolVar var) { + auto out_node = var.node(); + if (m_var2midtensor_id.find(var.node()) == m_var2midtensor_id.end()) { + mgb_assert(m_var_remove_in_dump.find(var.node()) != m_var_remove_in_dump.end()); + out_node = m_var_remove_in_dump[var.node()]; + } + return fbs::v2::CreateOutputVar( + m_builder, m_var2midtensor_id.at(out_node), var.node()->id()); +} + +void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) { + m_oprs_to_dump.clear(); + + // iterate oprs to init + auto on_opr = [&](cg::OperatorNodeBase* opr) { + if (should_remove_in_dump(opr)) { + mgb_assert(opr->input().size() == 1); + // Copy input ID to output + for (auto i : opr->output()) { + if (m_var_remove_in_dump.find(opr->input(0)) != + m_var_remove_in_dump.end()) { + m_var_remove_in_dump[i] = m_var_remove_in_dump[opr->input(0)]; + } else { + m_var_remove_in_dump[i] = opr->input(0); + } + } + } else { + auto registry = OprRegistryV2::versioned_find_by_typeinfo( + opr->dyn_typeinfo(), CURRENT_VERSION); + if (!registry || !registry->dumper) { + mgb_throw( + cg::OperatorNodeExcExtraInfo::ExcMaker{opr}.make, + "serialization as FlatBuffers is not supported for " + "operator %s", + opr->dyn_typeinfo()->name); + } + m_oprs_to_dump.emplace_back(opr, registry); + } + }; + cg::DepOprIter dep_opr_iter{on_opr}; + for (auto i : endpoints) { + dep_opr_iter.add(i.node()->owner_opr()); + } +} + +flatbuffers::Offset GraphDumperOSSV2::build_metadata( + const Metadata& metadata) { + auto user_info = m_builder.CreateSharedString(metadata.user_info); + fbs::v2::MetadataBuilder builder(m_builder); + builder.add_is_valid(metadata.is_valid); + builder.add_graph_modified(metadata.graph_modified); + builder.add_optimize_options(metadata.optimize_options); + builder.add_user_info(user_info); + return builder.Finish(); +} + +flatbuffers::Offset GraphDumperOSSV2::build_single_opr( + cg::OperatorNodeBase* opr, const OprRegistryV2* registry) { + m_cur_opr = opr; + ++m_cur_rst.nr_opr; + + using namespace flatbuffers; + Offset> inputs; + if (m_cur_opr->input().size()) { + std::vector v; + v.reserve(m_cur_opr->input().size()); + for (auto inp : m_cur_opr->input()) { + if (m_var2midtensor_id.find(inp) != m_var2midtensor_id.end()) { + v.emplace_back(m_var2midtensor_id.at(inp)); + } else { + mgb_assert( + m_var_remove_in_dump.find(inp) != m_var_remove_in_dump.end(), + "when dump the model, the dependence of var is wrong."); + v.emplace_back(m_var2midtensor_id.at(m_var_remove_in_dump[inp])); + } + } + inputs = m_builder.CreateVector(v); + } + + m_cur_opr_tensor.clear(); + m_blobs.clear(); + m_cur_opr_param.clear(); + m_cur_opr_param_type.clear(); + registry->dumper(*this, *m_cur_opr); + + Offset>> comp_node; + auto& config = m_cur_opr->config(); + if (config.has_comp_node_set()) { + std::vector> cns; + for (const auto& cn : config.comp_node()) { + cns.emplace_back(fbs::v2::CreateCompNode( + m_builder, m_builder.CreateSharedString(cn.to_string_logical()))); + } + comp_node = m_builder.CreateVector(cns); + } + Offset operator_name; + if (m_config.keep_op_name) { + operator_name = m_builder.CreateSharedString(m_cur_opr->name()); + } + + auto output_dtype = build_dtype(config.output_dtype()); + + Offset> outputs; + if (m_cur_opr->output().size()) { + std::vector v; + v.reserve(m_cur_opr->output().size()); + for (auto out : m_cur_opr->output()) { + if (!out->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + auto fbs_out = build_middle_tensor(out); + m_model_middle_tensors.push_back(fbs_out); + m_var2midtensor_id[out] = m_model_middle_tensors.size() - 1; + v.emplace_back(m_var2midtensor_id.at(out)); + } + } + outputs = m_builder.CreateVector(v); + } + + Offset>> tensors; + if (m_cur_opr_tensor.size()) + tensors = m_builder.CreateVector(m_cur_opr_tensor); + + //! the blobs data is used by custom data + //! m_blobs will be filled by the Operator dumper function + Offset>> blobs; + if (m_blobs.size()) + blobs = m_builder.CreateVector(m_blobs); + + Offset> additional_params_type; + Offset>> additional_params; + auto param_cnt = m_cur_opr_param_type.size(); + if (param_cnt > 1) { + additional_params_type = m_builder.CreateVectorScalarCast( + m_cur_opr_param_type.data() + 1, param_cnt - 1); + additional_params = + m_builder.CreateVector(m_cur_opr_param.data() + 1, param_cnt - 1); + } + auto opr_type = m_builder.CreateSharedString(registry->name); + + fbs::v2::OperatorBuilder builder(m_builder); + builder.add_type(opr_type); + builder.add_type_id(registry->type_id); + builder.add_inputs(inputs); + builder.add_outputs(outputs); + if (m_config.keep_opr_priority) { + builder.add_priority(opr->node_prop().attribute().priority); + } + builder.add_comp_node(comp_node); + builder.add_opr_version(registry->get_version()); + builder.add_name(operator_name); + builder.add_output_dtype(output_dtype); + if (param_cnt > 0) { + builder.add_param_type(m_cur_opr_param_type[0]); + builder.add_param(m_cur_opr_param[0]); + } + if (param_cnt > 1) { + builder.add_additional_params_type(additional_params_type); + builder.add_additional_params(additional_params); + } + builder.add_tensors(tensors); + builder.add_custom_data(blobs); + m_cur_opr = nullptr; + return builder.Finish(); +} + +SymbolVarArray GraphDumperOSSV2::converter_all_opr_to_compatiable( + const SymbolVarArray& output_vars) { + gopt::GraphOptimizer optimizer; + VarNodeArray rets_var; + for (auto& symbolvar : output_vars) { + rets_var.push_back(symbolvar.node()); + } + optimizer.add_pass(PassConvertToCompatible::make(output_vars)); + optimizer.apply_inplace(rets_var); + + SymbolVarArray dst_vars; + for (auto& var : rets_var) { + dst_vars.push_back({var}); + } + return dst_vars; +} + +GraphDumper::DumpResult GraphDumperOSSV2::dump( + const SymbolVarArray& output_vars, const DumpConfig& config, + const Metadata& metadata) { + mgb_throw_if(output_vars.empty(), SerializationError, "Can't dump empty graph"); + + auto&& new_output_vars = converter_all_opr_to_compatiable(output_vars); + + auto begin_pos = m_file->tell(); + m_config = config; + m_builder.Reset(); + + m_output_vars.clear(); + m_cur_rst = {}; + m_used_input_names.clear(); + m_used_param_names.clear(); + m_var_remove_in_dump.clear(); + m_model_middle_tensors.clear(); + m_var2midtensor_id.clear(); + m_nr_shared_tensor = 0; + + // process output vars + bool keep_output_var_name = m_config.keep_var_name >= 1; + std::unordered_set output_var_names; + for (auto i : new_output_vars) { + mgb_assert( + !i.node()->contain_flag(VarNode::Flag::VOLATILE_CONTENT), + "can not dump var with VOLATILE_CONTENT flag: %s", + cg::dump_var_info({i.node()}).c_str()); + if (m_output_vars.insert(i.node()).second && keep_output_var_name) { + auto name_ins = output_var_names.insert(i.node()->name()).second; + mgb_assert(name_ins, "duplicated output var name: %s", i.node()->cname()); + } + } + + // Dump metadata + auto fbmeta = build_metadata(metadata); + + // Dump operators + init_oprs_to_dump(new_output_vars); + std::vector> oprs; + for (auto&& i : m_oprs_to_dump) { + oprs.emplace_back(build_single_opr(i.first, i.second)); + } + auto fb_oprs = m_builder.CreateVector(oprs); + + // Dump output vars + std::vector> output_vars_idx; + output_vars_idx.reserve(new_output_vars.size()); + for (auto i : new_output_vars) { + auto foutput_vars_idx = build_output_var(i); + output_vars_idx.push_back(foutput_vars_idx); + } + auto fb_output_vars = m_builder.CreateVector(output_vars_idx); + auto fb_mid_tensor = m_builder.CreateVector(m_model_middle_tensors); + + fbs::v2::ModelBuilder model(m_builder); + model.add_mge_version(MGB_VERSION); + model.add_oprs(fb_oprs); + model.add_middle_tensors(fb_mid_tensor); + model.add_output_vars_idx(fb_output_vars); + model.add_nr_shared_tensor(m_nr_shared_tensor); + model.add_metadata(fbmeta); + m_builder.FinishSizePrefixed(model.Finish(), fbs::v2::ModelIdentifier()); + + // Write serialized fbs::Graph + m_file->write(m_builder.GetBufferPointer(), m_builder.GetSize()); + + // Finalize DumpResult + auto&& ret = m_cur_rst; + for (size_t i = 0; i < new_output_vars.size(); i++) { + ret.outputs.emplace_back( + keep_output_var_name ? new_output_vars[i].node()->cname() + : ssprintf("unnamed%zu", i)); + } + std::sort(ret.inputs.begin(), ret.inputs.end()); + mgb_assert(ret.nr_opr == m_oprs_to_dump.size()); + ret.tot_bytes = m_file->tell() - begin_pos; + return ret; +} + +void GraphDumperOSSV2::dump_tensor( + const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { + using namespace flatbuffers; + using Meth = TensorWriteMethod; + mgb_assert( + (method == Meth::VALUE_ANONYMOUS) ^ (!name.empty()), + "name must be non-empty for non Meth::VALUE_ANONYMOUS tensors"); + + bool has_value = method != Meth::META_INPUT; + bool should_keep_name = true; + switch (method) { + case Meth::VALUE_ANONYMOUS: + should_keep_name = false; + break; + case Meth::VALUE_SHARED: + should_keep_name = m_config.keep_param_name; + ++m_nr_shared_tensor; + if (m_config.keep_param_name) { + mgb_assert( + m_used_param_names.insert(name).second, + "duplicated VALUE_SHARED tensor name: %s", name.c_str()); + m_cur_rst.params.emplace_back(name); + } + break; + case Meth::META_INPUT: + case Meth::VALUE_INPUT: + mgb_assert(!name.empty(), "empty input tensor name"); + mgb_assert( + m_used_input_names.insert(name).second, + "duplicated input tensor name: %s", name.c_str()); + m_cur_rst.inputs.emplace_back(name); + break; + } + + auto& layout = tensor.layout(); + flatbuffers::Offset> data; + if (has_value) { + check_tensor_value_valid(name, tensor); + auto&& dumper = m_config.tensor_value_dumper; + if (dumper) { + mgb_log_warn( + "serialization v2 format is pure flatbuffer format, not support " + "user tensor value dumper"); + } + data = m_builder.CreateVector( + reinterpret_cast(tensor.raw_ptr()), layout.span().high_byte); + m_cur_rst.tensor_value_bytes += layout.span().high_byte; + } + + auto fbname = should_keep_name ? m_builder.CreateSharedString(name) : 0; + auto fshape = m_builder.CreateVectorScalarCast(layout.shape, layout.ndim); + auto fcomp_node = fbs::v2::CreateCompNode( + m_builder, + m_builder.CreateSharedString(tensor.comp_node().to_string_logical())); + auto fdtype = build_dtype(layout.dtype); + + auto fformat_type = get_flatbuffer_tensor_format_type(layout.format); + auto fformat = build_tensor_format(layout.format); + auto serialized_tensor = fbs::v2::CreateTensor( + m_builder, fbname, fshape, fcomp_node, fdtype, fformat_type, fformat, data); + m_cur_opr_tensor.emplace_back(serialized_tensor); +} + +void GraphDumperOSSV2::dump_buf_with_len(const void* data, uint32_t size) { + auto blob = fbs::v2::CreateBlob( + m_builder, m_builder.CreateVector(static_cast(data), size)); + m_blobs.emplace_back(blob); +} + +// ----------------------------- Loader -------------------------------------- +CompNode GraphLoaderOSSV2::OprLoadContextImpl::load_comp_node( + const fbs::v2::CompNode* comp_node) { + mgb_assert(comp_node); + if (!comp_node->logical_locator()) + return {}; + auto loc = CompNode::Locator::parse(comp_node->logical_locator()->str()); + m_loader->m_cur_load_config->comp_node_mapper(loc); + return CompNode::load(loc); +} + +TensorFormat load_tensor_format( + const fbs::v2::TensorFormat fformat_type, const void* fformat, + const CompNode& comp_node) { + switch (fformat_type) { + case fbs::v2::TensorFormat_DefaultTensorFormat: + return megdnn::DefaultTensorFormat::make(); + case fbs::v2::TensorFormat_Image2DPackedTensorFormat: { + auto image_format = + static_cast(fformat); + auto handle = + MegDNNHandle::get(CompNodeEnv::from_comp_node(comp_node)).handle(); + return megdnn::Image2DPack4TensorFormat::make( + image_format->align_axis(), handle); + } + case fbs::v2::TensorFormat_LowbitsAlignedTensorFormat: { + auto lowbit_format = + static_cast(fformat); + return megdnn::LowbitsAlignedToBytesTensorFormat::make( + lowbit_format->size_nbits()); + } + default: + mgb_throw( + SerializationError, "invalid tensor format type in serialization."); + } +} + +TensorLayout load_tensor_layout( + const fbs::v2::Tensor* tensor, const CompNode& comp_node) { + TensorLayout layout; + if (tensor->shape()) { + layout.ndim = tensor->shape()->size(); + std::copy(tensor->shape()->begin(), tensor->shape()->end(), layout.shape); + } + if (tensor->dtype()) { + // modify data type inplace for TensorLayout + layout.modify_dtype_inplace(fbs::intl::load_dtype(tensor->dtype())); + } + if (tensor->format() && tensor->format_type()) { + layout.format = + load_tensor_format(tensor->format_type(), tensor->format(), comp_node); + } + layout.init_contiguous_stride(); + return layout; +} + +//! the opr loader should make sure the exist of tensors and the number of +//! tensor, here just assert it. +std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl::load_tensor() { + mgb_assert( + m_current_opr->tensors() && + m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); + auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); + auto comp_node = load_comp_node(tensor->comp_node()); + auto layout = load_tensor_layout(tensor, comp_node); + auto ret = std::make_shared(comp_node, layout); + + auto&& loader = m_loader->m_cur_load_config->tensor_value_loader; + if (tensor->data() && tensor->data()->size() > 0) { + if (loader) { + mgb_log_warn( + "serialization v2 format is pure flatbuffer format, not support " + "user tensor value loader"); + } + memcpy(ret->raw_ptr(), tensor->data()->data(), tensor->data()->size()); + } + if (tensor->name()) { + m_tensor_map[tensor->name()->str()] = ret; + } + if (auto&& mod = m_loader->m_cur_load_config->tensor_modifier) { + bool has_value = false; + if (tensor && tensor->data()) { + has_value = tensor->data()->size() != 0; + } + mod(tensor->name() ? tensor->name()->str() : "", has_value, *ret); + } + return ret; +} + +std::shared_ptr GraphLoaderOSSV2::OprLoadContextImpl:: + load_tensor_shared() { + mgb_assert( + m_current_opr->tensors() && + m_cur_opr_tensor_cnt < m_current_opr->tensors()->size()); + auto tensor = m_current_opr->tensors()->Get(m_cur_opr_tensor_cnt++); + auto comp_node = load_comp_node(tensor->comp_node()); + auto layout = load_tensor_layout(tensor, comp_node); + mgb_assert(tensor->data()); + auto&& shared_pair = m_loader->m_shared_tensor_map.at(m_cur_shared_tensor_idx++); + auto&& shared_tensor_ref = shared_pair.second[comp_node.mem_node()]; + if (shared_tensor_ref) { + if (shared_tensor_ref->comp_node() == comp_node) + return shared_tensor_ref; + // same mem node but different comp node, change comp node and share + // value + auto ret = std::make_shared(*shared_tensor_ref); + ret->comp_node(comp_node); + return ret; + } + if (tensor->name()) { + shared_pair.first = tensor->name()->str(); + } + + if (comp_node.mem_node() == CompNode::default_cpu().mem_node()) { + // directly forward CPU memory + HostTensorND hv{comp_node}; + if (tensor->data() && tensor->data()->size() > 0) { + hv.dtype(layout.dtype).resize(layout); + memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); + } + shared_tensor_ref = std::make_shared(); + *shared_tensor_ref = DeviceTensorND::make_proxy(hv); + } else { + // use lazy load for non-CPU devices + HostTensorND hv{CompNode::default_cpu()}; + if (tensor->data() && tensor->data()->size() > 0) { + hv.dtype(layout.dtype).resize(layout); + memcpy(hv.raw_ptr(), tensor->data()->data(), tensor->data()->size()); + } + shared_tensor_ref = m_device_value_loader.make(comp_node, std::move(hv)); + } + return shared_tensor_ref; +} + +Metadata GraphLoaderOSSV2::OprLoadContextImpl::load_metadata() { + const auto* fbmeta = m_loader->m_model->metadata(); + Metadata ret; + if (fbmeta) { + ret.is_valid = fbmeta->is_valid(); + ret.graph_modified = fbmeta->graph_modified(); + if (fbmeta->user_info()) { + ret.user_info = fbmeta->user_info()->str(); + ret.has_user_info = true; + } + if (fbmeta->optimize_options()) { + ret.optimize_options = fbmeta->optimize_options(); + ret.optimized_for_inference = true; + } + } + return ret; +} + +void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr( + const fbs::v2::Operator* fbopr) { + m_cur_opr_tensor_cnt = 0; + m_cur_opr_blob_cnt = 0; + m_cur_opr_param_cnt = 0; + + OperatorNodeConfig config; + if (fbopr->output_dtype()) { + config.output_dtype(fbs::intl::load_dtype(fbopr->output_dtype())); + } + if (fbopr->name()) { + config.name(fbopr->name()->str()); + } + if (fbopr->comp_node()) { + auto cnt = fbopr->comp_node()->size(); + cg::OperatorNodeConfig::CompNodeArray comp_node_arr(cnt); + for (size_t i = 0; i < cnt; i++) { + CompNode cn{}; + auto node = fbopr->comp_node()->Get(i); + if (node) { + cn = load_comp_node(node); + } + comp_node_arr[i] = cn; + } + config.comp_node_arr(comp_node_arr); + } + //! opr version must be exist + uint8_t opr_version = fbopr->opr_version(); + auto type_id = fbopr->type_id(); + auto opr_type = fbopr->type()->str(); + const OprRegistryV2* registry = + OprRegistryV2::versioned_find_by_id(type_id, opr_version); + mgb_throw_if( + !registry, SerializationError, + "failed to find opr with type %s id is %zu, use python env " + "config.dump_registered_oprs() to get a dict that maps from " + "opr id to opr name", + fbopr->type()->str().c_str(), type_id); + + // load inputs + VarNodeArray inputs; + if (fbopr->inputs()) { + inputs.resize(fbopr->inputs()->size()); + for (size_t i = 0; i < inputs.size(); ++i) { + inputs[i] = m_id2varnode.at(fbopr->inputs()->Get(i)); + } + } + + // call loader + auto accessor = registry->loader(*this, inputs, config); + auto opr = accessor.opr(); + + // check opr type; note that: + // 1. registry->type may be empty for dynamic opr loaders or legacy oprs + // 2. due to some optimization, an opr may be replaced by ImmutableTensor + mgb_assert( + opr && (opr->dyn_typeinfo() == registry->type || !registry->type || + opr->same_type()), + "got_type=%s expected_type=%s", opr ? opr->dyn_typeinfo()->name : nullptr, + registry->type->name); + // record output vars; read output names + size_t i = 0; + for (auto ovar : accessor.output()) { + if (!ovar->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + m_id2varnode.push_back(ovar); + if (fbopr->outputs()) { + auto id = fbopr->outputs()->Get(i); + mgb_assert( + m_id2varnode.size() - 1 == fbopr->outputs()->Get(i), + "id2var is %zu, fbs get id is %d\n", m_id2varnode.size() - 1, + fbopr->outputs()->Get(i)); + if (m_middle_tensors.size() > i) { + auto name = m_middle_tensors[id]->name()->str(); + ovar->name(name); + } + } + i++; + } + } + + opr->node_prop().attribute().priority = fbopr->priority(); +} + +GraphLoader::LoadResult GraphLoaderOSSV2::OprLoadContextImpl::load_oprs() { + // load oprs + const auto* oprs = m_loader->m_model->oprs(); + { + // inplace arith graph optimization is disabled during opr load + // it tries to restore the same graph as it was dumped + // see test TestSerializer2.LOGEXP for example + GraphLoader::ScopedGraphOptDisabler _(m_graph); + for (flatbuffers::uoffset_t i = 0; i < oprs->size(); ++i) { + m_current_opr = oprs->Get(i); + load_single_opr(m_current_opr); + } + } + + // batched loading device values + m_device_value_loader.apply(); + + LoadResult ret; + ret.graph = m_graph; + ret.tensor_map = m_tensor_map; + + const auto* outputs = m_loader->m_model->output_vars_idx(); + ret.output_var_list.resize(outputs->size()); + for (flatbuffers::uoffset_t i = 0; i < outputs->size(); i++) { + auto out = outputs->Get(i); + auto var = m_id2varnode.at(out->compact_id()); + ret.output_var_map[var->name()] = var; + ret.output_var_map_id[out->original_id()] = var; + ret.output_var_list[i] = var; + } + mgb_assert(m_cur_shared_tensor_idx == m_loader->m_shared_tensor_map.size()); + return ret; +} + +void GraphLoaderOSSV2::OprLoadContextImpl::load_middle_tensor() { + auto model = m_loader->m_model; + if (model->middle_tensors()) { + for (unsigned int i = 0; i < m_loader->m_model->middle_tensors()->size(); i++) { + m_middle_tensors.push_back(model->middle_tensors()->Get(i)); + } + } +} + +GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool rewind) { + mgb_assert(m_file); + m_cur_load_config = &config; + if (rewind) { + m_file->rewind(); + } + // Read fbs::Graph + uint32_t size; + m_file->read(&size, sizeof(size)); + m_model_buf = m_file->read_shared(size); + + mgb_throw_if( + !fbs::v2::ModelBufferHasIdentifier(m_model_buf.data()), SerializationError, + "invalid fbs model"); + + { + flatbuffers::Verifier verifier( + static_cast(m_model_buf.data()), m_model_buf.size()); + mgb_throw_if( + !fbs::v2::VerifyModelBuffer(verifier), SerializationError, + "model verification failed (invalid or corrupted model?)"); + } + + m_model = fbs::v2::GetModel(m_model_buf.data()); + m_mgb_version = m_model->mge_version(); + if (m_model->mge_version() > MGB_VERSION) { + mgb_log_warn( + "loading model from future runtime: version=%u " + "model_version=%u", + MGB_VERSION, m_model->mge_version()); + } + + if (m_shared_tensor_map.empty()) { + m_shared_tensor_map.resize(m_model->nr_shared_tensor()); + } else { + mgb_assert(m_shared_tensor_map.size() == m_model->nr_shared_tensor()); + } + + OprLoadContextImpl ctx{this, m_model->mge_version()}; + ctx.load_middle_tensor(); + auto metadata = ctx.load_metadata(); + auto result = ctx.load_oprs(); + result.metadata = metadata; + m_model_loaded = true; + result.graph_compile_ahead(); + return result; +} + +std::unique_ptr make_fbs_v2_dumper(std::unique_ptr file) { + return std::make_unique(std::move(file)); +} + +std::unique_ptr make_fbs_v2_loader(std::unique_ptr file) { + return std::make_unique(std::move(file)); +} + +bool is_fbs_v2_file(InputFile& file) { + constexpr size_t identifier_length = 25; + char identifier[identifier_length]; + file.read(identifier, identifier_length); + file.skip(-identifier_length); + //! skip the size in prefix of the file + return fbs::v2::ModelBufferHasIdentifier(identifier + sizeof(uint32_t)); +} + +} // namespace serialization +} // namespace mgb + +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/batched_device_value_loader.h b/src/serialization/include/megbrain/serialization/batched_device_value_loader.h similarity index 100% rename from src/serialization/impl/batched_device_value_loader.h rename to src/serialization/include/megbrain/serialization/batched_device_value_loader.h diff --git a/src/serialization/include/megbrain/serialization/dump_format.h b/src/serialization/include/megbrain/serialization/dump_format.h index 381571c74f63ff9658725eecea19a30bfb168f80..4872baa28482cc2f5092e46ffd63d5f767722f54 100644 --- a/src/serialization/include/megbrain/serialization/dump_format.h +++ b/src/serialization/include/megbrain/serialization/dump_format.h @@ -5,6 +5,7 @@ namespace serialization { enum class GraphDumpFormat { FLATBUFFERS, + FLATBUFFERS_V2, }; } // namespace serialization diff --git a/src/serialization/include/megbrain/serialization/opr_load_dump.h b/src/serialization/include/megbrain/serialization/opr_load_dump.h index 16c1f1c5bad9f9cb8301f4e3cfcc4fd68b1fafab..47ae370b70fe9caabdcade19751faf8e57b99d24 100644 --- a/src/serialization/include/megbrain/serialization/opr_load_dump.h +++ b/src/serialization/include/megbrain/serialization/opr_load_dump.h @@ -20,8 +20,12 @@ class FlatBufferBuilder; } // namespace flatbuffers namespace mgb { -namespace serialization { +constexpr uint8_t CURRENT_VERSION = 2u; +constexpr uint8_t BEGIN_VERSION = 0u; +constexpr uint8_t VERSION_1 = 1u; +constexpr uint8_t VERSION_2 = 2u; +namespace serialization { namespace fbs { template struct OperatorParamTraits; @@ -187,6 +191,9 @@ class OprLoadContext : public UserDataContainer::UserData { friend class OprLoadContextRawPOD; friend class OprLoadContextFlatBuffers; +protected: + virtual ~OprLoadContext() = default; + public: //! get current computing graph virtual ComputingGraph& graph() = 0; @@ -224,6 +231,12 @@ public: */ virtual SharedBuffer load_shared_buf_with_len() = 0; + /*! + * \brief get the serialization data of the current opr + * + */ + virtual const void* get_current_opr_data() { return nullptr; }; + /*! * \brief read a param and check that tag matches */ diff --git a/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h new file mode 100644 index 0000000000000000000000000000000000000000..655adf016c61e04614e5d166b94e90e360a55c5e --- /dev/null +++ b/src/serialization/include/megbrain/serialization/oss_opr_load_dump.h @@ -0,0 +1,257 @@ +#pragma once + +#if MGB_ENABLE_FBS_SERIALIZATION +#include "megbrain/comp_node_env.h" +#include "megbrain/graph/exc_extra_info.h" +#include "megbrain/serialization/batched_device_value_loader.h" +#include "megbrain/serialization/internal/schema_v2_generated.h" +#include "megbrain/serialization/opr_load_dump.h" +#include "megbrain/serialization/serializer.h" + +#define CAST_TO_FBS_V2_CTX(cvt) static_cast(ctx) + +namespace mgb { +namespace serialization { + +class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers { + const std::unique_ptr m_file; + flatbuffers::FlatBufferBuilder m_builder; + + DumpConfig m_config; + DumpResult m_cur_rst; + + size_t m_nr_shared_tensor; + + std::vector> m_oprs_to_dump; + ThinHashMap m_var_remove_in_dump; + + //! set of output vars specified by user + ThinHashSet m_output_vars; + std::unordered_set m_used_input_names, m_used_param_names; + + //! current opr to be dumped + cg::OperatorNodeBase* m_cur_opr = nullptr; + // Will be filled in dump_tensor + std::vector> m_cur_opr_tensor; + std::vector> m_blobs; + std::vector m_cur_opr_param_type; + std::vector> m_cur_opr_param; + + std::vector> m_model_middle_tensors; + ThinHashMap m_var2midtensor_id; + + SymbolVarArray converter_all_opr_to_compatiable(const SymbolVarArray& output_vars); + + void init_oprs_to_dump(const SymbolVarArray& endpoints); + + flatbuffers::Offset build_metadata(const Metadata& metadata); + flatbuffers::Offset build_single_opr( + cg::OperatorNodeBase* opr, const OprRegistryV2* registry); + + flatbuffers::Offset build_dtype(DType dtype); + +public: + GraphDumperOSSV2(std::unique_ptr file) : m_file{std::move(file)} {} + + DumpResult dump( + const SymbolVarArray& output_vars, const DumpConfig& config = {}, + const Metadata& metadata = {}) override; + + const GraphDumpConfig& config() const override { return m_config; } + + void dump_tensor( + const std::string& name, const HostTensorND& tensor, + TensorWriteMethod method) override; + + void append_param(uint32_t type, uint32_t value) override { + static_assert( + std::is_same::value, + "append_param depends on uoffset_t being uint32_t"); + static_assert( + std::is_standard_layout>::value, + "append_param depends on flatbuffers::Offset having " + "standard memory layout"); + mgb_assert(type != fbs::v2::OperatorParam_NONE); + m_cur_opr_param_type.emplace_back(static_cast(type)); + m_cur_opr_param.emplace_back(value); + } + + flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } + void dump_buf_with_len(const void* data, uint32_t size) override; + + GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; } + flatbuffers::Offset build_middle_tensor(const SymbolVar var); + flatbuffers::Offset build_output_var(const SymbolVar var); + flatbuffers::Offset build_tensor_format(const TensorLayout::Format& format); + + void set_current_opr(cg::OperatorNodeBase* cur_opr) { m_cur_opr = cur_opr; } +}; + +// ----------------------------- Loader -------------------------------------- +class GraphLoaderOSSV2 final : public GraphLoader { + const LoadConfig* m_cur_load_config = nullptr; + std::unique_ptr m_file; + SharedBuffer m_model_buf{{}, 0}; + const fbs::v2::Model* m_model; + SharedTensorIDMap m_shared_tensor_map; + uint32_t m_mgb_version = 0; + bool m_model_loaded = false; + + void verify(); + +public: + class OprLoadContextImpl; + friend class OprLoadContextImpl; + + GraphLoaderOSSV2(std::unique_ptr input_file) + : m_file{std::move(input_file)} {} + + std::unique_ptr reset_file(std::unique_ptr file) override { + file.swap(m_file); + return file; + } + + LoadResult load(const LoadConfig& config, bool rewind) override; + + const SharedTensorIDMap& shared_tensor_id_map() const override { + mgb_assert(m_model_loaded, "graph not loaded yet"); + return m_shared_tensor_map; + } + + GraphDumpFormat format() const override { return GraphDumpFormat::FLATBUFFERS_V2; } +}; + +class GraphLoaderOSSV2::OprLoadContextImpl final : public OprLoadContextFlatBuffers { + GraphLoaderOSSV2* const m_loader; + size_t m_cur_shared_tensor_idx = 0; + std::shared_ptr m_graph; + LoadResult::TensorMap m_tensor_map; + VarNodeArray m_id2varnode; + std::vector m_middle_tensors; + BatchedDeviceValueLoader m_device_value_loader; + const fbs::v2::Operator* m_current_opr; + size_t m_cur_opr_tensor_cnt; + size_t m_cur_opr_blob_cnt; + size_t m_cur_opr_param_cnt; + +public: + ComputingGraph& graph() override { return *m_graph; } + + const GraphLoadConfig& config() const override { + return *m_loader->m_cur_load_config; + } + + std::shared_ptr load_tensor() override; + + std::shared_ptr load_tensor_shared() override; + + void load_single_opr(const fbs::v2::Operator* opr); + + OprLoadContextImpl(GraphLoaderOSSV2* loader, uint32_t version) + : OprLoadContextFlatBuffers(version), m_loader{loader} { + m_graph = loader->m_cur_load_config->comp_graph; + if (!m_graph) { + m_graph = ComputingGraph::make(); + } + auto maker = [this]() { + return std::shared_ptr{ + std::shared_ptr{}, this}; + }; + auto got = m_graph->options().user_data.get_user_data_or_create( + maker); + mgb_assert(got == this); + } + + ~OprLoadContextImpl() noexcept { + auto nr = m_graph->options().user_data.pop_user_data(); + mgb_assert(nr == 1); + } + + Metadata load_metadata(); + LoadResult load_oprs(); + CompNode load_comp_node(const fbs::v2::CompNode* comp_node); + + void load_middle_tensor(); + + const void* get_next_param(uint32_t enumv) override { + auto type = static_cast(enumv); + if (m_cur_opr_param_cnt == 0) { + m_cur_opr_param_cnt++; + if (m_current_opr->param_type() == type) { + return m_current_opr->param(); + } else { + mgb_throw( + SerializationError, + "The param type is not match when load the opr."); + } + } + mgb_throw( + SerializationError, + "When load multi param in one Operator, please use read_param(index) " + "interface. "); + } + + std::string load_buf_with_len() override { + mgb_assert( + m_current_opr->custom_data() && + m_cur_opr_blob_cnt < m_current_opr->custom_data()->size()); + auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++); + mgb_assert(blob && blob->data()); + auto data = blob->data()->data(); + return {reinterpret_cast(data), blob->data()->size()}; + } + + SharedBuffer load_shared_buf_with_len() override { + mgb_assert( + m_current_opr->custom_data() && + m_cur_opr_blob_cnt < m_current_opr->custom_data()->size()); + auto blob = m_current_opr->custom_data()->Get(m_cur_opr_blob_cnt++); + mgb_assert(blob && blob->data()); + auto size = blob->data()->size(); + std::shared_ptr shptr{ + new uint8_t[size], [](uint8_t* p) { delete[] p; }}; + memcpy(shptr.get(), blob->data()->data(), size); + return {std::move(shptr), size}; + } + + const void* get_current_opr_data() override { + return reinterpret_cast(m_current_opr); + } + + template + T read_param(int index) { + using SourceType = typename fbs::ParamConverter::FlatBufferType; + auto enumv = fbs::OperatorParamTraits::enum_value; + auto type = static_cast(enumv); + if (index == 0) { + mgb_assert( + m_current_opr->param_type() == type, + "Load param error, the param type is not right."); + return fbs::ParamConverter::to_param( + static_cast(m_current_opr->param())); + } else { + int addition_index = index - 1; + if (addition_index >= + static_cast(m_current_opr->additional_params()->size())) { + mgb_log_warn( + "Model has no addition param of index %d, just construct a " + "default one.", + addition_index); + } else { + mgb_assert( + m_current_opr->additional_params_type()->Get(addition_index) == + type, + "Load param error, the addition param type is not right."); + return fbs::ParamConverter::to_param(static_cast( + m_current_opr->additional_params()->Get(addition_index))); + } + } + } +}; + +} // namespace serialization +} // namespace mgb + +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}