提交 2d6476a4 编写于 作者: M Megvii Engine Team

feat(lite): add auto decide model inference format option

GitOrigin-RevId: fcbf945de59a8d9a861e3605a40e69c942c05f4e
上级 10a0349e
...@@ -114,6 +114,9 @@ struct LITE_API Options { ...@@ -114,6 +114,9 @@ struct LITE_API Options {
* model is not pack json information data inside * model is not pack json information data inside
* *
* @param options configuration of Options * @param options configuration of Options
*
* @param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/ */
struct LITE_API Config { struct LITE_API Config {
bool has_compression = false; bool has_compression = false;
...@@ -122,6 +125,7 @@ struct LITE_API Config { ...@@ -122,6 +125,7 @@ struct LITE_API Config {
LiteBackend backend = LiteBackend::LITE_DEFAULT; LiteBackend backend = LiteBackend::LITE_DEFAULT;
std::string bare_model_cryption_name = {}; std::string bare_model_cryption_name = {};
Options options = {}; Options options = {};
bool auto_optimize_inference = false;
}; };
/*! /*!
......
...@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option; ...@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option;
* *
*\param has_compression flag whether the model is compressed, the compress *\param has_compression flag whether the model is compressed, the compress
*method will read form the model *method will read form the model
*\param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/ */
typedef struct LiteConfig { typedef struct LiteConfig {
int has_compression; int has_compression;
...@@ -108,6 +111,7 @@ typedef struct LiteConfig { ...@@ -108,6 +111,7 @@ typedef struct LiteConfig {
LiteBackend backend; LiteBackend backend;
const char* bare_model_cryption_name; const char* bare_model_cryption_name;
LiteOptions options; LiteOptions options;
int auto_optimize_inference;
} LiteConfig; } LiteConfig;
//! get default config //! get default config
......
...@@ -42,7 +42,8 @@ LiteConfig default_config_t = { ...@@ -42,7 +42,8 @@ LiteConfig default_config_t = {
.device_type = LiteDeviceType::LITE_CPU, .device_type = LiteDeviceType::LITE_CPU,
.backend = LiteBackend::LITE_DEFAULT, .backend = LiteBackend::LITE_DEFAULT,
.bare_model_cryption_name = nullptr, .bare_model_cryption_name = nullptr,
.options = default_option}; .options = default_option,
.auto_optimize_inference = false};
LiteConfig* default_config() { LiteConfig* default_config() {
return &default_config_t; return &default_config_t;
} }
...@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) { ...@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
lite_config.options.enable_nchw32 = c_config.options.enable_nchw32; lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
lite_config.options.enable_nchw64 = c_config.options.enable_nchw64; lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
lite_config.auto_optimize_inference = c_config.auto_optimize_inference;
return lite_config; return lite_config;
} }
......
...@@ -171,15 +171,18 @@ class LiteConfig(Structure): ...@@ -171,15 +171,18 @@ class LiteConfig(Structure):
options: configuration of Options options: configuration of Options
auto_optimize_inference: lite will detect the device information add set the options heuristically
Examples: Examples:
.. code-block:: .. code-block::
from megenginelite import * from megenginelite import *
config = LiteConfig() config = LiteConfig()
config.has_compression = false config.has_compression = False
config.device_type = LiteDeviceType.LITE_CPU config.device_type = LiteDeviceType.LITE_CPU
config.backend = LiteBackend.LITE_DEFAULT config.backend = LiteBackend.LITE_DEFAULT
config.bare_model_cryption_name = "AES_default".encode("utf-8") config.bare_model_cryption_name = "AES_default".encode("utf-8")
config.auto_optimize_inference = False
""" """
_fields_ = [ _fields_ = [
...@@ -189,6 +192,7 @@ class LiteConfig(Structure): ...@@ -189,6 +192,7 @@ class LiteConfig(Structure):
("backend", c_int), ("backend", c_int),
("_bare_model_cryption_name", c_char_p), ("_bare_model_cryption_name", c_char_p),
("options", LiteOptions), ("options", LiteOptions),
("auto_optimize_inference", c_int),
] ]
def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None): def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
...@@ -202,6 +206,7 @@ class LiteConfig(Structure): ...@@ -202,6 +206,7 @@ class LiteConfig(Structure):
self.use_loader_dynamic_param = 0 self.use_loader_dynamic_param = 0
self.has_compression = 0 self.has_compression = 0
self.backend = LiteBackend.LITE_DEFAULT self.backend = LiteBackend.LITE_DEFAULT
self.auto_optimize_inference = 0
@property @property
def bare_model_cryption_name(self): def bare_model_cryption_name(self):
...@@ -223,6 +228,7 @@ class LiteConfig(Structure): ...@@ -223,6 +228,7 @@ class LiteConfig(Structure):
"backend": LiteBackend(self.backend), "backend": LiteBackend(self.backend),
"bare_model_cryption_name": self.bare_model_cryption_name, "bare_model_cryption_name": self.bare_model_cryption_name,
"options": self.options, "options": self.options,
"auto_optimize_inference": self.auto_optimize_inference,
} }
return data.__repr__() return data.__repr__()
......
...@@ -21,6 +21,10 @@ ...@@ -21,6 +21,10 @@
#include "megcore_opencl.h" #include "megcore_opencl.h"
#endif #endif
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include <fstream> #include <fstream>
#include <memory> #include <memory>
#include <set> #include <set>
...@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) { ...@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded."); LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
m_load_result = src_impl.m_loader->load(m_load_config, true); m_load_result = src_impl.m_loader->load(m_load_config, true);
//! flag weather the mode is cross compnode model configure_after_loaded();
cross_compnode_model_detect();
//! update the IO of the network
update_io();
//! replace the IO when there is device input or output
compile_graph();
} }
void NetworkImplDft::application_config() { void NetworkImplDft::application_config() {
...@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() { ...@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() {
} }
} }
void NetworkImplDft::global_layout_transform() { void NetworkImplDft::layout_transform_optimization() {
if (m_set_layout_transform) { if (m_set_layout_transform) {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map; mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto output_var_array = mgb::gopt::layout_transform( auto output_var_array = mgb::gopt::layout_transform(
...@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() { ...@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() {
for (auto&& item : m_load_result.output_var_map) { for (auto&& item : m_load_result.output_var_map) {
item.second = out_var_map[item.second]; item.second = out_var_map[item.second];
} }
} else if (m_user_config->auto_optimize_inference) {
//! set model weight preprocess
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
LITE_LOG(
"weight_preprocess is enabled, this maybe use more memory when "
"infernece.");
//! get the current format and data type of the model
bool is_model_nchw = true;
//! is any convolution is int8
bool is_model_int8 = false;
//! is all convolution is float32
bool is_model_float32 = true;
float conv_cnt = 0;
float dimshuffle_cnt = 0;
auto detect_int8_model = [&](const VarNode* input) {
if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) {
is_model_int8 = true;
is_model_float32 = false;
} else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) {
is_model_float32 = (is_model_float32 && true);
} else {
is_model_float32 = false;
}
};
cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) {
if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) {
if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) {
is_model_nchw = false;
}
conv_cnt++;
detect_int8_model(conv->input(0));
} else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) {
if (conv_bias->param().format !=
megdnn::param::ConvBias::Format::NCHW) {
is_model_nchw = false;
}
conv_cnt++;
detect_int8_model(conv->input(0));
} else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) {
LITE_MARK_USED_VAR(dimshuffle);
dimshuffle_cnt++;
}
});
for (auto&& i : m_load_result.output_var_list)
dep.add(i);
float radio_dimshuffle_conv = 0;
if (conv_cnt > 0) {
radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt;
}
//! format optimize can only applied on nchw model,
//! shufflenet like model will hurt the performance when using nchw88 or nchw44
//! format, here just heuristically decide the gate radio of
//! dimshuffle and convolution
if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) {
return;
}
//! determine the layout by the device information
//! TODO: shufflenet like model use nchw88 or nchw44 will hurt the
//! performance
if (m_user_config->device_type == LITE_CPU) {
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
cpuinfo_initialize();
//! if all convolution and matmul data type is float32
if (is_model_float32) {
//! if device is x86
//! if x86 support avx, use format nchw88
if (cpuinfo_has_x86_avx()) {
m_load_config.comp_graph->options().graph_opt.enable_nchw88();
LITE_LOG("Configure model inference with nchw88 format.");
} else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) {
//! if x86 only support sse2, use format nchw44
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
} else if (cpuinfo_has_arm_neon()) {
//! if device is arm, use format nchw44
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
}
} else if (is_model_int8) {
//! if date type of convolution is int8
//! if device is arm and support dot, use nchw44-dot format
if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) {
m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot();
LITE_LOG("Configure model inference with nchw44-dot format.");
} else if (cpuinfo_has_arm_neon()) {
//! if device is arm and do not support dot, use nchw44 format
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
}
}
#endif
}
} }
} }
...@@ -422,10 +516,13 @@ void NetworkImplDft::load_model( ...@@ -422,10 +516,13 @@ void NetworkImplDft::load_model(
} }
m_load_result = m_loader->load(m_load_config, true); m_load_result = m_loader->load(m_load_config, true);
configure_after_loaded();
}
void NetworkImplDft::configure_after_loaded() {
modify_exection_policy(); modify_exection_policy();
global_layout_transform(); layout_transform_optimization();
//! some optimization option maybe invalid in some case, so here just //! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply. //! auto determine whether some options will apply.
......
...@@ -178,8 +178,10 @@ private: ...@@ -178,8 +178,10 @@ private:
//! call_back to the outputspec //! call_back to the outputspec
void make_output_spec(); void make_output_spec();
//! do the global layout transform for the given platform target //! do layout transform for the given platform target, maybe the global
void global_layout_transform(); //! layout optimization or heuristically choose the best layout according to
//! the device information
void layout_transform_optimization();
//! modify the execution policy //! modify the execution policy
void modify_exection_policy(); void modify_exection_policy();
...@@ -223,6 +225,9 @@ private: ...@@ -223,6 +225,9 @@ private:
//! adapt option valid, it should call after update_io //! adapt option valid, it should call after update_io
void adapt_option_valid(); void adapt_option_valid();
//! configure and optimize network after loaded
void configure_after_loaded();
private: private:
bool m_async = false; bool m_async = false;
bool m_is_cpu_inplace_mode = false; bool m_is_cpu_inplace_mode = false;
......
...@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) { ...@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) {
compare_lite_tensor<float>(output_tensor, result_mgb); compare_lite_tensor<float>(output_tensor, result_mgb);
} }
TEST(TestNetWorkOptions, auto_optimize_inference_layout) {
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
auto result_mgb = mgb_lar(model_path, config, input_name, tensor);
config.auto_optimize_inference = true;
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
auto result_tensor = std::make_shared<Tensor>(
LiteDeviceType::LITE_CPU, Layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT});
void* out_data = result_tensor->get_memory_ptr();
output_tensor->reset(out_data, result_tensor->get_layout());
network->forward();
network->wait();
compare_lite_tensor<float>(output_tensor, result_mgb);
}
TEST(TestNetWorkOptions, const_shape) { TEST(TestNetWorkOptions, const_shape) {
Config config; Config config;
auto tensor = get_input_data("./input_data.npy"); auto tensor = get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册