diff --git a/src/core/impl/comp_node/atlas/comp_node.cpp b/src/core/impl/comp_node/atlas/comp_node.cpp index ab71e0ad31c353ff346f026f061e212b22f4f5c1..7c03c8f004e2ba9ed8aa2ee3fb0009636f0485cb 100644 --- a/src/core/impl/comp_node/atlas/comp_node.cpp +++ b/src/core/impl/comp_node/atlas/comp_node.cpp @@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/cambricon/comp_node.cpp b/src/core/impl/comp_node/cambricon/comp_node.cpp index c7f9da35188d46d8fcea2c53cae04bee41f32dc9..9da1494a2cdd0a489298c0b1532bcf8682342078 100644 --- a/src/core/impl/comp_node/cambricon/comp_node.cpp +++ b/src/core/impl/comp_node/cambricon/comp_node.cpp @@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon( for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/comp_node.cpp b/src/core/impl/comp_node/comp_node.cpp index c276cac980b03dcdb0e2c8148bb9c58b69eba57e..725eb3832762c4811f4725fba5a2ec795be55d7f 100644 --- a/src/core/impl/comp_node/comp_node.cpp +++ b/src/core/impl/comp_node/comp_node.cpp @@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { void CompNode::Locator::set_unspec_device_type(DeviceType type) { mgb_assert(type != DeviceType::UNSPEC); + if (type != DeviceType::CPU && type != DeviceType::CUDA) { + mgb_log_warn("to resolve unspec device type as one except " + "CUDA and CPU may lead to unknown problems."); + } g_unspec_locator_type = type; } diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index c639e992ad3e0403f37af071412af11c3da1ef22..140a8cec4783a59ec447139db4fe7a469bef19b1 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -723,12 +723,13 @@ struct CpuCompNode::Pool { impl_storage[MAX_NR_COMP_NODE]; size_t nr_used_impl_storage = 0; - ThinHashMap, - std::unique_ptr> logical2impl; + std::unordered_map, + CompNode::LocatorPairHashKey::Hash> locator2impl; ThinHashMap, std::weak_ptr> physical2queue; - ThinHashMap, - std::unique_ptr> - logical2impl_multi_thread; + std::unordered_map, + CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; ThinHashMap, std::weak_ptr> physical2queue_multithead; }; @@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, MGB_LOCK_GUARD(sm_pool->mtx); // encode both device ID and type into a int - int compact_logical_device = locator_logical.device; - mgb_assert(compact_logical_device >= -1 || - compact_logical_device <= Locator::DEVICE_CPU_DEFAULT); - if (locator_logical.type == CompNode::DeviceType::UNSPEC) { - compact_logical_device += std::numeric_limits::min() + 1; - mgb_assert(compact_logical_device < - Locator::DEVICE_MULTITHREAD_DEFAULT); - } else { + mgb_assert(locator_logical.device >= -1 || + locator_logical.device <= Locator::DEVICE_CPU_DEFAULT); + if (locator_logical.type != CompNode::DeviceType::UNSPEC) { mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || locator_logical.type == CompNode::DeviceType::MULTITHREAD); } @@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->logical2impl[{compact_logical_device, - locator_logical.stream}]; + auto&& pimpl = sm_pool->locator2impl[{locator, + locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu comp nodes; max %d allowed", @@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->logical2impl_multi_thread[{ - compact_logical_device, locator_logical.nr_threads}]; + auto&& pimpl = sm_pool->locator2impl_multi_thread[{ + locator, locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu multithread comp nodes; max %d allowed", @@ -854,9 +850,9 @@ void CpuCompNode::sync_all() { return; MGB_LOCK_GUARD(sm_pool->mtx); - for (auto &&i: sm_pool->logical2impl) + for (auto &&i: sm_pool->locator2impl) i.second->sync(); - for (auto&& i : sm_pool->logical2impl_multi_thread) + for (auto&& i : sm_pool->locator2impl_multi_thread) i.second->sync(); } diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index dcfe549d47b8ea2973e1f084e92ad8c366ef3537..83804687970637799dda9e57ab06bba7902b5e4d 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda( for (int i = 0; i < sd.nr_node; ++ i) { auto &&cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/rocm/comp_node.cpp b/src/core/impl/comp_node/rocm/comp_node.cpp index c7300b39167d33f461d81d1a8b370649ae557156..07a53d7a665f90f82fd1faa27df27b22cbebbfe6 100644 --- a/src/core/impl/comp_node/rocm/comp_node.cpp +++ b/src/core/impl/comp_node/rocm/comp_node.cpp @@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator, for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/include/megbrain/comp_node.h b/src/core/include/megbrain/comp_node.h index 464235f8316fd6d5bad498e3e1f649a57425be05..2ff854495db00056a2ce01160f4136f575660595 100644 --- a/src/core/include/megbrain/comp_node.h +++ b/src/core/include/megbrain/comp_node.h @@ -168,6 +168,22 @@ class CompNode { return type == rhs.type && device == rhs.device && stream == rhs.stream; } + + }; + + struct LocatorPairHashKey { + Locator locator, locator_logical; + + bool operator==(const LocatorPairHashKey& rhs) const { + return locator == rhs.locator && locator_logical == rhs.locator_logical; + } + + struct Hash { + size_t operator()(const LocatorPairHashKey& k) const { + return hash_pair_combine(mgb::hash(k.locator), + mgb::hash(k.locator_logical)); + } + }; }; //! predefined special streams @@ -537,6 +553,7 @@ class CompNode { friend class CompNodeEnv; friend struct HashTrait; + friend struct HashTrait; friend class CompNodeImplHelper; public: CompNode(ImplBase* impl) : m_impl{impl} {} @@ -686,6 +703,15 @@ struct HashTrait { } }; +template<> +struct HashTrait { + static size_t eval(const CompNode::Locator &val) { + return static_cast(val.device) + + (static_cast(val.type) << 4) + + (static_cast(val.stream) << 8); + } +}; + namespace comp_node_detail { /*! diff --git a/src/core/test/comp_node.cpp b/src/core/test/comp_node.cpp index 1ad42db74f0dbb1a6ee321e19de545d199d267fc..dc042b2e1cfced4450bf6ee908f3cc0cac645c54 100644 --- a/src/core/test/comp_node.cpp +++ b/src/core/test/comp_node.cpp @@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) { CompNode::finalize(); using L = CompNode::Locator; auto orig_dt = L::parse("xpu").to_physical(), - orig_gpu = L::parse("gpux").to_physical(); + orig_gpu = L::parse("gpux").to_physical(), + orig_cpu = L::parse("cpux").to_physical(); constexpr auto CUDA = CompNode::DeviceType::CUDA; + constexpr auto CPU = CompNode::DeviceType::CPU; L::set_unspec_device_type(CUDA); - L::set_device_map(CUDA, -1, 2); - auto run = []() { - ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2")); + + auto run = [](int device) { + ASSERT_EQ(CompNode::load("xpu").locator(), + L::parse("gpu" + std::to_string(device))); + }; + auto run_cpu = [](int device) { + ASSERT_EQ(CompNode::load("cpux").locator(), + L::parse("cpu" + std::to_string(device))); }; MGB_TRY { - run(); + L::set_device_map(CUDA, -1, 2); + run(2); + L::set_device_map(CUDA, -1, 1); + run(1); + L::set_device_map(CPU, -1, 2); + run_cpu(2); + L::set_device_map(CPU, -1, 1); + run_cpu(1); } MGB_FINALLY({ L::set_unspec_device_type(orig_dt.type); L::set_device_map(CUDA, -1, orig_gpu.device); + L::set_device_map(CPU, -1, orig_cpu.device); }); CompNode::finalize(); }