#pragma once #include "./helper.h" #include #include #include #include "megbrain/graph.h" #include "megbrain/plugin/opr_footprint.h" template class GraphNodePtr { std::shared_ptr m_graph; T* m_node; public: GraphNodePtr(T* node) : m_graph(node ? node->owner_graph()->shared_from_this() : nullptr), m_node(node) {} T* operator->() { return m_node; } T& operator*() { return *m_node; } operator bool() { return m_node; } T* get() { return m_node; } }; PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr, true); class RendezvousBase { public: virtual ~RendezvousBase() = default; virtual void set_exception(std::exception_ptr p) = 0; }; template class Rendezvous : public RendezvousBase { std::mutex m_lock; int m_read_ahead = 0; bool m_drop_next = false; std::promise m_promise; Rendezvous() = default; struct Factory { template static auto make_rendezvous(Args&&... args) { auto ptr = new Rendezvous{std::forward(args)...}; return std::shared_ptr>(ptr); } }; public: Rendezvous(const Rendezvous& rhs) = delete; Rendezvous(Rendezvous&& rhs) = delete; Rendezvous& operator=(const Rendezvous& rhs) = delete; template static auto make(Args&&... args) { return Factory::make_rendezvous(std::forward(args)...); } R get() { std::future f; { MGB_LOCK_GUARD(m_lock); mgb_assert(m_read_ahead <= 0); mgb_assert(m_read_ahead >= -1); f = m_promise.get_future(); if (m_read_ahead == -1) { m_promise = {}; } ++m_read_ahead; } return f.get(); } void drop() { MGB_LOCK_GUARD(m_lock); mgb_assert(m_read_ahead <= 0); mgb_assert(m_read_ahead >= -1); if (m_read_ahead == -1) { m_promise = {}; } else { m_drop_next = true; } ++m_read_ahead; } template void set(T&& value) { MGB_LOCK_GUARD(m_lock); mgb_assert(m_read_ahead >= 0); mgb_assert(m_read_ahead <= 1); if (m_drop_next) { m_drop_next = false; } else { m_promise.set_value(std::forward(value)); } if (m_read_ahead == 1) { m_promise = {}; } --m_read_ahead; } void reset() { MGB_LOCK_GUARD(m_lock); m_promise = {}; m_read_ahead = 0; m_drop_next = false; } void set_exception(std::exception_ptr e) { if (e) { MGB_LOCK_GUARD(m_lock); if (m_read_ahead >= 0) { mgb_assert(m_read_ahead <= 1); if (m_drop_next) { m_drop_next = false; } else { m_promise.set_exception(e); } if (m_read_ahead == 1) { m_promise = {}; } --m_read_ahead; } else { mgb_assert(m_read_ahead == -1); // TODO: maybe exception should be ignored // if value was already set ? m_promise.set_exception(e); } } } }; void init_graph_rt(pybind11::module m);