/** * \file lite/load_and_run/src/models/model_mdl.cpp * * This file is part of MegEngine, a deep learning framework developed by * Megvii. * * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved. */ #include "model_mdl.h" #include #include DECLARE_bool(share_param_mem); using namespace lar; ModelMdl::ModelMdl(const std::string& path) : model_path(path) { mgb_log_warn("creat mdl model use XPU as default comp node"); m_load_config.comp_graph = mgb::ComputingGraph::make(); m_load_config.comp_graph->options().graph_opt_level = 0; testcase_num = 0; } void ModelMdl::load_model() { //! read dump file if (share_model_mem) { mgb_log_warn("enable share model memory"); FILE* fin = fopen(model_path.c_str(), "rb"); mgb_assert(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno)); fseek(fin, 0, SEEK_END); size_t size = ftell(fin); fseek(fin, 0, SEEK_SET); void* ptr = malloc(size); std::shared_ptr buf{ptr, free}; auto nr = fread(buf.get(), 1, size, fin); mgb_assert(nr == size, "read model file failed"); fclose(fin); m_model_file = mgb::serialization::InputFile::make_mem_proxy(buf, size); } else { m_model_file = mgb::serialization::InputFile::make_fs(model_path.c_str()); } //! get dump_with_testcase model testcase number char magic[8]; m_model_file->read(magic, sizeof(magic)); if (strncmp(magic, "mgbtest0", 8)) { m_model_file->rewind(); } else { m_model_file->read(&testcase_num, sizeof(testcase_num)); } auto format = mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file); mgb_assert( format.valid(), "invalid format, please make sure model is dumped by GraphDumper"); //! load computing graph of model m_loader = mgb::serialization::GraphLoader::make( std::move(m_model_file), format.val()); m_load_result = m_loader->load(m_load_config, false); m_load_config.comp_graph.reset(); // get testcase input generated by dump_with_testcase.py if (testcase_num) { for (auto&& i : m_load_result.tensor_map) { test_input_tensors.emplace_back(i.first, i.second.get()); } std::sort(test_input_tensors.begin(), test_input_tensors.end()); } // initialize output callback for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { mgb::ComputingGraph::Callback cb; m_callbacks.push_back(cb); } } void ModelMdl::make_output_spec() { for (size_t i = 0; i < m_load_result.output_var_list.size(); i++) { auto item = m_load_result.output_var_list[i]; m_output_spec.emplace_back(item, std::move(m_callbacks[i])); } m_asyc_exec = m_load_result.graph_compile(m_output_spec); } std::shared_ptr& ModelMdl::reset_loader() { m_loader = mgb::serialization::GraphLoader::make( m_loader->reset_file(), m_loader->format()); return m_loader; } void ModelMdl::run_model() { mgb_assert( m_asyc_exec != nullptr, "empty asychronous function to execute after graph compiled"); m_asyc_exec->execute(); } void ModelMdl::wait() { m_asyc_exec->wait(); }