graph_rt.h 2.3 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/python/src/graph_rt.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
#pragma once

#include "./helper.h"

#include <memory>
#include <mutex>
#include <future>

#include "megbrain/graph.h"

template<typename T>
class GraphNodePtr {
    std::shared_ptr<mgb::cg::ComputingGraph> m_graph;
    T* m_node;
public:
    GraphNodePtr(T* node) :
        m_graph(node ? nullptr : node->owner_graph()->shared_from_this()),
        m_node(node) {}
    T* operator->() {return m_node;}
    T& operator*() {return *m_node;}
    operator bool() {return m_node;}
    T* get() {return m_node;}
};

PYBIND11_DECLARE_HOLDER_TYPE(T, GraphNodePtr<T>, true);

template<typename R>
class Rendezvous {
    std::mutex m_lock;
    int m_read_ahead = 0;
    std::promise<R> m_promise;
public:
    Rendezvous() = default;
    Rendezvous(const Rendezvous& rhs) = delete;
    Rendezvous(Rendezvous&& rhs) = default;
    Rendezvous& operator=(const Rendezvous& rhs) = delete;
    Rendezvous& operator=(Rendezvous&& rhs) {
        MGB_LOCK_GUARD(m_lock);
        m_read_ahead = rhs.m_read_ahead;
        m_promise = std::move(rhs.m_promise);
        return *this;
    }

    R get() {
        std::future<R> f;
        {
            MGB_LOCK_GUARD(m_lock);
            mgb_assert(m_read_ahead <= 0);
            mgb_assert(m_read_ahead >= -1);
            f = m_promise.get_future();
            if (m_read_ahead == -1) {
                m_promise = {};
            }
            ++m_read_ahead;
        }
        return f.get();
    }

    template<typename T>
    void set(T&& value) {
        MGB_LOCK_GUARD(m_lock);
        mgb_assert(m_read_ahead >= 0);
        mgb_assert(m_read_ahead <= 1);
        m_promise.set_value(std::forward<T>(value));
        if (m_read_ahead == 1) {
            m_promise = {};
        }
        --m_read_ahead;
    }

    void reset() {
        MGB_LOCK_GUARD(m_lock);
        m_promise = {};
        m_read_ahead = 0;
    }
};

void init_graph_rt(pybind11::module m);