提交 2d6476a4 编写于 作者: M Megvii Engine Team

feat(lite): add auto decide model inference format option

GitOrigin-RevId: fcbf945de59a8d9a861e3605a40e69c942c05f4e
上级 10a0349e
......@@ -114,6 +114,9 @@ struct LITE_API Options {
* model is not pack json information data inside
*
* @param options configuration of Options
*
* @param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/
struct LITE_API Config {
bool has_compression = false;
......@@ -122,6 +125,7 @@ struct LITE_API Config {
LiteBackend backend = LiteBackend::LITE_DEFAULT;
std::string bare_model_cryption_name = {};
Options options = {};
bool auto_optimize_inference = false;
};
/*!
......
......@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option;
*
*\param has_compression flag whether the model is compressed, the compress
*method will read form the model
*\param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/
typedef struct LiteConfig {
int has_compression;
......@@ -108,6 +111,7 @@ typedef struct LiteConfig {
LiteBackend backend;
const char* bare_model_cryption_name;
LiteOptions options;
int auto_optimize_inference;
} LiteConfig;
//! get default config
......
......@@ -42,7 +42,8 @@ LiteConfig default_config_t = {
.device_type = LiteDeviceType::LITE_CPU,
.backend = LiteBackend::LITE_DEFAULT,
.bare_model_cryption_name = nullptr,
.options = default_option};
.options = default_option,
.auto_optimize_inference = false};
LiteConfig* default_config() {
return &default_config_t;
}
......@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
lite_config.options.enable_nchw32 = c_config.options.enable_nchw32;
lite_config.options.enable_nchw64 = c_config.options.enable_nchw64;
lite_config.auto_optimize_inference = c_config.auto_optimize_inference;
return lite_config;
}
......
......@@ -171,15 +171,18 @@ class LiteConfig(Structure):
options: configuration of Options
auto_optimize_inference: lite will detect the device information add set the options heuristically
Examples:
.. code-block::
from megenginelite import *
config = LiteConfig()
config.has_compression = false
config.has_compression = False
config.device_type = LiteDeviceType.LITE_CPU
config.backend = LiteBackend.LITE_DEFAULT
config.bare_model_cryption_name = "AES_default".encode("utf-8")
config.auto_optimize_inference = False
"""
_fields_ = [
......@@ -189,6 +192,7 @@ class LiteConfig(Structure):
("backend", c_int),
("_bare_model_cryption_name", c_char_p),
("options", LiteOptions),
("auto_optimize_inference", c_int),
]
def __init__(self, device_type=LiteDeviceType.LITE_CPU, option=None):
......@@ -202,6 +206,7 @@ class LiteConfig(Structure):
self.use_loader_dynamic_param = 0
self.has_compression = 0
self.backend = LiteBackend.LITE_DEFAULT
self.auto_optimize_inference = 0
@property
def bare_model_cryption_name(self):
......@@ -223,6 +228,7 @@ class LiteConfig(Structure):
"backend": LiteBackend(self.backend),
"bare_model_cryption_name": self.bare_model_cryption_name,
"options": self.options,
"auto_optimize_inference": self.auto_optimize_inference,
}
return data.__repr__()
......
......@@ -21,6 +21,10 @@
#include "megcore_opencl.h"
#endif
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include <fstream>
#include <memory>
#include <set>
......@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
LITE_ASSERT(src_impl.m_loader, "Clone network must after the network is loaded.");
m_load_result = src_impl.m_loader->load(m_load_config, true);
//! flag weather the mode is cross compnode model
cross_compnode_model_detect();
//! update the IO of the network
update_io();
//! replace the IO when there is device input or output
compile_graph();
configure_after_loaded();
}
void NetworkImplDft::application_config() {
......@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() {
}
}
void NetworkImplDft::global_layout_transform() {
void NetworkImplDft::layout_transform_optimization() {
if (m_set_layout_transform) {
mgb::ThinHashMap<mgb::SymbolVar, mgb::SymbolVar> out_var_map;
auto output_var_array = mgb::gopt::layout_transform(
......@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() {
for (auto&& item : m_load_result.output_var_map) {
item.second = out_var_map[item.second];
}
} else if (m_user_config->auto_optimize_inference) {
//! set model weight preprocess
m_load_config.comp_graph->options().graph_opt.weight_preprocess = true;
LITE_LOG(
"weight_preprocess is enabled, this maybe use more memory when "
"infernece.");
//! get the current format and data type of the model
bool is_model_nchw = true;
//! is any convolution is int8
bool is_model_int8 = false;
//! is all convolution is float32
bool is_model_float32 = true;
float conv_cnt = 0;
float dimshuffle_cnt = 0;
auto detect_int8_model = [&](const VarNode* input) {
if (input->dtype().enumv() == megdnn::DTypeEnum::QuantizedS8 ||
input->dtype().enumv() == megdnn::DTypeEnum::Quantized8Asymm) {
is_model_int8 = true;
is_model_float32 = false;
} else if (input->dtype().enumv() == megdnn::DTypeEnum::Float32) {
is_model_float32 = (is_model_float32 && true);
} else {
is_model_float32 = false;
}
};
cg::DepOprIter dep([&](cg::OperatorNodeBase* opr) {
if (auto conv = opr->try_cast_final<opr::ConvolutionForward>()) {
if (conv->param().format != megdnn::param::ConvBias::Format::NCHW) {
is_model_nchw = false;
}
conv_cnt++;
detect_int8_model(conv->input(0));
} else if (auto conv_bias = opr->try_cast_final<opr::ConvBias>()) {
if (conv_bias->param().format !=
megdnn::param::ConvBias::Format::NCHW) {
is_model_nchw = false;
}
conv_cnt++;
detect_int8_model(conv->input(0));
} else if (auto dimshuffle = opr->try_cast_final<opr::Dimshuffle>()) {
LITE_MARK_USED_VAR(dimshuffle);
dimshuffle_cnt++;
}
});
for (auto&& i : m_load_result.output_var_list)
dep.add(i);
float radio_dimshuffle_conv = 0;
if (conv_cnt > 0) {
radio_dimshuffle_conv = dimshuffle_cnt / conv_cnt;
}
//! format optimize can only applied on nchw model,
//! shufflenet like model will hurt the performance when using nchw88 or nchw44
//! format, here just heuristically decide the gate radio of
//! dimshuffle and convolution
if (!is_model_nchw || radio_dimshuffle_conv > 0.15f) {
return;
}
//! determine the layout by the device information
//! TODO: shufflenet like model use nchw88 or nchw44 will hurt the
//! performance
if (m_user_config->device_type == LITE_CPU) {
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
cpuinfo_initialize();
//! if all convolution and matmul data type is float32
if (is_model_float32) {
//! if device is x86
//! if x86 support avx, use format nchw88
if (cpuinfo_has_x86_avx()) {
m_load_config.comp_graph->options().graph_opt.enable_nchw88();
LITE_LOG("Configure model inference with nchw88 format.");
} else if (cpuinfo_has_x86_sse2() && !cpuinfo_has_x86_sse3()) {
//! if x86 only support sse2, use format nchw44
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
} else if (cpuinfo_has_arm_neon()) {
//! if device is arm, use format nchw44
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
}
} else if (is_model_int8) {
//! if date type of convolution is int8
//! if device is arm and support dot, use nchw44-dot format
if (cpuinfo_has_arm_neon() && cpuinfo_has_arm_neon_dot()) {
m_load_config.comp_graph->options().graph_opt.enable_nchw44_dot();
LITE_LOG("Configure model inference with nchw44-dot format.");
} else if (cpuinfo_has_arm_neon()) {
//! if device is arm and do not support dot, use nchw44 format
m_load_config.comp_graph->options().graph_opt.enable_nchw44();
LITE_LOG("Configure model inference with nchw44 format.");
}
}
#endif
}
}
}
......@@ -422,10 +516,13 @@ void NetworkImplDft::load_model(
}
m_load_result = m_loader->load(m_load_config, true);
configure_after_loaded();
}
void NetworkImplDft::configure_after_loaded() {
modify_exection_policy();
global_layout_transform();
layout_transform_optimization();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
......
......@@ -178,8 +178,10 @@ private:
//! call_back to the outputspec
void make_output_spec();
//! do the global layout transform for the given platform target
void global_layout_transform();
//! do layout transform for the given platform target, maybe the global
//! layout optimization or heuristically choose the best layout according to
//! the device information
void layout_transform_optimization();
//! modify the execution policy
void modify_exection_policy();
......@@ -223,6 +225,9 @@ private:
//! adapt option valid, it should call after update_io
void adapt_option_valid();
//! configure and optimize network after loaded
void configure_after_loaded();
private:
bool m_async = false;
bool m_is_cpu_inplace_mode = false;
......
......@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) {
compare_lite_tensor<float>(output_tensor, result_mgb);
}
TEST(TestNetWorkOptions, auto_optimize_inference_layout) {
Config config;
auto tensor = get_input_data("./input_data.npy");
std::string model_path = "./shufflenet.mge";
std::string input_name = "data";
auto result_mgb = mgb_lar(model_path, config, input_name, tensor);
config.auto_optimize_inference = true;
std::shared_ptr<Network> network = std::make_shared<Network>(config);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_io_tensor(input_name);
auto src_ptr = tensor->get_memory_ptr();
auto src_layout = tensor->get_layout();
input_tensor->reset(src_ptr, src_layout);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
auto result_tensor = std::make_shared<Tensor>(
LiteDeviceType::LITE_CPU, Layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT});
void* out_data = result_tensor->get_memory_ptr();
output_tensor->reset(out_data, result_tensor->get_layout());
network->forward();
network->wait();
compare_lite_tensor<float>(output_tensor, result_mgb);
}
TEST(TestNetWorkOptions, const_shape) {
Config config;
auto tensor = get_input_data("./input_data.npy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册