model_mdl.cpp 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9
#include "model_mdl.h"
#include <gflags/gflags.h>
#include <iostream>

DECLARE_bool(share_param_mem);

using namespace lar;

ModelMdl::ModelMdl(const std::string& path) : model_path(path) {
10
    mgb_log("creat mdl model use XPU as default comp node");
11 12 13 14 15 16 17
    m_load_config.comp_graph = mgb::ComputingGraph::make();
    testcase_num = 0;
}

void ModelMdl::load_model() {
    //! read dump file
    if (share_model_mem) {
18
        mgb_log("enable share model memory");
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
        FILE* fin = fopen(model_path.c_str(), "rb");
        mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
        fseek(fin, 0, SEEK_END);
        size_t size = ftell(fin);
        fseek(fin, 0, SEEK_SET);

        void* ptr = malloc(size);
        std::shared_ptr<void> buf{ptr, free};
        auto nr = fread(buf.get(), 1, size, fin);
        mgb_assert(nr == size, "read model file failed");
        fclose(fin);

        m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size);
    } else {
        m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str());
    }

    //! get dump_with_testcase model testcase number
    char magic[8];
    m_model_file->read(magic, sizeof(magic));
    if (strncmp(magic, "mgbtest0", 8)) {
        m_model_file->rewind();
    } else {
        m_model_file->read(&testcase_num, sizeof(testcase_num));
    }

45
    m_format =
46 47
            mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file);
    mgb_assert(
48
            m_format.valid(),
49 50 51 52
            "invalid format, please make sure model is dumped by GraphDumper");

    //! load computing graph of model
    m_loader = mgb::serialization::GraphLoader::make(
53
            std::move(m_model_file), m_format.val());
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    m_load_result = m_loader->load(m_load_config, false);
    m_load_config.comp_graph.reset();

    // get testcase input generated by dump_with_testcase.py
    if (testcase_num) {
        for (auto&& i : m_load_result.tensor_map) {
            test_input_tensors.emplace_back(i.first, i.second.get());
        }
        std::sort(test_input_tensors.begin(), test_input_tensors.end());
    }
    // initialize output callback
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        mgb::ComputingGraph::Callback cb;
        m_callbacks.push_back(cb);
    }
}

void ModelMdl::make_output_spec() {
    for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) {
        auto item = m_load_result.output_var_list[i];
        m_output_spec.emplace_back(item, std::move(m_callbacks[i]));
    }

    m_asyc_exec = m_load_result.graph_compile(m_output_spec);
78 79 80 81 82 83 84
    auto new_output_vars = m_asyc_exec->get_output_vars();
    mgb::cg::SymbolVarArray symbol_var_array;
    symbol_var_array.reserve(new_output_vars.size());
    for (auto output_var : new_output_vars) {
        symbol_var_array.emplace_back(output_var);
    }
    m_load_result.output_var_list = symbol_var_array;
85 86
}

87 88 89 90 91 92 93 94 95
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
        std::unique_ptr<mgb::serialization::InputFile> input_file) {
    if (input_file) {
        m_loader = mgb::serialization::GraphLoader::make(
                std::move(input_file), m_loader->format());
    } else {
        m_loader = mgb::serialization::GraphLoader::make(
                m_loader->reset_file(), m_loader->format());
    }
96 97 98 99 100 101 102 103 104 105 106 107 108
    return m_loader;
}

void ModelMdl::run_model() {
    mgb_assert(
            m_asyc_exec != nullptr,
            "empty asychronous function to execute after graph compiled");
    m_asyc_exec->execute();
}

void ModelMdl::wait() {
    m_asyc_exec->wait();
}
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180

#if MGB_ENABLE_JSON
std::shared_ptr<mgb::json::Object> ModelMdl::get_io_info() {
    std::shared_ptr<mgb::json::Array> inputs = mgb::json::Array::make();
    std::shared_ptr<mgb::json::Array> outputs = mgb::json::Array::make();
    auto get_dtype = [&](megdnn::DType data_type) {
        std::map<megdnn::DTypeEnum, std::string> type_map = {
                {mgb::dtype::Float32().enumv(), "float32"},
                {mgb::dtype::Int32().enumv(), "int32"},
                {mgb::dtype::Int16().enumv(), "int16"},
                {mgb::dtype::Uint16().enumv(), "uint16"},
                {mgb::dtype::Int8().enumv(), "int8"},
                {mgb::dtype::Uint8().enumv(), "uint8"}};
        return type_map[data_type.enumv()];
    };
    auto make_shape = [](mgb::TensorShape& shape_) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                shape;
        for (size_t i = 0; i < shape_.ndim; ++i) {
            std::string lable = "dim";
            lable += std::to_string(shape_.ndim - i - 1);
            shape.push_back(
                    {mgb::json::String(lable),
                     mgb::json::NumberInt::make(shape_[shape_.ndim - i - 1])});
        }
        return shape;
    };
    for (auto&& i : m_load_result.tensor_map) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                json_inp;
        auto shape_ = i.second->shape();
        json_inp.push_back(
                {mgb::json::String("shape"),
                 mgb::json::Object::make(make_shape(shape_))});
        json_inp.push_back(
                {mgb::json::String("dtype"),
                 mgb::json::String::make(get_dtype(i.second->dtype()))});
        json_inp.push_back(
                {mgb::json::String("name"), mgb::json::String::make(i.first)});
        inputs->add(mgb::json::Object::make(json_inp));
    }

    for (auto&& i : m_load_result.output_var_list) {
        std::vector<std::pair<mgb::json::String, std::shared_ptr<mgb::json::Value>>>
                json_out;
        auto shape_ = i.shape();
        json_out.push_back(
                {mgb::json::String("shape"),
                 mgb::json::Object::make(make_shape(shape_))});
        json_out.push_back(
                {mgb::json::String("dtype"),
                 mgb::json::String::make(get_dtype(i.dtype()))});

        json_out.push_back(
                {mgb::json::String("name"), mgb::json::String::make(i.node()->name())});
        outputs->add(mgb::json::Object::make(json_out));
    }
    return mgb::json::Object::make(
            {{"IO",
              mgb::json::Object::make({{"outputs", outputs}, {"inputs", inputs}})}});
}
#endif

std::vector<uint8_t> ModelMdl::get_model_data() {
    std::vector<uint8_t> out_data;
    auto out_file = mgb::serialization::OutputFile::make_vector_proxy(&out_data);
    using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
    DumpConfig config{1, false, false};
    auto dumper =
            mgb::serialization::GraphDumper::make(std::move(out_file), m_format.val());
    dumper->dump(m_load_result.output_var_list, config);
    return out_data;
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
}

void ModelMdl::update_io() {
    //! update output varlist when input shape maybe change(some pass excution
    //! time depends on the shape of init input)
    mgb::thin_hash_table::ThinHashMap<mgb::cg::SymbolVar, mgb::cg::SymbolVar> varmap;
    auto&& network = m_load_result;
    std::unordered_map<void*, std::string> tensor_name_map;
    for (auto& input : network.tensor_map) {
        tensor_name_map.insert({input.second->raw_ptr(), input.first});
    }
    mgb::cg::DepOprIter dep([&](mgb::cg::OperatorNodeBase* opr) {
        if (auto h2d = opr->try_cast_final<mgb::opr::Host2DeviceCopy>()) {
            if (tensor_name_map.find(h2d->host_data()->raw_ptr()) !=
                tensor_name_map.end()) {
                //! make new h2d opr with new host tensor shape
                std::string name = tensor_name_map[h2d->host_data()->raw_ptr()];
                std::shared_ptr<mgb::HostTensorND> new_tensor =
                        std::make_shared<mgb::HostTensorND>();
                new_tensor->copy_from(*h2d->host_data());

                auto h2d_opr = mgb::opr::Host2DeviceCopy::make(
                        *h2d->owner_graph(), new_tensor, h2d->param(), h2d->config());
                //! rename new h2d with given name
                h2d_opr.node()->owner_opr()->name(name);
                varmap[h2d->output(0)] = h2d_opr;
            }
        }
    });
    //! get replace var map
    for (auto&& i : network.output_var_list)
        dep.add(i);
    //! replace new h2d and update related var shape
    if (!varmap.empty()) {
        auto output_vars = mgb::cg::replace_vars(network.output_var_list, varmap);
        network.output_var_list = output_vars;
    }
218
}