提交 d610c987 编写于 作者: M Megvii Engine Team

feat(lite): add disable configure by model info interface

GitOrigin-RevId: cd155a1fcf8bf6b845fa9118a6a791d662b2b624
上级 07bdb3bf
...@@ -117,6 +117,17 @@ struct LITE_API Config { ...@@ -117,6 +117,17 @@ struct LITE_API Config {
Options options = {}; Options options = {};
}; };
/*!
* \brief Extra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
struct LITE_API ExtraConfig {
bool disable_configure_by_model_info = false;
};
/*! /*!
* \brief config the network input and output item * \brief config the network input and output item
* *
...@@ -275,6 +286,12 @@ public: ...@@ -275,6 +286,12 @@ public:
//! get static peak memory info showed by Graph visualization //! get static peak memory info showed by Graph visualization
void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const; void get_static_memory_alloc_info(const std::string& log_dir = "logs/test") const;
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
void extra_configure(const ExtraConfig& extra_config);
public: public:
friend class NetworkHelper; friend class NetworkHelper;
...@@ -288,6 +305,7 @@ private: ...@@ -288,6 +305,7 @@ private:
private: private:
bool m_loaded = false; bool m_loaded = false;
Config m_config; Config m_config;
ExtraConfig m_extra_config;
NetworkIO m_network_io; NetworkIO m_network_io;
std::unique_ptr<NetworkImplBase> m_impl; std::unique_ptr<NetworkImplBase> m_impl;
std::string m_extra_info; std::string m_extra_info;
......
...@@ -113,6 +113,17 @@ typedef struct LiteConfig { ...@@ -113,6 +113,17 @@ typedef struct LiteConfig {
//! get default config //! get default config
LITE_API LiteConfig* default_config(); LITE_API LiteConfig* default_config();
/*!
* \brief Exetra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
typedef struct LiteExtraConfig {
int disable_configure_by_model_info;
} LiteExtraConfig;
/*! /*!
* \brief config the network input and output item * \brief config the network input and output item
* *
...@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory( ...@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory(
const void* model_mem, size_t size, const LiteConfig config, const void* model_mem, size_t size, const LiteConfig config,
LiteNetworkIO* ios); LiteNetworkIO* ios);
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
......
...@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) { ...@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
return innner_io; return innner_io;
} }
lite::ExtraConfig convert_extra_config(const LiteExtraConfig& extra_config) {
lite::ExtraConfig ret;
ret.disable_configure_by_model_info = extra_config.disable_configure_by_model_info;
return ret;
}
int LITE_make_default_network(LiteNetwork* network) { int LITE_make_default_network(LiteNetwork* network) {
LITE_CAPI_BEGIN(); LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null"); LITE_ASSERT(network, "The network pass to LITE api is null");
...@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory( ...@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory(
LITE_CAPI_END(); LITE_CAPI_END();
} }
LITE_API int LITE_extra_configure(LiteNetwork network, LiteExtraConfig extra_config) {
LITE_CAPI_BEGIN();
LITE_ASSERT(network, "The network pass to LITE api is null");
static_cast<lite::Network*>(network)->extra_configure(
convert_extra_config(extra_config));
LITE_CAPI_END();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -134,6 +134,31 @@ class LiteConfig(Structure): ...@@ -134,6 +134,31 @@ class LiteConfig(Structure):
return data.__repr__() return data.__repr__()
class LiteExtraConfig(Structure):
"""
Extra configuration when load and compile the graph
disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""
_fields_ = [
("disable_configure_by_model_info", c_int),
]
def __init__(self, disable_model_config=False):
self.disable_configure_by_model_info = disable_model_config
def __repr__(self):
data = {
"disable_configure_by_model_info": bool(
self.disable_configure_by_model_info
),
}
return data.__repr__()
class LiteIO(Structure): class LiteIO(Structure):
""" """
config the network input and output item config the network input and output item
...@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase): ...@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase):
"LITE_get_model_io_info_by_memory", "LITE_get_model_io_info_by_memory",
[c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)], [c_char_p, c_size_t, LiteConfig, POINTER(_LiteNetworkIO)],
), ),
("LITE_extra_configure", [_Cnetwork, LiteExtraConfig]),
] ]
...@@ -541,6 +567,12 @@ class LiteNetwork(object): ...@@ -541,6 +567,12 @@ class LiteNetwork(object):
ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)] ret_name = [names[i].decode("utf-8") for i in range(nr_output.value)]
return ret_name return ret_name
def extra_configure(self, extra_config):
"""
Extra Configuration to the network.
"""
self._api.LITE_extra_configure(self._network, extra_config)
def share_weights_with(self, src_network): def share_weights_with(self, src_network):
""" """
share weights with the loaded network share weights with the loaded network
......
...@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet): ...@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet):
network.load(model_path) network.load(model_path)
self.do_forward(network) self.do_forward(network)
def test_disable_model_config(self):
model_path = os.path.join(self.source_dir, "test_packed_model_rc4.lite")
network = LiteNetwork()
network.extra_configure(LiteExtraConfig(True))
network.load(model_path)
self.do_forward(network)
def test_pack_cache_to_model(self): def test_pack_cache_to_model(self):
model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite") model_path = os.path.join(self.source_dir, "test_pack_cache_to_model.lite")
network = LiteNetwork() network = LiteNetwork()
......
...@@ -31,7 +31,6 @@ using namespace mgb; ...@@ -31,7 +31,6 @@ using namespace mgb;
LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft); LITE_DYN_TYPE_OBJ_FINAL_IMPL(NetworkImplDft);
void NetworkImplDft::set_config(const Config& config) { void NetworkImplDft::set_config(const Config& config) {
m_user_config = std::make_unique<Config>();
*m_user_config = config; *m_user_config = config;
m_compnode_locator = to_compnode_locator(m_user_config->device_type); m_compnode_locator = to_compnode_locator(m_user_config->device_type);
m_compnode_locator.device = config.device_id; m_compnode_locator.device = config.device_id;
...@@ -428,8 +427,11 @@ void NetworkImplDft::load_model( ...@@ -428,8 +427,11 @@ void NetworkImplDft::load_model(
global_layout_transform(); global_layout_transform();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid(); adapt_option_valid();
//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect(); cross_compnode_model_detect();
//! update the IO of the network //! update the IO of the network
...@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const { ...@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const {
} }
void NetworkImplDft::set_io(const NetworkIO& network_io) { void NetworkImplDft::set_io(const NetworkIO& network_io) {
m_network_io = std::make_unique<NetworkIOInner>();
for (auto&& in : network_io.inputs) { for (auto&& in : network_io.inputs) {
m_network_io->inputs.emplace_back(in); m_network_io->inputs.emplace_back(in);
} }
......
...@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase { ...@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL; LITE_DYN_TYPE_OBJ_FINAL_DECL;
public: public:
NetworkImplDft() { m_load_config.comp_graph = mgb::ComputingGraph::make(); } NetworkImplDft() {
m_load_config.comp_graph = mgb::ComputingGraph::make();
m_user_config = std::make_unique<Config>();
m_network_io = std::make_unique<NetworkIOInner>();
}
using S = megdnn::param::ExecutionPolicy::Strategy; using S = megdnn::param::ExecutionPolicy::Strategy;
using Var = mgb::cg::SymbolVar; using Var = mgb::cg::SymbolVar;
//! set the config of the network, include: //! set the config of the network, include:
......
...@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) { ...@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
ModelParser model_parser(model_data, size); ModelParser model_parser(model_data, size);
//! parse the model info //! parse the model info
if (model_parser.parse_model_info( if (model_parser.parse_model_info(
m_config, m_network_io, separate_config_map, m_extra_info)) { m_config, m_network_io, separate_config_map, m_extra_info,
!m_extra_config.disable_configure_by_model_info)) {
if (m_config.backend == LiteBackend::LITE_DEFAULT && if (m_config.backend == LiteBackend::LITE_DEFAULT &&
m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) { m_impl->get_backend_type() != LiteBackend::LITE_DEFAULT) {
m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>( m_impl.reset(try_call_func<NetworkImplDft, lite::Network::NetworkImplBase*>(
"parse_model")); "parse_model"));
} }
m_impl->set_config(m_config); if (!m_extra_config.disable_configure_by_model_info) {
m_impl->set_io(m_network_io); m_impl->set_config(m_config);
m_impl->set_io(m_network_io);
}
} }
//! decryption the model //! decryption the model
size_t model_length; size_t model_length;
...@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const { ...@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_END LITE_ERROR_HANDLER_END
} }
void Network::extra_configure(const ExtraConfig& extra_config) {
LITE_ERROR_HANDLER_BEGIN
if (!extra_config.disable_configure_by_model_info) {
LITE_ASSERT(
!m_loaded,
"disable_configure_by_model_info should be configured before model "
"loaded.");
}
m_extra_config = extra_config;
LITE_ERROR_HANDLER_END
}
/*********************** MGE special network function ***************/ /*********************** MGE special network function ***************/
void Runtime::set_cpu_threads_number( void Runtime::set_cpu_threads_number(
......
...@@ -43,7 +43,7 @@ void ModelParser::parse_header() { ...@@ -43,7 +43,7 @@ void ModelParser::parse_header() {
bool ModelParser::parse_model_info( bool ModelParser::parse_model_info(
Config& network_config, NetworkIO& network_io, Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map, std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const { std::string& extra_info, bool configure_valid) const {
//! no model info, no parse, direct return //! no model info, no parse, direct return
if (m_is_bare_model || !m_info) { if (m_is_bare_model || !m_info) {
return false; return false;
...@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info( ...@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info(
} }
} }
//! parse ModelInfo::algo_policy //! parse ModelInfo::algo_policy
if (m_info->algo_policy()) { if (m_info->algo_policy() && configure_valid) {
size_t cache_length = m_info->algo_policy()->size(); size_t cache_length = m_info->algo_policy()->size();
const uint8_t* cache = m_info->algo_policy()->Data(); const uint8_t* cache = m_info->algo_policy()->Data();
if (m_info_cache_parse_func_name == "LITE_parse_cache") { if (m_info_cache_parse_func_name == "LITE_parse_cache") {
...@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info( ...@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info(
} else { } else {
LITE_THROW("opencl binary cache is not given"); LITE_THROW("opencl binary cache is not given");
} }
} else {
LITE_THROW(ssprintf(
"model cache parse function of %s is not defined.",
m_info_cache_parse_func_name.c_str()));
} }
} }
return true; return true;
......
...@@ -25,7 +25,7 @@ public: ...@@ -25,7 +25,7 @@ public:
bool parse_model_info( bool parse_model_info(
Config& network_config, NetworkIO& network_io, Config& network_config, NetworkIO& network_io,
std::unordered_map<std::string, LiteAny>& isolated_config_map, std::unordered_map<std::string, LiteAny>& isolated_config_map,
std::string& extra_info) const; std::string& extra_info, bool configure_valid) const;
//! parse the model and decrypt the model //! parse the model and decrypt the model
std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const; std::shared_ptr<void> parse_model(size_t& model_length, const Config& config) const;
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#include "lite/global.h" #include "lite/global.h"
#include "megbrain/tensor.h" #include "megbrain/tensor.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "test_common.h" #include "test_common.h"
#include <string.h> #include <string.h>
...@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) { ...@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) {
compare_lite_tensor<float>(output_tensor, result_mgb); compare_lite_tensor<float>(output_tensor, result_mgb);
} }
TEST(TestNetWorkOptions, DisableModelInfo) {
//! clear the cache set by other test
mgb::PersistentCache::inst().set_impl(
std::make_shared<mgb::InMemoryPersistentCache>());
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./test_pack_cache_to_model.lite";
std::string model_path2 = "./test_pack_cache_to_model.lite";
std::string input_name = "data";
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->extra_configure({true});
Runtime::set_cpu_inplace_mode(network);
network->load_model(model_path);
//! the fast-run cache will not configure, so it is not support dump
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), false);
ASSERT_EQ(Runtime::is_cpu_inplace_mode(network), true);
std::shared_ptr<Network> network2 = std::make_shared<Network>(config);
network2->load_model(model_path2);
//! the fast-run cache is configured by the model information
ASSERT_EQ(mgb::PersistentCache::inst().support_dump_cache(), true);
}
TEST(TestNetWorkOptions, FastRunIgnorBatch) { TEST(TestNetWorkOptions, FastRunIgnorBatch) {
Config config; Config config;
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册