From b6376c83a0025ace8ec2215a04af95321a4f77a9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 1 Feb 2021 15:17:18 +0800 Subject: [PATCH] fix(mgb/comp_node): use extra physical device to decide whether reuse existed CompNodeImpl GitOrigin-RevId: 5dddc68a84ae2d1e31f948eda725e0d42f1ec1bd --- src/core/impl/comp_node/atlas/comp_node.cpp | 2 +- .../impl/comp_node/cambricon/comp_node.cpp | 2 +- src/core/impl/comp_node/comp_node.cpp | 4 +++ src/core/impl/comp_node/cpu/comp_node.cpp | 34 ++++++++----------- src/core/impl/comp_node/cuda/comp_node.cpp | 2 +- src/core/impl/comp_node/rocm/comp_node.cpp | 2 +- src/core/include/megbrain/comp_node.h | 26 ++++++++++++++ src/core/test/comp_node.cpp | 25 +++++++++++--- 8 files changed, 69 insertions(+), 28 deletions(-) diff --git a/src/core/impl/comp_node/atlas/comp_node.cpp b/src/core/impl/comp_node/atlas/comp_node.cpp index ab71e0ad3..7c03c8f00 100644 --- a/src/core/impl/comp_node/atlas/comp_node.cpp +++ b/src/core/impl/comp_node/atlas/comp_node.cpp @@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/cambricon/comp_node.cpp b/src/core/impl/comp_node/cambricon/comp_node.cpp index c7f9da351..9da1494a2 100644 --- a/src/core/impl/comp_node/cambricon/comp_node.cpp +++ b/src/core/impl/comp_node/cambricon/comp_node.cpp @@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon( for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/comp_node.cpp b/src/core/impl/comp_node/comp_node.cpp index c276cac98..725eb3832 100644 --- a/src/core/impl/comp_node/comp_node.cpp +++ b/src/core/impl/comp_node/comp_node.cpp @@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { void CompNode::Locator::set_unspec_device_type(DeviceType type) { mgb_assert(type != DeviceType::UNSPEC); + if (type != DeviceType::CPU && type != DeviceType::CUDA) { + mgb_log_warn("to resolve unspec device type as one except " + "CUDA and CPU may lead to unknown problems."); + } g_unspec_locator_type = type; } diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index c639e992a..140a8cec4 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -723,12 +723,13 @@ struct CpuCompNode::Pool { impl_storage[MAX_NR_COMP_NODE]; size_t nr_used_impl_storage = 0; - ThinHashMap, - std::unique_ptr> logical2impl; + std::unordered_map, + CompNode::LocatorPairHashKey::Hash> locator2impl; ThinHashMap, std::weak_ptr> physical2queue; - ThinHashMap, - std::unique_ptr> - logical2impl_multi_thread; + std::unordered_map, + CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread; ThinHashMap, std::weak_ptr> physical2queue_multithead; }; @@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, MGB_LOCK_GUARD(sm_pool->mtx); // encode both device ID and type into a int - int compact_logical_device = locator_logical.device; - mgb_assert(compact_logical_device >= -1 || - compact_logical_device <= Locator::DEVICE_CPU_DEFAULT); - if (locator_logical.type == CompNode::DeviceType::UNSPEC) { - compact_logical_device += std::numeric_limits::min() + 1; - mgb_assert(compact_logical_device < - Locator::DEVICE_MULTITHREAD_DEFAULT); - } else { + mgb_assert(locator_logical.device >= -1 || + locator_logical.device <= Locator::DEVICE_CPU_DEFAULT); + if (locator_logical.type != CompNode::DeviceType::UNSPEC) { mgb_assert(locator_logical.type == CompNode::DeviceType::CPU || locator_logical.type == CompNode::DeviceType::MULTITHREAD); } @@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->logical2impl[{compact_logical_device, - locator_logical.stream}]; + auto&& pimpl = sm_pool->locator2impl[{locator, + locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu comp nodes; max %d allowed", @@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, pqueue = std::make_shared(locator); pqueue_weak = pqueue; } - auto&& pimpl = sm_pool->logical2impl_multi_thread[{ - compact_logical_device, locator_logical.nr_threads}]; + auto&& pimpl = sm_pool->locator2impl_multi_thread[{ + locator, locator_logical}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu multithread comp nodes; max %d allowed", @@ -854,9 +850,9 @@ void CpuCompNode::sync_all() { return; MGB_LOCK_GUARD(sm_pool->mtx); - for (auto &&i: sm_pool->logical2impl) + for (auto &&i: sm_pool->locator2impl) i.second->sync(); - for (auto&& i : sm_pool->logical2impl_multi_thread) + for (auto&& i : sm_pool->locator2impl_multi_thread) i.second->sync(); } diff --git a/src/core/impl/comp_node/cuda/comp_node.cpp b/src/core/impl/comp_node/cuda/comp_node.cpp index dcfe549d4..838046879 100644 --- a/src/core/impl/comp_node/cuda/comp_node.cpp +++ b/src/core/impl/comp_node/cuda/comp_node.cpp @@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda( for (int i = 0; i < sd.nr_node; ++ i) { auto &&cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/impl/comp_node/rocm/comp_node.cpp b/src/core/impl/comp_node/rocm/comp_node.cpp index c7300b391..07a53d7a6 100644 --- a/src/core/impl/comp_node/rocm/comp_node.cpp +++ b/src/core/impl/comp_node/rocm/comp_node.cpp @@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator, for (int i = 0; i < sd.nr_node; ++i) { auto&& cur = sd.node[i]; if (cur.m_initialized) { - if (cur.m_locator_logical == locator_logical) { + if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) { return &cur; } } else { diff --git a/src/core/include/megbrain/comp_node.h b/src/core/include/megbrain/comp_node.h index 464235f83..2ff854495 100644 --- a/src/core/include/megbrain/comp_node.h +++ b/src/core/include/megbrain/comp_node.h @@ -168,6 +168,22 @@ class CompNode { return type == rhs.type && device == rhs.device && stream == rhs.stream; } + + }; + + struct LocatorPairHashKey { + Locator locator, locator_logical; + + bool operator==(const LocatorPairHashKey& rhs) const { + return locator == rhs.locator && locator_logical == rhs.locator_logical; + } + + struct Hash { + size_t operator()(const LocatorPairHashKey& k) const { + return hash_pair_combine(mgb::hash(k.locator), + mgb::hash(k.locator_logical)); + } + }; }; //! predefined special streams @@ -537,6 +553,7 @@ class CompNode { friend class CompNodeEnv; friend struct HashTrait; + friend struct HashTrait; friend class CompNodeImplHelper; public: CompNode(ImplBase* impl) : m_impl{impl} {} @@ -686,6 +703,15 @@ struct HashTrait { } }; +template<> +struct HashTrait { + static size_t eval(const CompNode::Locator &val) { + return static_cast(val.device) + + (static_cast(val.type) << 4) + + (static_cast(val.stream) << 8); + } +}; + namespace comp_node_detail { /*! diff --git a/src/core/test/comp_node.cpp b/src/core/test/comp_node.cpp index 1ad42db74..dc042b2e1 100644 --- a/src/core/test/comp_node.cpp +++ b/src/core/test/comp_node.cpp @@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) { CompNode::finalize(); using L = CompNode::Locator; auto orig_dt = L::parse("xpu").to_physical(), - orig_gpu = L::parse("gpux").to_physical(); + orig_gpu = L::parse("gpux").to_physical(), + orig_cpu = L::parse("cpux").to_physical(); constexpr auto CUDA = CompNode::DeviceType::CUDA; + constexpr auto CPU = CompNode::DeviceType::CPU; L::set_unspec_device_type(CUDA); - L::set_device_map(CUDA, -1, 2); - auto run = []() { - ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2")); + + auto run = [](int device) { + ASSERT_EQ(CompNode::load("xpu").locator(), + L::parse("gpu" + std::to_string(device))); + }; + auto run_cpu = [](int device) { + ASSERT_EQ(CompNode::load("cpux").locator(), + L::parse("cpu" + std::to_string(device))); }; MGB_TRY { - run(); + L::set_device_map(CUDA, -1, 2); + run(2); + L::set_device_map(CUDA, -1, 1); + run(1); + L::set_device_map(CPU, -1, 2); + run_cpu(2); + L::set_device_map(CPU, -1, 1); + run_cpu(1); } MGB_FINALLY({ L::set_unspec_device_type(orig_dt.type); L::set_device_map(CUDA, -1, orig_gpu.device); + L::set_device_map(CPU, -1, orig_cpu.device); }); CompNode::finalize(); } -- GitLab