提交 643ab1c1 编写于 作者: M Megvii Engine Team

fix(lite): fix lite compnode mapper

GitOrigin-RevId: 994308b511e097496dad97b6421b61b668e3b1bb
上级 0ad5eeae
......@@ -33,7 +33,7 @@ typedef enum {
LITE_NPU = 4,
LITE_CAMBRICON = 5,
//! when the device information is set in model, so set LITE_DEVICE_DEFAULT
//! in lite
//! in lite, which equal to xpu in megengine
LITE_DEVICE_DEFAULT = 6,
} LiteDeviceType;
......
......@@ -116,35 +116,22 @@ void NetworkImplDft::application_config() {
m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
}
//! if device is LITE_NONE, the compnode information is stored in model
//! if device is LITE_NONE, the compnode information is stored in model or
//! xpu in MegEngine
if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
//! currently not set Locator type because an atlas mgb model is a
//! cross-compnode graph
if (device_type == LiteDeviceType::LITE_ATLAS) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::ATLAS) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
} else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
loc = m_compnode_locator;
};
}
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.type = m_compnode_locator.type;
}
loc.device = m_compnode_locator.device;
//! if user set the thread number and the compnode is multithread
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
m_nr_threads != 1) {
loc.stream = m_nr_threads;
} else {
loc.stream = m_compnode_locator.stream;
}
};
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册