提交 565466c2 编写于 作者: M Megvii Engine Team

feat(lite): auto deduce output tensor shape before model forward

GitOrigin-RevId: 78e00dab5da3fcc91bb53d8588c06b5b25295e19
上级 29f9935d
......@@ -31,6 +31,7 @@ class LiteDataType(IntEnum):
LITE_INT16 = 3
LITE_INT8 = 4
LITE_UINT8 = 5
LITE_UINT16 = 6
class LiteTensorPhase(IntEnum):
......
......@@ -22,6 +22,7 @@ _lite_type_to_nptypes = {
LiteDataType.LITE_UINT8: np.uint8,
LiteDataType.LITE_INT8: np.int8,
LiteDataType.LITE_INT16: np.int16,
LiteDataType.LITE_UINT16: np.uint16,
LiteDataType.LITE_HALF: np.float16,
}
......@@ -33,6 +34,7 @@ _str_nptypes_to_lite_nptypes = {
np.dtype("uint8"): LiteDataType.LITE_UINT8,
np.dtype("int8"): LiteDataType.LITE_INT8,
np.dtype("int16"): LiteDataType.LITE_INT16,
np.dtype("uint16"): LiteDataType.LITE_UINT16,
np.dtype("float16"): LiteDataType.LITE_HALF,
}
......@@ -43,7 +45,7 @@ ctype_to_lite_dtypes = {
c_ubyte: LiteDataType.LITE_UINT8,
c_byte: LiteDataType.LITE_INT8,
c_short: LiteDataType.LITE_INT16,
c_ushort: LiteDataType.LITE_INT16,
c_ushort: LiteDataType.LITE_UINT16,
}
......@@ -83,7 +85,7 @@ class LiteLayout(Structure):
def __repr__(self):
data = {
"shapes": list(self.shapes),
"shapes": list(self.shapes)[0 : self.ndim],
"ndim": self.ndim,
"data_type": _lite_type_to_nptypes[LiteDataType(self.data_type)],
}
......
......@@ -100,6 +100,9 @@ LTensorLayout lite::to_impl_layout(const Layout& layout) {
case LiteDataType::LITE_INT16:
mge_layout.dtype = mgb::dtype::Int16();
break;
case LiteDataType::LITE_UINT16:
mge_layout.dtype = mgb::dtype::Uint16();
break;
default:
LITE_THROW(mgb::ssprintf(
"unsupport dtype in lite enum id is %d.",
......@@ -133,6 +136,9 @@ Layout lite::to_lite_layout(const LTensorLayout& mge_layout) {
case mgb::DTypeEnum::Int16:
layout.data_type = LiteDataType::LITE_INT16;
break;
case mgb::DTypeEnum::Uint16:
layout.data_type = LiteDataType::LITE_UINT16;
break;
case mgb::DTypeEnum::Int8:
layout.data_type = LiteDataType::LITE_INT8;
break;
......
......@@ -442,6 +442,24 @@ void NetworkImplDft::set_io(const NetworkIO& network_io) {
}
}
void NetworkImplDft::try_infer_tensor_layout(
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var) {
auto&& static_infer_mgr = m_load_config.comp_graph->static_infer_manager();
auto infer_trait = var.node()->get_static_infer_trait();
if (std::get<0>(infer_trait)) {
auto shape = static_infer_mgr.infer_shape_fallible(var.node());
if (!shape) {
LITE_WARN(
"Lite infer output shape failed, maybe the model is "
"dynamic "
"shape.\n");
return;
}
Layout layout = to_lite_layout(mgb::TensorLayout{*shape, var.dtype()});
tensor->set_layout(layout);
}
}
void NetworkImplDft::update_io() {
update_input();
update_output();
......@@ -564,6 +582,14 @@ void NetworkImplDft::update_output() {
out_it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type);
}
mgb::SymbolVar var;
for (auto&& out_var : m_load_result.output_var_list) {
if (out_var.node()->name() == out_it->name) {
var = out_var;
break;
}
}
try_infer_tensor_layout(out_it->lite_tensor, var);
}
//! user not set, use default output
} else {
......@@ -579,12 +605,14 @@ void NetworkImplDft::update_output() {
it->lite_tensor =
std::make_shared<Tensor>(device_id, stream_id, device_type);
}
try_infer_tensor_layout(it->lite_tensor, out);
} else {
IOInner output;
output.name = out.node()->name();
output.lite_tensor = std::make_shared<Tensor>(
device_id, stream_id, device_type, true);
m_network_io->outputs.push_back({output});
try_infer_tensor_layout(output.lite_tensor, out);
}
}
}
......
......@@ -201,6 +201,10 @@ private:
//! compile the graph to get the execute function
void compile_graph();
//! try to infer output tensor layout
void try_infer_tensor_layout(
std::shared_ptr<Tensor> tensor, mgb::cg::SymbolVar var);
private:
bool m_async = false;
bool m_is_cpu_inplace_mode = false;
......
......@@ -102,6 +102,8 @@ LiteLogLevel lite::get_log_level() {
}
std::string lite::ssprintf(const char* format, ...) {
if (!format)
return "";
va_list ap;
va_start(ap, format);
auto ret = svsprintf(format, ap);
......@@ -110,6 +112,8 @@ std::string lite::ssprintf(const char* format, ...) {
}
void lite::print_log(LiteLogLevel level, const char* format, ...) {
if (!format)
return;
if (static_cast<uint32_t>(level) < static_cast<uint32_t>(get_log_level())) {
return;
}
......
......@@ -90,6 +90,11 @@ TEST(TestNetWork, GetAllName) {
auto input_names = network->get_all_input_name();
auto output_names = network->get_all_output_name();
auto output_tensor = network->get_output_tensor(0);
auto out_layout = output_tensor->get_layout();
ASSERT_EQ(out_layout.ndim, 2);
ASSERT_EQ(out_layout.shapes[0], 1);
ASSERT_EQ(out_layout.shapes[1], 1000);
ASSERT_EQ(input_names.size(), 1);
ASSERT_EQ(output_names.size(), 1);
ASSERT_TRUE(input_names[0] == "data");
......
......@@ -488,6 +488,13 @@ public:
*/
MemAllocPlan& init_mem_plan(const DeviceTensorND* fixed_alloc = nullptr);
/*!
* \brief get the shape and value infer trait
*/
const std::tuple<void*, void*>& get_static_infer_trait() {
return m_static_infer_trait;
}
private:
//! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册