graph_rt.h 3.5 KB
Newer Older
1 2 3 4
#pragma once

#include "./helper.h"

M
Megvii Engine Team 已提交
5
#include <future>
6 7 8
#include <memory>
#include <mutex>
#include "megbrain/graph.h"
M
Megvii Engine Team 已提交
9
#include "megbrain/plugin/opr_footprint.h"
10

11 12
namespace py = pybind11;
extern py::object Py_Varnode;
13
extern const std::unique_ptr<mgb::OprFootprint> _imperative_sm_opr_footprint_ptr;
M
Megvii Engine Team 已提交
14
template <typename T>
15 16 17
class GraphNodePtr {
    std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
    T* m_node;
M
Megvii Engine Team 已提交
18

19
public:
M
Megvii Engine Team 已提交
20 21 22 23 24 25 26
    GraphNodePtr(T* node)
            : m_graph(node ? node->owner_graph()->shared_from_this() : nullptr),
              m_node(node) {}
    T* operator->() { return m_node; }
    T& operator*() { return *m_node; }
    operator bool() { return m_node; }
    T* get() { return m_node; }
27 28 29 30
};

PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr<T>, true);

31 32 33 34 35 36
class RendezvousBase {
public:
    virtual ~RendezvousBase() = default;
    virtual void set_exception(std::exception_ptr p) = 0;
};

M
Megvii Engine Team 已提交
37 38
template <typename R>
class Rendezvous : public RendezvousBase {
39 40
    std::mutex m_lock;
    int m_read_ahead = 0;
M
Megvii Engine Team 已提交
41
    bool m_drop_next = false;
42 43
    std::promise<R> m_promise;
    Rendezvous() = default;
44
    struct Factory {
M
Megvii Engine Team 已提交
45 46
        template <typename... Args>
        static auto make_rendezvous(Args&&... args) {
47 48 49 50
            auto ptr = new Rendezvous<R>{std::forward(args)...};
            return std::shared_ptr<Rendezvous<R>>(ptr);
        }
    };
M
Megvii Engine Team 已提交
51

52
public:
53
    Rendezvous(const Rendezvous& rhs) = delete;
54
    Rendezvous(Rendezvous&& rhs) = delete;
55 56
    Rendezvous& operator=(const Rendezvous& rhs) = delete;

M
Megvii Engine Team 已提交
57 58
    template <typename... Args>
    static auto make(Args&&... args) {
59 60 61
        return Factory::make_rendezvous(std::forward<Args>(args)...);
    }

62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    R get() {
        std::future<R> f;
        {
            MGB_LOCK_GUARD(m_lock);
            mgb_assert(m_read_ahead <= 0);
            mgb_assert(m_read_ahead >= -1);
            f = m_promise.get_future();
            if (m_read_ahead == -1) {
                m_promise = {};
            }
            ++m_read_ahead;
        }
        return f.get();
    }

M
Megvii Engine Team 已提交
77 78 79 80 81 82 83 84 85 86 87 88
    void drop() {
        MGB_LOCK_GUARD(m_lock);
        mgb_assert(m_read_ahead <= 0);
        mgb_assert(m_read_ahead >= -1);
        if (m_read_ahead == -1) {
            m_promise = {};
        } else {
            m_drop_next = true;
        }
        ++m_read_ahead;
    }

M
Megvii Engine Team 已提交
89
    template <typename T>
90 91 92 93
    void set(T&& value) {
        MGB_LOCK_GUARD(m_lock);
        mgb_assert(m_read_ahead >= 0);
        mgb_assert(m_read_ahead <= 1);
M
Megvii Engine Team 已提交
94 95 96 97 98
        if (m_drop_next) {
            m_drop_next = false;
        } else {
            m_promise.set_value(std::forward<T>(value));
        }
99 100 101 102 103 104 105 106 107 108
        if (m_read_ahead == 1) {
            m_promise = {};
        }
        --m_read_ahead;
    }

    void reset() {
        MGB_LOCK_GUARD(m_lock);
        m_promise = {};
        m_read_ahead = 0;
M
Megvii Engine Team 已提交
109
        m_drop_next = false;
110
    }
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133

    void set_exception(std::exception_ptr e) {
        if (e) {
            MGB_LOCK_GUARD(m_lock);
            if (m_read_ahead >= 0) {
                mgb_assert(m_read_ahead <= 1);
                if (m_drop_next) {
                    m_drop_next = false;
                } else {
                    m_promise.set_exception(e);
                }
                if (m_read_ahead == 1) {
                    m_promise = {};
                }
                --m_read_ahead;
            } else {
                mgb_assert(m_read_ahead == -1);
                // TODO: maybe exception should be ignored
                // if value was already set ?
                m_promise.set_exception(e);
            }
        }
    }
134 135 136
};

void init_graph_rt(pybind11::module m);