提交 4fcf8b49 编写于 作者: M Megvii Engine Team

chore(mge): improve mem borrow

GitOrigin-RevId: 599562260cc1a668788a173b09bdd42a6a6615ca
上级 a7ca0588
......@@ -59,10 +59,7 @@ class CompNodeSyncManager {
void emplace(uint64_t t, A&& a) {
map.emplace_hint(map.end(), t, std::forward<A>(a));
}
void release(uint64_t t) {
auto it = map.upper_bound(t);
map.erase(map.begin(), it);
}
void release(uint64_t t) { map.erase(map.begin(), map.upper_bound(t)); }
};
//! next virtual event
......@@ -99,6 +96,7 @@ class CompNodeSyncManager {
return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e);
}
// get a real event t' such that t <= t'
std::pair<uint64_t, CompNode::Event*> get_event(
CompNode cn, size_t cnid, uint64_t t, std::unique_lock<std::mutex>& lock) {
auto& cndata = m_cndata[cnid];
......@@ -145,8 +143,11 @@ class CompNodeSyncManager {
std::vector<Stat> stats;
std::vector<Item> todos;
std::vector<bool> updated;
std::unique_lock lock(m_mtx);
for (;;) {
updated.clear();
updated.resize(m_cndata.size(), false);
// copy events to a temporary storage so that we may unlock while polling
stats.resize(m_cndata.size());
for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) {
......@@ -192,35 +193,63 @@ class CompNodeSyncManager {
lock.lock();
// update completed
for (auto [cnid, stat] : views::enumerate(stats)) {
if (stat.num_success == 0) {
continue;
}
auto t = stat.it->first;
auto& cndata = m_cndata[cnid];
if (cndata.completed < t) {
cndata.completed = t;
updated[cnid] = true;
// also propagate by the transitive <= relation to ensure that
// we can safely delete ordering information without performance
// degradation even if some completion events are missed by our query
auto it = cndata.ordering.upper_bound(t);
if (it != cndata.ordering.begin()) {
it = std::prev(it);
for (auto [cnid, t] : views::enumerate(it->second)) {
auto& cndata = m_cndata[cnid];
if (cndata.completed < t) {
cndata.completed = t;
updated[cnid] = true;
}
}
}
}
}
// release dev storage
for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size();
++receiver_cnid) {
for (size_t releaser_cnid = 0;
releaser_cnid < m_cndata[receiver_cnid].release_queues.size();
++releaser_cnid) {
if (releaser_cnid >= stats.size() ||
stats[releaser_cnid].num_success == 0) {
if (!(releaser_cnid < updated.size() && updated[releaser_cnid])) {
continue;
}
auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid];
q.release(stats[releaser_cnid].it->first);
q.release(m_cndata[releaser_cnid].completed);
}
}
for (size_t cnid = 0; cnid < stats.size(); ++cnid) {
if (stats[cnid].num_success == 0) {
for (size_t cnid = 0; cnid < updated.size(); ++cnid) {
if (!updated[cnid]) {
continue;
}
auto& cndata = m_cndata[cnid];
auto it = stats[cnid].it;
auto t = it->first;
// update completed
cndata.completed = t;
auto t = cndata.completed;
// release host storage
cndata.host_release_queue.release(t);
// remove completed events
auto& events = cndata.events;
events.erase(events.begin(), std::next(it));
[&](auto& map) {
map.erase(map.begin(), map.upper_bound(t));
}(cndata.events);
// delete ordering information
[&](auto& map) {
map.erase(map.begin(), map.upper_bound(t));
}(cndata.ordering);
}
using namespace std::literals;
......@@ -287,6 +316,17 @@ public:
auto waitee_id = get_cnid_unsafe(waitee);
auto& waiter_data = m_cndata.at(waiter_id);
auto& waitee_data = m_cndata.at(waitee_id);
if (t <= waitee_data.completed) {
return;
}
if (waiter_data.ordering.size() &&
waitee_id < waiter_data.ordering.rbegin()->second.size() &&
t <= waiter_data.ordering.rbegin()->second[waitee_id]) {
return;
}
auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock);
// DO NOT unlock around this line! Event* could be invalidated!
......@@ -301,7 +341,7 @@ public:
ordering[waitee_id] = t_waitee;
ordering[waiter_id] = t_waiter;
{
auto it = waitee_data.ordering.lower_bound(t_waitee);
auto it = waitee_data.ordering.upper_bound(t_waitee);
if (it != waitee_data.ordering.begin()) {
for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) {
static_assert(std::is_lvalue_reference_v<decltype(a)>);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册