diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index 0375ad5a0abdd79365f01097bfdeb0294162cf87..58c1460b57209de17cab4285d1305c133a1f22a2 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -112,10 +112,22 @@ void NetworkImplDft::application_config() { loc.type = m_compnode_locator.type; } loc.device = m_compnode_locator.device; + //! the user configured stream + auto stream = m_compnode_locator.stream; //! if user set the thread number and the compnode is multithread - if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD && - m_nr_threads != 1) { - loc.stream = m_nr_threads; + if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) { + if (m_nr_threads != 1) { + loc.nr_threads = m_nr_threads; + } + //! user set the stream to separate the different multithread + if (stream != 0) { + auto device = m_compnode_locator.device; + //! the device is also set by user, so combine them to one + //! int + if (device == -1) { + loc.device = stream; + } + } } else { loc.stream = m_compnode_locator.stream; } diff --git a/lite/src/network_impl_base.h b/lite/src/network_impl_base.h index 49117df4dbdda2f0b2e039e51d90087b5060bf25..b7b319f9f9ac05a5af1cd631699ccb6b058c6a74 100644 --- a/lite/src/network_impl_base.h +++ b/lite/src/network_impl_base.h @@ -148,7 +148,8 @@ public: virtual void set_device_id(int device_id) = 0; virtual int get_device_id() const = 0; virtual LiteBackend get_backend_type() const = 0; - //! set stream id, default stream id = 0 + //! set stream id, default stream id = 0, if there are multi compnode in a + //! model, set all the compnode stream to the stream_id virtual void set_stream_id(int stream_id) = 0; virtual int get_stream_id() const = 0; diff --git a/lite/test/test_network.cpp b/lite/test/test_network.cpp index 1ca8141877052b996a80b0337281b0853ceff653..e23b0af089309e1cc1b1b108fb187c77a8f33afd 100644 --- a/lite/test/test_network.cpp +++ b/lite/test/test_network.cpp @@ -1665,6 +1665,30 @@ TEST(TestNetWork, AtlasLoadAtlasDeviceInput) { TEST(TestNetWork, AtlasDeviceID) { load_device_id(LiteDeviceType::LITE_ATLAS, 1, "./model_atlas.mgb"); } + +TEST(TestNetWork, AtlasCrossCompnodeStreamID) { + auto thread = [](int stream_id) { + lite::Config config; + config.device_type = LiteDeviceType::LITE_ATLAS; + auto network = std::make_shared(config); + network->set_stream_id(stream_id); + network->load_model("./model_atlas.mgb"); + std::shared_ptr input_tensor = network->get_input_tensor(0); + std::shared_ptr output_tensor = network->get_output_tensor(0); + for (int i = 0; i < 10; i++) { + network->forward(); + network->wait(); + } + }; + std::thread t0(thread, 1); + std::thread t1(thread, 2); + std::thread t2(thread, 3); + std::thread t3(thread, 4); + t0.join(); + t1.join(); + t2.join(); + t3.join(); +} #endif #if MGB_CAMBRICON