提交 643ab1c1 编写于 作者: M Megvii Engine Team

fix(lite): fix lite compnode mapper

GitOrigin-RevId: 994308b511e097496dad97b6421b61b668e3b1bb
上级 0ad5eeae
...@@ -33,7 +33,7 @@ typedef enum { ...@@ -33,7 +33,7 @@ typedef enum {
LITE_NPU = 4, LITE_NPU = 4,
LITE_CAMBRICON = 5, LITE_CAMBRICON = 5,
//! when the device information is set in model, so set LITE_DEVICE_DEFAULT //! when the device information is set in model, so set LITE_DEVICE_DEFAULT
//! in lite //! in lite, which equal to xpu in megengine
LITE_DEVICE_DEFAULT = 6, LITE_DEVICE_DEFAULT = 6,
} LiteDeviceType; } LiteDeviceType;
......
...@@ -116,35 +116,22 @@ void NetworkImplDft::application_config() { ...@@ -116,35 +116,22 @@ void NetworkImplDft::application_config() {
m_load_config.tensor_value_loader = decompressed_tensor_value_loader; m_load_config.tensor_value_loader = decompressed_tensor_value_loader;
} }
//! if device is LITE_NONE, the compnode information is stored in model //! if device is LITE_NONE, the compnode information is stored in model or
//! xpu in MegEngine
if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) { if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) {
//! currently not set Locator type because an atlas mgb model is a
//! cross-compnode graph
if (device_type == LiteDeviceType::LITE_ATLAS) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::ATLAS) { if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.device = m_compnode_locator.device; loc.type = m_compnode_locator.type;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
} }
};
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device; loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream; //! if user set the thread number and the compnode is multithread
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
m_nr_threads != 1) {
loc.stream = m_nr_threads; loc.stream = m_nr_threads;
}
};
} else { } else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { loc.stream = m_compnode_locator.stream;
loc = m_compnode_locator;
};
} }
};
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册