diff --git a/imperative/src/impl/physical_tensor.cpp b/imperative/src/impl/physical_tensor.cpp index d17965fb26bc563ca0f336dac279b1661d4630d4..38be5c0315fc37b2b344cd6a80427aabf9f6e66f 100644 --- a/imperative/src/impl/physical_tensor.cpp +++ b/imperative/src/impl/physical_tensor.cpp @@ -59,10 +59,7 @@ class CompNodeSyncManager { void emplace(uint64_t t, A&& a) { map.emplace_hint(map.end(), t, std::forward(a)); } - void release(uint64_t t) { - auto it = map.upper_bound(t); - map.erase(map.begin(), it); - } + void release(uint64_t t) { map.erase(map.begin(), map.upper_bound(t)); } }; //! next virtual event @@ -99,6 +96,7 @@ class CompNodeSyncManager { return cndata.events.emplace_hint(cndata.events.end(), cndata.next++, e); } + // get a real event t' such that t <= t' std::pair get_event( CompNode cn, size_t cnid, uint64_t t, std::unique_lock& lock) { auto& cndata = m_cndata[cnid]; @@ -145,8 +143,11 @@ class CompNodeSyncManager { std::vector stats; std::vector todos; + std::vector updated; std::unique_lock lock(m_mtx); for (;;) { + updated.clear(); + updated.resize(m_cndata.size(), false); // copy events to a temporary storage so that we may unlock while polling stats.resize(m_cndata.size()); for (size_t cnid = 0; cnid < m_cndata.size(); ++cnid) { @@ -192,35 +193,63 @@ class CompNodeSyncManager { lock.lock(); + // update completed + for (auto [cnid, stat] : views::enumerate(stats)) { + if (stat.num_success == 0) { + continue; + } + auto t = stat.it->first; + auto& cndata = m_cndata[cnid]; + if (cndata.completed < t) { + cndata.completed = t; + updated[cnid] = true; + // also propagate by the transitive <= relation to ensure that + // we can safely delete ordering information without performance + // degradation even if some completion events are missed by our query + auto it = cndata.ordering.upper_bound(t); + if (it != cndata.ordering.begin()) { + it = std::prev(it); + for (auto [cnid, t] : views::enumerate(it->second)) { + auto& cndata = m_cndata[cnid]; + if (cndata.completed < t) { + cndata.completed = t; + updated[cnid] = true; + } + } + } + } + } + // release dev storage for (size_t receiver_cnid = 0; receiver_cnid < m_cndata.size(); ++receiver_cnid) { for (size_t releaser_cnid = 0; releaser_cnid < m_cndata[receiver_cnid].release_queues.size(); ++releaser_cnid) { - if (releaser_cnid >= stats.size() || - stats[releaser_cnid].num_success == 0) { + if (!(releaser_cnid < updated.size() && updated[releaser_cnid])) { continue; } auto& q = m_cndata[receiver_cnid].release_queues[releaser_cnid]; - q.release(stats[releaser_cnid].it->first); + q.release(m_cndata[releaser_cnid].completed); } } - for (size_t cnid = 0; cnid < stats.size(); ++cnid) { - if (stats[cnid].num_success == 0) { + for (size_t cnid = 0; cnid < updated.size(); ++cnid) { + if (!updated[cnid]) { continue; } auto& cndata = m_cndata[cnid]; - auto it = stats[cnid].it; - auto t = it->first; - // update completed - cndata.completed = t; + auto t = cndata.completed; // release host storage cndata.host_release_queue.release(t); // remove completed events - auto& events = cndata.events; - events.erase(events.begin(), std::next(it)); + [&](auto& map) { + map.erase(map.begin(), map.upper_bound(t)); + }(cndata.events); + // delete ordering information + [&](auto& map) { + map.erase(map.begin(), map.upper_bound(t)); + }(cndata.ordering); } using namespace std::literals; @@ -287,6 +316,17 @@ public: auto waitee_id = get_cnid_unsafe(waitee); auto& waiter_data = m_cndata.at(waiter_id); auto& waitee_data = m_cndata.at(waitee_id); + + if (t <= waitee_data.completed) { + return; + } + + if (waiter_data.ordering.size() && + waitee_id < waiter_data.ordering.rbegin()->second.size() && + t <= waiter_data.ordering.rbegin()->second[waitee_id]) { + return; + } + auto [t_waitee, e] = get_event(waitee, waitee_id, t, lock); // DO NOT unlock around this line! Event* could be invalidated! @@ -301,7 +341,7 @@ public: ordering[waitee_id] = t_waitee; ordering[waiter_id] = t_waiter; { - auto it = waitee_data.ordering.lower_bound(t_waitee); + auto it = waitee_data.ordering.upper_bound(t_waitee); if (it != waitee_data.ordering.begin()) { for (auto [a, b] : views::zip(ordering, std::prev(it)->second)) { static_assert(std::is_lvalue_reference_v);