提交 a1ca390d 编写于 作者: M Megvii Engine Team

fix(lite): fix const shape error for lar fitting mode

GitOrigin-RevId: 1cea25fe4cec845d509295286e03bf112ba99eea
上级 2001c494
...@@ -29,7 +29,8 @@ void OptionsFastManager::init(std::shared_ptr<OptionMap>& options) { ...@@ -29,7 +29,8 @@ void OptionsFastManager::init(std::shared_ptr<OptionMap>& options) {
m_internal_options_name = { m_internal_options_name = {
{"enable_fuse_conv_bias_with_z"}, {"enable_fuse_conv_bias_with_z"},
{"enable_fuse_preprocess"}, {"enable_fuse_preprocess"},
{"record_comp_seq"}}; {"record_comp_seq"},
{"const_shape"}};
//! record the independent option value //! record the independent option value
for (auto& option : *options) { for (auto& option : *options) {
auto option_vals = option.second->get_option(); auto option_vals = option.second->get_option();
...@@ -226,17 +227,21 @@ void OptionsTimeProfiler::profile_with_given_options( ...@@ -226,17 +227,21 @@ void OptionsTimeProfiler::profile_with_given_options(
auto average = inference_time / runtime_param.run_iter; auto average = inference_time / runtime_param.run_iter;
if (exception_state) { if (exception_state) {
average = TIME_OUT; average = TIME_OUT;
} printf("out of time (this may be caused by some exception, please checkout the "
"log) when profile option:\n%s\n",
//! record profile result option_code.c_str());
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(), average); } else {
m_options_profile_result.insert({option_code, average}); printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(),
average);
//! record the best result //! record profile result
m_options_profile_result.insert({option_code, average});
if (average < m_best_setting.second) {
m_best_setting.first = option_code; //! record the best result
m_best_setting.second = average;
if (average < m_best_setting.second) {
m_best_setting.first = option_code;
m_best_setting.second = average;
}
} }
} }
/////////////////////////// UserInfoParser ///////////////////////////// /////////////////////////// UserInfoParser /////////////////////////////
...@@ -244,23 +249,9 @@ void UserInfoParser::get_user_info() { ...@@ -244,23 +249,9 @@ void UserInfoParser::get_user_info() {
//! register user information tips //! register user information tips
std::vector<std::pair<std::string, std::string>> info_tips; std::vector<std::pair<std::string, std::string>> info_tips;
m_user_info["fitting_preference"] = "Inferspeed"; m_user_info["fitting_preference"] = "Inferspeed";
info_tips.push_back(
{"use_const_shape", "whether the input shape is constant?(yes/no)?"});
for (auto& tip : info_tips) {
std::cout << tip.second;
std::string answer = "";
std::cin >> answer;
m_user_info[tip.first] = answer;
}
} }
void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) { void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) {
std::vector<std::string> fixed_options; std::vector<std::string> fixed_options;
if (m_user_info["use_const_shape"] == "yes") {
fixed_options.push_back("const_shape");
} else if (m_user_info["use_const_shape"] != "no") {
mgb_log_error("invalid user information for \"use_const_shape\"");
}
fixed_options.push_back("enable_fuse_conv_bias_nonlinearity"); fixed_options.push_back("enable_fuse_conv_bias_nonlinearity");
std::vector<std::string> tmp_options; std::vector<std::string> tmp_options;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册