提交 b6376c83 编写于 作者: M Megvii Engine Team

fix(mgb/comp_node): use extra physical device to decide whether reuse existed CompNodeImpl

GitOrigin-RevId: 5dddc68a84ae2d1e31f948eda725e0d42f1ec1bd
上级 ccbc6761
......@@ -420,7 +420,7 @@ CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator,
for (int i = 0; i < sd.nr_node; ++i) {
auto&& cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator_logical == locator_logical) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
......
......@@ -604,7 +604,7 @@ CompNode::Impl* CambriconCompNode::load_cambricon(
for (int i = 0; i < sd.nr_node; ++i) {
auto&& cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator_logical == locator_logical) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
......
......@@ -250,6 +250,10 @@ void CompNode::Locator::set_device_map(DeviceType type, int from, int to) {
void CompNode::Locator::set_unspec_device_type(DeviceType type) {
mgb_assert(type != DeviceType::UNSPEC);
if (type != DeviceType::CPU && type != DeviceType::CUDA) {
mgb_log_warn("to resolve unspec device type as one except "
"CUDA and CPU may lead to unknown problems.");
}
g_unspec_locator_type = type;
}
......
......@@ -723,12 +723,13 @@ struct CpuCompNode::Pool {
impl_storage[MAX_NR_COMP_NODE];
size_t nr_used_impl_storage = 0;
ThinHashMap<std::pair<int, int>,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>> logical2impl;
std::unordered_map<CompNode::LocatorPairHashKey,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>,
CompNode::LocatorPairHashKey::Hash> locator2impl;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>> physical2queue;
ThinHashMap<std::pair<int, int>,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>>
logical2impl_multi_thread;
std::unordered_map<CompNode::LocatorPairHashKey,
std::unique_ptr<CpuCompNodeImpl, CpuCompNodeImplDeleter>,
CompNode::LocatorPairHashKey::Hash> locator2impl_multi_thread;
ThinHashMap<std::pair<int, int>, std::weak_ptr<WorkerQueue>>
physical2queue_multithead;
};
......@@ -792,14 +793,9 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
MGB_LOCK_GUARD(sm_pool->mtx);
// encode both device ID and type into a int
int compact_logical_device = locator_logical.device;
mgb_assert(compact_logical_device >= -1 ||
compact_logical_device <= Locator::DEVICE_CPU_DEFAULT);
if (locator_logical.type == CompNode::DeviceType::UNSPEC) {
compact_logical_device += std::numeric_limits<int>::min() + 1;
mgb_assert(compact_logical_device <
Locator::DEVICE_MULTITHREAD_DEFAULT);
} else {
mgb_assert(locator_logical.device >= -1 ||
locator_logical.device <= Locator::DEVICE_CPU_DEFAULT);
if (locator_logical.type != CompNode::DeviceType::UNSPEC) {
mgb_assert(locator_logical.type == CompNode::DeviceType::CPU ||
locator_logical.type == CompNode::DeviceType::MULTITHREAD);
}
......@@ -811,8 +807,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue;
}
auto&& pimpl = sm_pool->logical2impl[{compact_logical_device,
locator_logical.stream}];
auto&& pimpl = sm_pool->locator2impl[{locator,
locator_logical}];
if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu comp nodes; max %d allowed",
......@@ -833,8 +829,8 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue;
}
auto&& pimpl = sm_pool->logical2impl_multi_thread[{
compact_logical_device, locator_logical.nr_threads}];
auto&& pimpl = sm_pool->locator2impl_multi_thread[{
locator, locator_logical}];
if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed",
......@@ -854,9 +850,9 @@ void CpuCompNode::sync_all() {
return;
MGB_LOCK_GUARD(sm_pool->mtx);
for (auto &&i: sm_pool->logical2impl)
for (auto &&i: sm_pool->locator2impl)
i.second->sync();
for (auto&& i : sm_pool->logical2impl_multi_thread)
for (auto&& i : sm_pool->locator2impl_multi_thread)
i.second->sync();
}
......
......@@ -718,7 +718,7 @@ CompNode::Impl* CudaCompNode::load_cuda(
for (int i = 0; i < sd.nr_node; ++ i) {
auto &&cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator_logical == locator_logical) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
......
......@@ -606,7 +606,7 @@ CompNode::Impl* ROCmCompNode::load_rocm(const Locator& locator,
for (int i = 0; i < sd.nr_node; ++i) {
auto&& cur = sd.node[i];
if (cur.m_initialized) {
if (cur.m_locator_logical == locator_logical) {
if (cur.m_locator == locator && cur.m_locator_logical == locator_logical) {
return &cur;
}
} else {
......
......@@ -168,6 +168,22 @@ class CompNode {
return type == rhs.type && device == rhs.device &&
stream == rhs.stream;
}
};
struct LocatorPairHashKey {
Locator locator, locator_logical;
bool operator==(const LocatorPairHashKey& rhs) const {
return locator == rhs.locator && locator_logical == rhs.locator_logical;
}
struct Hash {
size_t operator()(const LocatorPairHashKey& k) const {
return hash_pair_combine(mgb::hash(k.locator),
mgb::hash(k.locator_logical));
}
};
};
//! predefined special streams
......@@ -537,6 +553,7 @@ class CompNode {
friend class CompNodeEnv;
friend struct HashTrait<CompNode>;
friend struct HashTrait<CompNode::Locator>;
friend class CompNodeImplHelper;
public:
CompNode(ImplBase* impl) : m_impl{impl} {}
......@@ -686,6 +703,15 @@ struct HashTrait<CompNode> {
}
};
template<>
struct HashTrait<CompNode::Locator> {
static size_t eval(const CompNode::Locator &val) {
return static_cast<size_t>(val.device)
+ (static_cast<size_t>(val.type) << 4)
+ (static_cast<size_t>(val.stream) << 8);
}
};
namespace comp_node_detail {
/*!
......
......@@ -86,19 +86,34 @@ TEST(TestCompNode, SetDefaultDev) {
CompNode::finalize();
using L = CompNode::Locator;
auto orig_dt = L::parse("xpu").to_physical(),
orig_gpu = L::parse("gpux").to_physical();
orig_gpu = L::parse("gpux").to_physical(),
orig_cpu = L::parse("cpux").to_physical();
constexpr auto CUDA = CompNode::DeviceType::CUDA;
constexpr auto CPU = CompNode::DeviceType::CPU;
L::set_unspec_device_type(CUDA);
L::set_device_map(CUDA, -1, 2);
auto run = []() {
ASSERT_EQ(CompNode::load("xpu").locator(), L::parse("gpu2"));
auto run = [](int device) {
ASSERT_EQ(CompNode::load("xpu").locator(),
L::parse("gpu" + std::to_string(device)));
};
auto run_cpu = [](int device) {
ASSERT_EQ(CompNode::load("cpux").locator(),
L::parse("cpu" + std::to_string(device)));
};
MGB_TRY {
run();
L::set_device_map(CUDA, -1, 2);
run(2);
L::set_device_map(CUDA, -1, 1);
run(1);
L::set_device_map(CPU, -1, 2);
run_cpu(2);
L::set_device_map(CPU, -1, 1);
run_cpu(1);
} MGB_FINALLY({
L::set_unspec_device_type(orig_dt.type);
L::set_device_map(CUDA, -1, orig_gpu.device);
L::set_device_map(CPU, -1, orig_cpu.device);
});
CompNode::finalize();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册