提交 4fcf8b49 编写于 作者: M Megvii Engine Team

chore(mge): improve mem borrow

GitOrigin-RevId: 599562260cc1a668788a173b09bdd42a6a6615ca
上级 a7ca0588
...@@ -59,10 +59,7 @@ class CompNodeSyncManager { ...@@ -59,10 +59,7 @@ class CompNodeSyncManager {
void emplace(uint64_t t, A&& a) { void emplace(uint64_t t, A&& a) {
map.emplace_hint(map.end(), t, std::forward<A>(a)); map.emplace_hint(map.end(), t, std::forward<A>(a));
} }
void release(uint64_t t) { void release(uint64_t t) { map.erase(map.begin(), map.upper_bound(t)); }
auto it = map.upper_bound(t);
map.erase(map.begin(), it);
}
}; };
//! next virtual event //! next virtual event
...@@ -99,6 +96,7 @@ class CompNodeSyncManager { ...@@ -99,6 +96,7 @@ class CompNodeSyncManager {
return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e); return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e);
} }
// get a real event t' such that t <= t'
std::pair<uint64_t, CompNode::Event*> get_event( std::pair<uint64_t, CompNode::Event*> get_event(
CompNode cn, size_t cnid, uint64_t t, std::unique_lock<std::mutex>& lock) { CompNode cn, size_t cnid, uint64_t t, std::unique_lock<std::mutex>& lock) {
auto& cndata = m_cndata[cnid]; auto& cndata = m_cndata[cnid];
...@@ -145,8 +143,11 @@ class CompNodeSyncManager { ...@@ -145,8 +143,11 @@ class CompNodeSyncManager {
std::vector<Stat> stats; std::vector<Stat> stats;
std::vector<Item> todos; std::vector<Item> todos;
std::vector<bool> updated;
std::unique_lock lock(m_mtx); std::unique_lock lock(m_mtx);
for (;;) { for (;;) {
updated.clear();
updated.resize(m_cndata.size(), false);
// copy events to a temporary storage so that we may unlock while polling // copy events to a temporary storage so that we may unlock while polling
stats.resize(m_cndata.size()); stats.resize(m_cndata.size());
for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) { for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) {
...@@ -192,35 +193,63 @@ class CompNodeSyncManager { ...@@ -192,35 +193,63 @@ class CompNodeSyncManager {
lock.lock(); lock.lock();
// update completed
for (auto [cnid, stat] : views::enumerate(stats)) {
if (stat.num_success == 0) {
continue;
}
auto t = stat.it->first;
auto& cndata = m_cndata[cnid];
if (cndata.completed < t) {
cndata.completed = t;
updated[cnid] = true;
// also propagate by the transitive <= relation to ensure that
// we can safely delete ordering information without performance
// degradation even if some completion events are missed by our query
auto it = cndata.ordering.upper_bound(t);
if (it != cndata.ordering.begin()) {
it = std::prev(it);
for (auto [cnid, t] : views::enumerate(it->second)) {
auto& cndata = m_cndata[cnid];
if (cndata.completed < t) {
cndata.completed = t;
updated[cnid] = true;
}
}
}
}
}
// release dev storage // release dev storage
for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size(); for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size();
++receiver_cnid) { ++receiver_cnid) {
for (size_t releaser_cnid = 0; for (size_t releaser_cnid = 0;
releaser_cnid < m_cndata[receiver_cnid].release_queues.size(); releaser_cnid < m_cndata[receiver_cnid].release_queues.size();
++releaser_cnid) { ++releaser_cnid) {
if (releaser_cnid >= stats.size() || if (!(releaser_cnid < updated.size() && updated[releaser_cnid])) {
stats[releaser_cnid].num_success == 0) {
continue; continue;
} }
auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid]; auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid];
q.release(stats[releaser_cnid].it->first); q.release(m_cndata[releaser_cnid].completed);
} }
} }
for (size_t cnid = 0; cnid < stats.size(); ++cnid) { for (size_t cnid = 0; cnid < updated.size(); ++cnid) {
if (stats[cnid].num_success == 0) { if (!updated[cnid]) {
continue; continue;
} }
auto& cndata = m_cndata[cnid]; auto& cndata = m_cndata[cnid];
auto it = stats[cnid].it; auto t = cndata.completed;
auto t = it->first;
// update completed
cndata.completed = t;
// release host storage // release host storage
cndata.host_release_queue.release(t); cndata.host_release_queue.release(t);
// remove completed events // remove completed events
auto& events = cndata.events; [&](auto& map) {
events.erase(events.begin(), std::next(it)); map.erase(map.begin(), map.upper_bound(t));
}(cndata.events);
// delete ordering information
[&](auto& map) {
map.erase(map.begin(), map.upper_bound(t));
}(cndata.ordering);
} }
using namespace std::literals; using namespace std::literals;
...@@ -287,6 +316,17 @@ public: ...@@ -287,6 +316,17 @@ public:
auto waitee_id = get_cnid_unsafe(waitee); auto waitee_id = get_cnid_unsafe(waitee);
auto& waiter_data = m_cndata.at(waiter_id); auto& waiter_data = m_cndata.at(waiter_id);
auto& waitee_data = m_cndata.at(waitee_id); auto& waitee_data = m_cndata.at(waitee_id);
if (t <= waitee_data.completed) {
return;
}
if (waiter_data.ordering.size() &&
waitee_id < waiter_data.ordering.rbegin()->second.size() &&
t <= waiter_data.ordering.rbegin()->second[waitee_id]) {
return;
}
auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock); auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock);
// DO NOT unlock around this line! Event* could be invalidated! // DO NOT unlock around this line! Event* could be invalidated!
...@@ -301,7 +341,7 @@ public: ...@@ -301,7 +341,7 @@ public:
ordering[waitee_id] = t_waitee; ordering[waitee_id] = t_waitee;
ordering[waiter_id] = t_waiter; ordering[waiter_id] = t_waiter;
{ {
auto it = waitee_data.ordering.lower_bound(t_waitee); auto it = waitee_data.ordering.upper_bound(t_waitee);
if (it != waitee_data.ordering.begin()) { if (it != waitee_data.ordering.begin()) {
for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) { for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) {
static_assert(std::is_lvalue_reference_v<decltype(a)>); static_assert(std::is_lvalue_reference_v<decltype(a)>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册