提交 d610c987 编写于 作者: M Megvii Engine Team

feat(lite): add disable configure by model info interface

GitOrigin-RevId: cd155a1fcf8bf6b845fa9118a6a791d662b2b624
上级 07bdb3bf
......@@ -117,6 +117,17 @@ struct LITE_API Config {
Options options = {};
};
/*!
* \brief Extra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
struct LITE_API ExtraConfig {
bool disable_configure_by_model_info = false;
};
/*!
* \brief config the network input and output item
*
......@@ -275,6 +286,12 @@ public:
//! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
void extra_configure(const ExtraConfig& extra_config);
public:
friend class NetworkHelper;
......@@ -288,6 +305,7 @@ private:
private:
bool m_loaded = false;
Config m_config;
ExtraConfig m_extra_config;
NetworkIO m_network_io;
std::unique_ptr<NetworkImplBase> m_impl;
std::string m_extra_info;
......
......@@ -113,6 +113,17 @@ typedef struct LiteConfig {
//! get default config
LITE_API LiteConfig* default_config();
/*!
* \brief Exetra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
typedef struct LiteExtraConfig {
int disable_configure_by_model_info;
} LiteExtraConfig;
/*!
* \brief config the network input and output item
*
......@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios);
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config);
#ifdef __cplusplus
}
#endif
......
......@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
return innner_io;
}
lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
lite::ExtraConfig ret;
ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
return ret;
}
int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
......@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory(
LITE_CAPI_END();
}
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
static_cast<lite::Network*>(network)->extra_configure(
convert_extra_config(extra_config));
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -134,6 +134,31 @@ class LiteConfig(Structure):
return data.__repr__()
class LiteExtraConfig(Structure):
"""
Extra configuration when load and compile the graph
disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""
_fields_ = [
("disable_configure_by_model_info", c_int),
]
def __init__(self, disable_model_config=False):
self.disable_configure_by_model_info = disable_model_config
def __repr__(self):
data = {
"disable_configure_by_model_info": bool(
self.disable_configure_by_model_info
),
}
return data.__repr__()
class LiteIO(Structure):
"""
config the network input and output item
......@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase):
"LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
]
......@@ -541,6 +567,12 @@ class LiteNetwork(object):
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
return ret_name
def extra_configure(self, extra_config):
"""
Extra Configuration to the network.
"""
self._api.LITE_extra_configure(self._network, extra_config)
def share_weights_with(self, src_network):
"""
share weights with the loaded network
......
......@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet):
network.load(model_path)
self.do_forward(network)
def test_disable_model_config(self):
model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
network = LiteNetwork()
network.extra_configure(LiteExtraConfig(True))
network.load(model_path)
self.do_forward(network)
def test_pack_cache_to_model(self):
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite")
network = LiteNetwork()
......
......@@ -31,7 +31,6 @@ using namespace mgb;
LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
void NetworkImplDft::set_config(const Config& config) {
m_user_config = std::make_unique<Config>();
*m_user_config = config;
m_compnode_locator = to_compnode_locator(m_user_config->device_type);
m_compnode_locator.device = config.device_id;
......@@ -428,8 +427,11 @@ void NetworkImplDft::load_model(
global_layout_transform();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid();
//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect();
//! update the IO of the network
......@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const {
}
void NetworkImplDft::set_io(const NetworkIO& network_io) {
m_network_io = std::make_unique<NetworkIOInner>();
for (auto&& in : network_io.inputs) {
m_network_io->inputs.emplace_back(in);
}
......
......@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL;
public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); }
NetworkImplDft() {
m_load_config.comp_graph = mgb::ComputingGraph::make();
m_user_config = std::make_unique<Config>();
m_network_io = std::make_unique<NetworkIOInner>();
}
using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar;
//! set the config of the network, include:
......
......@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
ModelParser model_parser(model_data, size);
//! parse the model info
if (model_parser.parse_model_info(
m_config, m_network_io, separate_config_map, m_extra_info)) {
m_config, m_network_io, separate_config_map, m_extra_info,
!m_extra_config.disable_configure_by_model_info)) {
if (m_config.backend == LiteBackend::LITE_DEFAULT &&
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
"parse_model"));
}
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
if (!m_extra_config.disable_configure_by_model_info) {
m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
}
}
//! decryption the model
size_t model_length;
......@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_END
}
void Network::extra_configure(const ExtraConfig& extra_config) {
LITE_ERROR_HANDLER_BEGIN
if (!extra_config.disable_configure_by_model_info) {
LITE_ASSERT(
!m_loaded,
"disable_configure_by_model_info should be configured before model "
"loaded.");
}
m_extra_config = extra_config;
LITE_ERROR_HANDLER_END
}
/*********************** MGE special network function ***************/
void Runtime::set_cpu_threads_number(
......
......@@ -43,7 +43,7 @@ void ModelParser::parse_header() {
bool ModelParser::parse_model_info(
Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const {
std::string& extra_info, bool configure_valid) const {
//! no model info, no parse, direct return
if (m_is_bare_model || !m_info) {
return false;
......@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info(
}
}
//! parse ModelInfo::algo_policy
if (m_info->algo_policy()) {
if (m_info->algo_policy() && configure_valid) {
size_t cache_length = m_info->algo_policy()->size();
const uint8_t* cache = m_info->algo_policy()->Data();
if (m_info_cache_parse_func_name == "LITE_parse_cache") {
......@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info(
} else {
LITE_THROW("opencl binary cache is not given");
}
} else {
LITE_THROW(ssprintf(
"model cache parse function of %s is not defined.",
m_info_cache_parse_func_name.c_str()));
}
}
return true;
......
......@@ -25,7 +25,7 @@ public:
bool parse_model_info(
Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const;
std::string& extra_info, bool configure_valid) const;
//! parse the model and decrypt the model
std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const;
......
......@@ -7,6 +7,8 @@
#include "lite/global.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "test_common.h"
#include <string.h>
......@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}
TEST(TestNetWorkOptions, DisableModelInfo) {
//! clear the cache set by other test
mgb::PersistentCache::inst().set_impl(
std::make_shared<mgb::InMemoryPersistentCache>());
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./test_pack_cache_to_model.lite";
std::string model_path2 = "./test_pack_cache_to_model.lite";
std::string input_name = "data";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->extra_configure({true});
Runtime::set_cpu_inplace_mode(network);
network->load_model(model_path);
//! the fast-run cache will not configure, so it is not support dump
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false);
ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true);
std::shared_ptr<Network> network2 = std::make_shared<Network>(config);
network2->load_model(model_path2);
//! the fast-run cache is configured by the model information
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true);
}
TEST(TestNetWorkOptions, FastRunIgnorBatch) {
Config config;
auto tensor = get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册