model_parser.h 2.5 KB
Newer Older
1 2
/**
 * \file src/model_parser.h
3
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
 */

#pragma once
#include "lite/global.h"
#include "../network_impl_base.h"

#include "pack_model_generated.h"
#include <flatbuffers/flatbuffers.h>

#include <unordered_map>

namespace lite {

/*!
 * \brief parse the model and decyt
 */
class ModelParser {
public:
    ModelParser(std::shared_ptr<void> model_ptr, size_t model_length)
            : m_model(model_ptr), m_total_length(model_length) {
        //! parse the header
        parse_header();
    }

    //! parse the Info part of the model, update the network_config and
    //! network_io
    bool parse_model_info(
            Config& network_config, NetworkIO& network_io,
            std::unordered_map<std::string, LiteAny>& isolated_config_map,
            std::string& extra_info) const;

    //! parse the model and decrypt the model
    std::shared_ptr<void> parse_model(size_t& model_length,
                                      const Config& config) const;

private:
    //! parse the header of the model and store the model related information
    //! to the menber data
    void parse_header();

    //! decrypt a memory with length of length and decryption method name
    //! decrypt_name
    std::shared_ptr<void> decrypt_memory(const uint8_t* data, size_t length,
                                         const std::string decryption_name,
                                         size_t& result_length) const;

private:
    std::string m_model_name;
    //! the info and model decryption method name,  the
    //! decryption func can be found through this name
    std::string m_info_decryption_name;
    std::string m_model_decryption_name;
    //! the function name to parse the model info
    std::string m_info_parse_func_name;
    //! if a model is not added json info to the model is not crypted, the
    //! model is a bare model
    bool m_is_bare_model = true;

    const model_parse::ModelInfo* m_info = nullptr;
    const model_parse::ModelData* m_model_data = nullptr;

    std::shared_ptr<void> m_model;
    size_t m_total_length;

    static std::string sm_model_tag;
};
}  // namespace lite
   // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}