model_parser.h 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
/**
 * \file src/model_parser.h
 *
 * This file is part of MegEngine, a deep learning framework developed by
 * Megvii.
 *
 * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
 */

#pragma once
#include "lite/global.h"
#include "../network_impl_base.h"

#include "pack_model_generated.h"
#include <flatbuffers/flatbuffers.h>

#include <unordered_map>

namespace lite {

/*!
 * \brief parse the model and decyt
 */
class ModelParser {
public:
    ModelParser(std::shared_ptr<void> model_ptr, size_t model_length)
            : m_model(model_ptr), m_total_length(model_length) {
        //! parse the header
        parse_header();
    }

    //! parse the Info part of the model, update the network_config and
    //! network_io
    bool parse_model_info(
            Config& network_config, NetworkIO& network_io,
            std::unordered_map<std::string, LiteAny>& isolated_config_map,
            std::string& extra_info) const;

    //! parse the model and decrypt the model
    std::shared_ptr<void> parse_model(size_t& model_length,
                                      const Config& config) const;

private:
    //! parse the header of the model and store the model related information
    //! to the menber data
    void parse_header();

    //! decrypt a memory with length of length and decryption method name
    //! decrypt_name
    std::shared_ptr<void> decrypt_memory(const uint8_t* data, size_t length,
                                         const std::string decryption_name,
                                         size_t& result_length) const;

private:
    std::string m_model_name;
    //! the info and model decryption method name,  the
    //! decryption func can be found through this name
    std::string m_info_decryption_name;
    std::string m_model_decryption_name;
    //! the function name to parse the model info
    std::string m_info_parse_func_name;
    //! if a model is not added json info to the model is not crypted, the
    //! model is a bare model
    bool m_is_bare_model = true;

    const model_parse::ModelInfo* m_info = nullptr;
    const model_parse::ModelData* m_model_data = nullptr;

    std::shared_ptr<void> m_model;
    size_t m_total_length;

    static std::string sm_model_tag;
};
}  // namespace lite
   // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}