提交 2001c494 编写于 作者: M Megvii Engine Team

feat(lite): add input shape parse for load and run

GitOrigin-RevId: ec44429f55050af49a24ab1e2aa671a08aff460e
上级 64a8aaaf
......@@ -180,7 +180,13 @@ void DataParser::parse_npy(const std::string& name, const std::string& path) {
inputs.insert(std::make_pair(name, std::move(hv)));
}
void DataParser::parse_string(const std::string name, const std::string& str) {
void DataParser::parse_string(const std::string& name, const std::string& str) {
//! parse shape
if ('{' == str[0]) {
parse_shape(name, str);
return;
}
// data type
megdnn::DType data_type = mgb::dtype::Int32();
if (str.find(".") != std::string::npos or str.find(".") != std::string::npos) {
......@@ -257,3 +263,31 @@ void DataParser::parse_string(const std::string name, const std::string& str) {
}
inputs.insert(std::make_pair(name, std::move(hv)));
}
void DataParser::parse_shape(const std::string& name, const std::string& str) {
//! {d0,d1,..,dn}
mgb_assert(
"{" == str.substr(0, 1),
"invalid value: %s for parse_shape, valid format: {d0,d1,..,dn}\n",
str.c_str());
megdnn::SmallVector<size_t> shape;
std::string shape_size = "";
for (size_t i = 0; i < str.size(); ++i) {
char c = str[i];
if ('{' == c || ' ' == c) {
continue;
} else if (',' == c || '}' == c) {
shape.push_back(std::stoul(shape_size));
shape_size = "";
if ('}' == c) {
break;
}
} else {
shape_size += c;
}
}
mgb::HostTensorND hv(mgb::CompNode::default_cpu(), shape);
mgb::HostTensorStorage storage(mgb::CompNode::default_cpu());
hv.only_reset_raw_storage(storage);
inputs.insert(std::make_pair(name, std::move(hv)));
}
......@@ -30,7 +30,10 @@ private:
//! parser for .npy data
void parse_npy(const std::string& name, const std::string& path);
//! parser for user define string
void parse_string(const std::string name, const std::string& str);
//! parser for user defined string
void parse_string(const std::string& name, const std::string& str);
//! parser for user defined shape
void parse_shape(const std::string& name, const std::string& str);
};
} // namespace lar
......@@ -73,7 +73,17 @@ void InputOption::config_model_internel<ModelMdl>(
tensormap.find(i.first) != tensormap.end(),
"can't find tesnor named %s", i.first.c_str());
auto& in = tensormap.find(i.first)->second;
in->copy_from(i.second);
if (i.second.storage().empty()) {
mgb::HostTensorND hv;
hv.comp_node(mgb::CompNode::default_cpu(), true)
.dtype(in->dtype())
.resize(i.second.shape());
mgb::dt_byte* raw_ptr = hv.raw_ptr();
memset((char*)raw_ptr, 1, hv.layout().total_nr_elems());
in->copy_from(hv);
} else {
in->copy_from(i.second);
}
}
}
}
......
......@@ -39,10 +39,24 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
template <>
void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
mgb_log_debug("update input shape for global layout transform\n");
auto&& load_result = model->get_mdl_load_result();
if (m_force_batch_size > 0) {
for (auto&& i : load_result.tensor_map) {
auto& in = i.second;
mgb::TensorShape new_shape = in->shape();
new_shape[0] = m_force_batch_size;
mgb::HostTensorND new_tensor;
new_tensor.comp_node(mgb::CompNode::default_cpu(), true)
.dtype(in->dtype())
.resize(new_shape);
mgb::dt_byte* raw_ptr = new_tensor.raw_ptr();
memset((char*)raw_ptr, 1, new_tensor.layout().total_nr_elems());
in->copy_from(new_tensor);
}
}
for (auto&& item : load_result.output_var_list) {
if (item.shape()[0] > 1) {
mgb_log_warn(
......@@ -81,7 +95,11 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
}
load_result.output_var_list = output_vars;
}
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, m_layout_transform_target);
......@@ -156,6 +174,8 @@ GoptLayoutOption::GoptLayoutOption() {
}
m_layout_transform_dump_file = FLAGS_layout_transform_dump;
m_force_batch_size = FLAGS_layout_transform_batch_size;
m_option = {
{"layout_transform", lar::String::make("")},
};
......@@ -182,6 +202,14 @@ bool GoptLayoutOption::is_valid() {
}
}
ret = ret || !FLAGS_layout_transform_dump.empty();
if (FLAGS_layout_transform_batch_size > 0) {
mgb_assert(
FLAGS_layout_transform_batch_size > 0 &&
!FLAGS_layout_transform.empty(),
"\"layout-transform-batch-size\" should be set with "
"\"layout-transform\"");
ret = ret || FLAGS_layout_transform_batch_size > 0;
}
return ret || m_valid;
}
......@@ -233,5 +261,8 @@ DEFINE_string(
"The computing graph after global layout transform will be dumped to the given "
"file path.");
DEFINE_int32(
layout_transform_batch_size, -1,
"the batch size of input for global layout transform optimization working on");
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
REGIST_OPTION_VALIDATER(gopt_layout, lar::GoptLayoutOption::set_valid);
\ No newline at end of file
......@@ -5,6 +5,7 @@
#include "models/model.h"
#include "option_base.h"
DECLARE_string(layout_transform);
DECLARE_int32(layout_transform_batch_size);
DECLARE_string(layout_transform_dump);
namespace lar {
......@@ -38,5 +39,6 @@ private:
mgb::gopt::GraphTuningOptions::Target m_layout_transform_target;
static bool m_valid;
OptionValMap m_option;
int32_t m_force_batch_size;
};
} // namespace lar
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册