diff --git a/lite/include/lite/common_enum_c.h b/lite/include/lite/common_enum_c.h index 23b1894a04451b0aee62d159b8aa7d6a402073c2..656f6ee78c17b88b623757ef26952a29aa53a380 100644 --- a/lite/include/lite/common_enum_c.h +++ b/lite/include/lite/common_enum_c.h @@ -33,7 +33,7 @@ typedef enum { LITE_NPU = 4, LITE_CAMBRICON = 5, //! when the device information is set in model, so set LITE_DEVICE_DEFAULT - //! in lite + //! in lite, which equal to xpu in megengine LITE_DEVICE_DEFAULT = 6, } LiteDeviceType; diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index b844f5619b48b7fa853d5d671518d82253823db2..6e5ef95a6cb390cee2d6ee9b30482f86dd210b3a 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -116,35 +116,22 @@ void NetworkImplDft::application_config() { m_load_config.tensor_value_loader = decompressed_tensor_value_loader; } - //! if device is LITE_NONE, the compnode information is stored in model + //! if device is LITE_NONE, the compnode information is stored in model or + //! xpu in MegEngine if (device_type != LiteDeviceType::LITE_DEVICE_DEFAULT) { - //! currently not set Locator type because an atlas mgb model is a - //! cross-compnode graph - if (device_type == LiteDeviceType::LITE_ATLAS) { - m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { - if (loc.type == mgb::CompNode::DeviceType::ATLAS) { - loc.device = m_compnode_locator.device; - loc.stream = m_compnode_locator.stream; - } else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { - loc.stream = m_nr_threads; - } - }; - //! currently not set Locator type because a cambricon mgb model is a - //! cross-compnode graph - } else if (device_type == LiteDeviceType::LITE_CAMBRICON) { - m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { - if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) { - loc.device = m_compnode_locator.device; - loc.stream = m_compnode_locator.stream; - } else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { - loc.stream = m_nr_threads; - } - }; - } else { - m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { - loc = m_compnode_locator; - }; - } + m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { + if (loc.type == mgb::CompNode::DeviceType::UNSPEC) { + loc.type = m_compnode_locator.type; + } + loc.device = m_compnode_locator.device; + //! if user set the thread number and the compnode is multithread + if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD && + m_nr_threads != 1) { + loc.stream = m_nr_threads; + } else { + loc.stream = m_compnode_locator.stream; + } + }; } }