model_parser.cpp 5.2 KB
Newer Older
1 2
/**
 * \file src/model_parser.cpp
3
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
 */

#include "model_parser.h"
#include "decryption/decrypt_base.h"
#include "parse_info/parse_info_base.h"

using namespace lite;
using namespace model_parse;

std::string ModelParser::sm_model_tag = "packed_model";

void ModelParser::parse_header() {
    size_t tag_length = sm_model_tag.size();

    //! parse model tag
    const char* ptr = static_cast<char*>(m_model.get());
    std::string tag(static_cast<const char*>(ptr), tag_length);
    if (sm_model_tag == tag) {
        m_is_bare_model = false;
    } else {
        //! if no tag, the model is bare model, return
        m_is_bare_model = true;
        return;
    }

    uint8_t* buffer = static_cast<uint8_t*>(m_model.get()) + tag_length;
    auto packed_model = GetPackModel(buffer);
    auto models = packed_model->models();
    LITE_ASSERT(models->size() == 1, "Now only support one model");
    auto model = models->Get(0);
    m_model_name = model->header()->name()->c_str();
M
Megvii Engine Team 已提交
41
    m_model_decryption_name = model->header()->model_decryption_method()->c_str();
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
    m_info_decryption_name = model->header()->info_decryption_method()->c_str();
    m_info_parse_func_name = model->header()->info_parse_method()->c_str();

    m_info = model->info();
    m_model_data = model->data();
}

bool ModelParser::parse_model_info(
        Config& network_config, NetworkIO& network_io,
        std::unordered_map<std::string, LiteAny>& isolated_config_map,
        std::string& extra_info) const {
    //! no model info, no parse, direct return
    if (m_is_bare_model || !m_info) {
        return false;
    }
    size_t info_length = m_info->data()->size();
    const uint8_t* info_data = m_info->data()->Data();
    //! decryption the info
M
Megvii Engine Team 已提交
60 61
    auto info_ptr =
            decrypt_memory(info_data, info_length, m_info_decryption_name, info_length);
62 63
    //! parse the info
    LITE_LOCK_GUARD(parse_info_static_data().map_mutex);
M
Megvii Engine Team 已提交
64 65
    auto it_parse =
            parse_info_static_data().parse_info_methods.find(m_info_parse_func_name);
66
    if (it_parse == parse_info_static_data().parse_info_methods.end()) {
M
Megvii Engine Team 已提交
67 68 69
        LITE_THROW(ssprintf(
                "can't find model info parse function %s.",
                m_info_parse_func_name.c_str()));
70 71 72 73 74
    }
    auto model_info_parse_func =
            parse_info_static_data().parse_info_methods[m_info_parse_func_name];
    //! convert for NetworkIOInner to NetworkIO
    if (model_info_parse_func) {
M
Megvii Engine Team 已提交
75 76 77
        model_info_parse_func(
                info_ptr.get(), info_length, m_model_name, network_config, network_io,
                isolated_config_map, extra_info);
78
    } else {
M
Megvii Engine Team 已提交
79 80 81
        LITE_THROW(ssprintf(
                "model info parse function of  %s is empty",
                m_info_parse_func_name.c_str()));
82 83 84 85
    }
    return true;
}

M
Megvii Engine Team 已提交
86 87
std::shared_ptr<void> ModelParser::parse_model(
        size_t& model_length, const Config& config) const {
88 89 90 91 92 93 94 95 96 97 98 99 100 101
    if (m_is_bare_model) {
        if (config.bare_model_cryption_name.size() == 0) {
            model_length = m_total_length;
            return m_model;
        } else {
            return decrypt_memory(
                    static_cast<uint8_t*>(m_model.get()), m_total_length,
                    config.bare_model_cryption_name, model_length);
        }
    }
    LITE_ASSERT(m_model_data, "packed model parse error!");
    model_length = m_model_data->data()->size();
    const uint8_t* model_data = m_model_data->data()->Data();
    LITE_ASSERT(model_length > 0, "The loaded model is of zero length.");
M
Megvii Engine Team 已提交
102 103
    return decrypt_memory(
            model_data, model_length, m_model_decryption_name, model_length);
104 105 106 107 108 109 110 111
}

std::shared_ptr<void> ModelParser::decrypt_memory(
        const uint8_t* data, size_t length, const std::string decryption_name,
        size_t& result_length) const {
    const uint8_t* memory_ptr = data;
    if (decryption_name == "NONE") {
        result_length = length;
M
Megvii Engine Team 已提交
112
        return std::shared_ptr<void>(const_cast<uint8_t*>(memory_ptr), [](void*) {});
113 114 115 116
    }
    LITE_LOCK_GUARD(decryption_static_data().map_mutex);
    auto it = decryption_static_data().decryption_methods.find(decryption_name);
    if (it == decryption_static_data().decryption_methods.end()) {
M
Megvii Engine Team 已提交
117 118 119
        LITE_THROW(ssprintf(
                "The decryption method %s is not registed yet.",
                decryption_name.c_str()));
120 121 122 123 124 125
    }
    auto&& func = it->second.first;
    auto&& key = it->second.second;
    if (func) {
        auto model_vector = func(memory_ptr, length, *key);
        result_length = model_vector.size();
M
Megvii Engine Team 已提交
126
        auto tmp_model_vector = new std::vector<uint8_t>(std::move(model_vector));
127 128 129 130
        return std::shared_ptr<void>(
                tmp_model_vector->data(),
                [tmp_model_vector](void*) { delete tmp_model_vector; });
    } else {
M
Megvii Engine Team 已提交
131 132
        LITE_THROW(ssprintf(
                "No decryption function in %s method.", decryption_name.c_str()));
133 134 135 136
    }
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}