提交 26b52a61 编写于 作者: M Megvii Engine Team

feat(lite): add get model infomation before create network interface

GitOrigin-RevId: e499f3ebf8e03ccbe25e9b698c9e351fd19f0ed6
上级 5e17b3e4
......@@ -373,6 +373,14 @@ public:
//! dump network after global layout transform optimization
static void dump_layout_transform_model(
std::shared_ptr<Network> network, std::string optimized_model_path);
//! get the model io information before model loaded by model path.
static NetworkIO get_model_io_info(
const std::string& model_path, const Config& config = {});
//! get the model io information before model loaded by model memory.
static NetworkIO get_model_io_info(
const void* model_mem, size_t size, const Config& config = {});
};
} // namespace lite
......
......@@ -588,6 +588,28 @@ LITE_API int LITE_enable_global_layout_transform(LiteNetwork network);
LITE_API int LITE_dump_layout_transform_model(
LiteNetwork network, const char* dump_file_path);
/**! get the model io information before model loaded by model path.
* \param[in] model_path The model file path
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_get_model_io_info_by_path(
const char* model_path, const LiteConfig config, LiteNetworkIO* ios);
/** get the model io information before model loaded by model memory.
* \param[in] model_mem The model memory ptr
* \param[in] size The model memory ptr length
* \param[in] config The model config for loading
* \param[out] ios The model io infermation
* \return int if the return is not zero, error happened, the error message
* can get by LITE_get_last_error
*/
LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios);
#ifdef __cplusplus
}
#endif
......
......@@ -167,6 +167,31 @@ lite::NetworkIO convert_to_lite_io(const LiteNetworkIO c_network_io) {
return network_io;
}
struct InnerIO {
std::vector<std::string> names;
std::vector<LiteIO> inputs;
std::vector<LiteIO> outputs;
};
InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
InnerIO innner_io;
for (size_t i = 0; i < network_io.inputs.size(); i++) {
lite::IO io = network_io.inputs[i];
innner_io.names.push_back(io.name);
innner_io.inputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
for (size_t i = 0; i < network_io.outputs.size(); i++) {
lite::IO io = network_io.outputs[i];
innner_io.names.push_back(io.name);
innner_io.outputs.push_back(
{innner_io.names.back().c_str(), io.is_host, io.io_type,
convert_to_clayout(io.config_layout)});
}
return innner_io;
}
int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
......@@ -665,4 +690,59 @@ int LITE_dump_layout_transform_model(LiteNetwork network, const char* dump_file_
lite::Runtime::dump_layout_transform_model(network_shared, dump_file_path);
LITE_CAPI_END();
}
namespace {
static LITE_MUTEX mtx_io;
static std::unordered_map<const void*, InnerIO>& get_global_io_holder() {
static std::unordered_map<const void*, InnerIO> global_holder;
return global_holder;
}
int write_ios_from_cpp_io(
const lite::NetworkIO& cpp_io, LiteNetworkIO* ios, const void* key) {
LITE_CAPI_BEGIN();
LITE_LOCK_GUARD(mtx_io);
get_global_io_holder()[key] = convert_to_inner_io(cpp_io);
auto&& inner_io = get_global_io_holder()[key];
ios->input_size = inner_io.inputs.size();
ios->output_size = inner_io.outputs.size();
ios->inputs = inner_io.inputs.data();
ios->outputs = inner_io.outputs.data();
size_t i = 0;
for (; i < ios->input_size; i++) {
auto io_ptr = ios->inputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
for (; i < ios->output_size; i++) {
auto io_ptr = ios->outputs + i;
io_ptr->name = inner_io.names[i].c_str();
}
LITE_CAPI_END();
}
} // namespace
int LITE_get_model_io_info_by_path(
const char* model_path, const LiteConfig config, LiteNetworkIO* ios) {
LITE_CAPI_BEGIN();
LITE_ASSERT(model_path, "The model_path pass to LITE api is null");
auto&& cpp_ios = lite::Runtime::get_model_io_info(
std::string{model_path}, convert_to_lite_config(config));
return write_ios_from_cpp_io(
cpp_ios, ios, reinterpret_cast<const void*>(model_path));
LITE_CAPI_END();
}
int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios) {
LITE_CAPI_BEGIN();
LITE_ASSERT(model_mem, "The model_mem pass to LITE api is null");
auto&& cpp_ios = lite::Runtime::get_model_io_info(
model_mem, size, convert_to_lite_config(config));
return write_ios_from_cpp_io(
cpp_ios, ios, reinterpret_cast<const void*>(model_mem));
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -364,6 +364,14 @@ class _NetworkAPI(_LiteCObjBase):
("LITE_get_static_memory_alloc_info", [_Cnetwork, c_char_p]),
("LITE_enable_global_layout_transform", [_Cnetwork]),
("LITE_dump_layout_transform_model", [_Cnetwork, c_char_p]),
(
"LITE_get_model_io_info_by_path",
[c_char_p, LiteConfig, POINTER(_LiteNetworkIO)],
),
(
"LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
),
]
......@@ -619,3 +627,27 @@ class LiteNetwork(object):
def dump_layout_transform_model(self, model_file):
c_file = model_file.encode("utf-8")
self._api.LITE_dump_layout_transform_model(self._network, c_file)
def get_model_io_info(model_path, config=None):
"""
get the model IO information before create the NetWork, this IO
information can be used to configuration the NetWork.
"""
api = _NetworkAPI()._lib
c_path = c_char_p(model_path.encode("utf-8"))
ios = _LiteNetworkIO()
if config is not None:
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
else:
config = LiteConfig()
api.LITE_get_model_io_info_by_path(c_path, config, byref(ios))
ret_ios = LiteNetworkIO()
for i in range(ios.input_size):
ret_ios.add_input(ios.inputs[i])
for i in range(ios.output_size):
ret_ios.add_output(ios.outputs[i])
return ret_ios
......@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import os
import numpy as np
......@@ -200,3 +201,20 @@ def test_tensor_collect_batch_device_numpy():
for i in range(4):
for j in range(48):
assert data[i][j // 8][j % 8] == i + 1
def test_get_model_io_ahead():
source_dir = os.getenv("LITE_TEST_RESOURCE")
model_path = os.path.join(source_dir, "shufflenet.mge")
ios = get_model_io_info(model_path)
assert len(ios.inputs) == 1
assert ios.inputs[0].name == "data"
assert ios.inputs[0].config_layout.shapes[1] == 3
assert ios.inputs[0].config_layout.shapes[2] == 224
assert ios.inputs[0].config_layout.shapes[3] == 224
assert len(ios.outputs) == 1
assert ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]"
assert ios.outputs[0].config_layout.shapes[0] == 1
assert ios.outputs[0].config_layout.shapes[1] == 1000
......@@ -34,7 +34,7 @@ ADD_STATEMENT(NetworkImplDft, Dft);
} // namespace
// if it can't find the function, ignore
template <typename tensor_type, typename ret_type, typename... Args>
template <typename type, typename ret_type, typename... Args>
ret_type try_call_func(std::string func_name, Args... args) {
mark_used_variable(func_name);
mark_used_variable(args...);
......@@ -42,10 +42,10 @@ ret_type try_call_func(std::string func_name, Args... args) {
}
// if it can't find the function, throw error
template <typename tensor_type, typename ret_type, typename... Args>
template <typename type, typename ret_type, typename... Args>
ret_type call_func(std::string func_name, Args... args) {
mark_used_variable(args...);
auto backend_name = class_type_name<tensor_type>()();
auto backend_name = class_type_name<type>()();
auto msg_info = func_name + " is not aviliable in " + backend_name + " backend.";
LITE_THROW(msg_info.c_str());
}
......
......@@ -206,6 +206,26 @@ inline void call_func<NetworkImplDft, void>(
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, std::string model_path, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_path, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}
template <>
inline NetworkIO call_func<NetworkImplDft, NetworkIO>(
std::string func_name, const void* model_mem, size_t size, Config config) {
if (func_name == "get_model_io_info") {
return get_model_io_info_dft(model_mem, size, config);
} else {
THROW_FUNC_ERROR(func_name);
}
}
#undef THROW_FUNC_ERROR
} // namespace lite
......
......@@ -929,5 +929,75 @@ void NetworkImplDft::dump_layout_transform_model(std::string optimized_model_pat
"enable_global_layout_transform before"));
}
}
NetworkIO lite::get_model_io_info_dft(
const std::string& model_path, const Config& config) {
FILE* fin = fopen(model_path.c_str(), "rb");
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);
return get_model_io_info_dft(ptr, size, config);
}
NetworkIO lite::get_model_io_info_dft(
const void* model_mem, size_t size, const Config& config) {
std::shared_ptr<void> model{const_cast<void*>(model_mem), [](void*) {}};
auto input_file = mgb::serialization::InputFile::make_mem_proxy(model, size, false);
auto format =
mgb::serialization::GraphLoader::identify_graph_dump_format(*input_file);
if (!format.valid()) {
LITE_THROW("invalid model format");
}
auto loader =
mgb::serialization::GraphLoader::make(std::move(input_file), format.val());
mgb::serialization::GraphLoadConfig load_config;
load_config.comp_graph = mgb::ComputingGraph::make();
if (config.has_compression) {
load_config.tensor_value_loader = decompressed_tensor_value_loader;
}
auto compnode_locator = to_compnode_locator(config.device_type);
load_config.comp_node_mapper = [=](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.type = compnode_locator.type;
}
loc.device = compnode_locator.device;
};
auto load_result = loader->load(load_config, true);
NetworkIO IOs;
for (auto&& in_tensor_iter : load_result.tensor_map) {
IO in_io;
in_io.name = in_tensor_iter.first;
in_io.config_layout = to_lite_layout(in_tensor_iter.second->layout());
IOs.inputs.push_back(in_io);
}
auto infer_shape = [=](mgb::cg::SymbolVar var) -> const megdnn::TensorShape* {
auto&& static_infer_mgr = load_config.comp_graph->static_infer_manager();
using InferType = mgb::cg::static_infer::InferType;
if (static_infer_mgr.get_infer_type(var.node()).shape &
(InferType::CONST | InferType::RT_STATIC)) {
return static_infer_mgr.infer_shape_fallible(var.node());
} else {
return nullptr;
}
};
for (auto&& out : load_result.output_var_list) {
IO out_io;
out_io.name = out.node()->name();
if (auto shape = infer_shape(out)) {
out_io.config_layout = to_lite_layout(TensorLayout{*shape, out.dtype()});
} else {
out_io.config_layout = to_lite_layout(TensorLayout{{}, out.dtype()});
}
IOs.outputs.push_back(out_io);
}
return IOs;
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -262,6 +262,13 @@ private:
#endif
std::unique_ptr<mgb::OprIODumpBase> m_iodump;
};
//! get the model information before model loaded by Network
NetworkIO get_model_io_info_dft(const std::string& model_path, const Config& config);
//! get the model information before model loaded by Network by model memory and
//! size
NetworkIO get_model_io_info_dft(
const void* model_mem, size_t size, const Config& config);
} // namespace lite
......
......@@ -534,4 +534,26 @@ void Runtime::dump_layout_transform_model(
LITE_THROW("dump_layout_transform_model is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
NetworkIO Runtime::get_model_io_info(
const std::string& model_path, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_path, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
NetworkIO Runtime::get_model_io_info(
const void* model_mem, size_t size, const Config& config) {
LITE_ERROR_HANDLER_BEGIN
if (config.backend == LiteBackend::LITE_DEFAULT) {
return call_func<NetworkImplDft, NetworkIO>(
"get_model_io_info", model_mem, size, config);
}
LITE_THROW("get_model_io_info is not aviliable in the backend.");
LITE_ERROR_HANDLER_END
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -106,6 +106,54 @@ TEST(TestNetWork, GetAllName) {
ASSERT_TRUE(output_names[0] == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
}
TEST(TestNetWork, GetAllIoInfoAhead) {
Config config;
std::string model_path = "./shufflenet.mge";
auto ios = Runtime::get_model_io_info(model_path);
FILE* fin = fopen(model_path.c_str(), "rb");
ASSERT_TRUE(fin);
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);
auto ios_mem = Runtime::get_model_io_info(ptr, size);
ASSERT_EQ(ios.inputs.size(), ios_mem.inputs.size());
ASSERT_EQ(ios.inputs.size(), 1);
ASSERT_EQ(ios.outputs.size(), ios_mem.outputs.size());
ASSERT_EQ(ios.outputs.size(), 1);
ASSERT_TRUE(ios.inputs[0].name == "data");
ASSERT_TRUE(ios.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_TRUE(ios_mem.inputs[0].name == "data");
ASSERT_TRUE(
ios_mem.outputs[0].name == "TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_EQ(ios.inputs[0].config_layout.ndim, 4);
ASSERT_EQ(ios.inputs[0].config_layout.shapes[1], 3);
ASSERT_EQ(ios.inputs[0].config_layout.shapes[2], 224);
ASSERT_EQ(ios.outputs[0].config_layout.ndim, 2);
ASSERT_EQ(ios.outputs[0].config_layout.shapes[0], 1);
ASSERT_EQ(ios.outputs[0].config_layout.shapes[1], 1000);
ASSERT_EQ(ios_mem.inputs[0].config_layout.ndim, 4);
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[1], 3);
ASSERT_EQ(ios_mem.inputs[0].config_layout.shapes[2], 224);
ASSERT_EQ(ios_mem.outputs[0].config_layout.ndim, 2);
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[0], 1);
ASSERT_EQ(ios_mem.outputs[0].config_layout.shapes[1], 1000);
}
TEST(TestNetWork, LoadFBSModel) {
Config config;
std::string model_path = "./ax.mge";
......
......@@ -252,6 +252,55 @@ TEST(TestCapiNetWork, GetAllName) {
LITE_destroy_network(c_network);
}
TEST(TestCapiNetWork, GetAllNameAhead) {
std::string model_path = "./shufflenet.mge";
LiteNetworkIO ios, ios_mem;
LITE_CAPI_CHECK(LITE_get_model_io_info_by_path(
model_path.c_str(), *default_config(), &ios));
FILE* fin = fopen(model_path.c_str(), "rb");
ASSERT_TRUE(fin);
fseek(fin, 0, SEEK_END);
size_t size = ftell(fin);
fseek(fin, 0, SEEK_SET);
void* ptr = malloc(size);
std::shared_ptr<void> buf{ptr, ::free};
auto nr = fread(buf.get(), 1, size, fin);
LITE_ASSERT(nr == size);
fclose(fin);
LITE_CAPI_CHECK(
LITE_get_model_io_info_by_memory(ptr, size, *default_config(), &ios_mem));
ASSERT_EQ(ios.input_size, 1);
ASSERT_EQ(ios.output_size, 1);
ASSERT_EQ(ios_mem.input_size, 1);
ASSERT_EQ(ios_mem.output_size, 1);
ASSERT_TRUE(std::string(ios.inputs->name) == "data");
ASSERT_TRUE(ios.inputs->config_layout.ndim == 4);
ASSERT_TRUE(ios.inputs->config_layout.shapes[1] == 3);
ASSERT_TRUE(ios.inputs->config_layout.shapes[2] == 224);
ASSERT_TRUE(ios.inputs->config_layout.shapes[3] == 224);
ASSERT_TRUE(
std::string(ios.outputs->name) ==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_TRUE(ios.outputs->config_layout.ndim == 2);
ASSERT_TRUE(ios.outputs->config_layout.shapes[0] == 1);
ASSERT_TRUE(ios.outputs->config_layout.shapes[1] == 1000);
ASSERT_TRUE(std::string(ios_mem.inputs->name) == "data");
ASSERT_TRUE(ios_mem.inputs->config_layout.ndim == 4);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[1] == 3);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[2] == 224);
ASSERT_TRUE(ios_mem.inputs->config_layout.shapes[3] == 224);
ASSERT_TRUE(
std::string(ios_mem.outputs->name) ==
"TRUE_DIV(EXP[12065],reduce0[12067])[12077]");
ASSERT_TRUE(ios_mem.outputs->config_layout.ndim == 2);
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[0] == 1);
ASSERT_TRUE(ios_mem.outputs->config_layout.shapes[1] == 1000);
}
#if LITE_BUILD_WITH_RKNPU
static int GetTop(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册