提交 6bb9772c 编写于 作者: M Megvii Engine Team

feat(lite): support set the consistent stream id when multiconpnode

GitOrigin-RevId: 2eed8c2fbd6079c9cfc24e40e2746361c99f9796
上级 ba9f67eb
......@@ -112,10 +112,22 @@ void NetworkImplDft::application_config() {
loc.type = m_compnode_locator.type;
}
loc.device = m_compnode_locator.device;
//! the user configured stream
auto stream = m_compnode_locator.stream;
//! if user set the thread number and the compnode is multithread
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD &&
m_nr_threads != 1) {
loc.stream = m_nr_threads;
if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
if (m_nr_threads != 1) {
loc.nr_threads = m_nr_threads;
}
//! user set the stream to separate the different multithread
if (stream != 0) {
auto device = m_compnode_locator.device;
//! the device is also set by user, so combine them to one
//! int
if (device == -1) {
loc.device = stream;
}
}
} else {
loc.stream = m_compnode_locator.stream;
}
......
......@@ -148,7 +148,8 @@ public:
virtual void set_device_id(int device_id) = 0;
virtual int get_device_id() const = 0;
virtual LiteBackend get_backend_type() const = 0;
//! set stream id, default stream id = 0
//! set stream id, default stream id = 0, if there are multi compnode in a
//! model, set all the compnode stream to the stream_id
virtual void set_stream_id(int stream_id) = 0;
virtual int get_stream_id() const = 0;
......
......@@ -1665,6 +1665,30 @@ TEST(TestNetWork, AtlasLoadAtlasDeviceInput) {
TEST(TestNetWork, AtlasDeviceID) {
load_device_id(LiteDeviceType::LITE_ATLAS, 1, "./model_atlas.mgb");
}
TEST(TestNetWork, AtlasCrossCompnodeStreamID) {
auto thread = [](int stream_id) {
lite::Config config;
config.device_type = LiteDeviceType::LITE_ATLAS;
auto network = std::make_shared<lite::Network>(config);
network->set_stream_id(stream_id);
network->load_model("./model_atlas.mgb");
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
for (int i = 0; i < 10; i++) {
network->forward();
network->wait();
}
};
std::thread t0(thread, 1);
std::thread t1(thread, 2);
std::thread t2(thread, 3);
std::thread t3(thread, 4);
t0.join();
t1.join();
t2.join();
t3.join();
}
#endif
#if MGB_CAMBRICON
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册