dispatcher.cpp 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
#include "./dispatcher.h"
#include "./pyext17.h"
#include "megbrain/utils/hash.h"
#include "megbrain/utils/small_vector.h"

#include <unordered_map>
#include <structmember.h>

namespace py = pybind11;
namespace pyx = pyext17;

namespace {

struct Handler {
    PyObject* func; // borrowed
    bool enabled;

    Handler() = default;
    Handler(PyObject* func_, bool enable = true) : func(func_), enabled(enable) {}
};

using FastSig = mgb::SmallVector<void*, 8>;
using MRO = std::vector<Handler*>;

struct Frame {
    MRO* mro;
    size_t mro_offset;

    Frame() = default;
    Frame(MRO* mro_, size_t mro_offset_ = 0) : mro(mro_), mro_offset(mro_offset_) {}
};

struct FastSigHash {
    size_t operator()(const FastSig& sig) const {
        auto* ptr = &sig.front();
        return mgb::XXHash()
            .update(ptr, sig.size() * sizeof(FastSig::value_type))
            .digest();
    }
};

struct ObjectIdHash : std::hash<void*> {
    size_t operator()(const py::handle& h) const {
        return std::hash<void*>::operator()(h.ptr());
    }
};

struct Dispatcher {
    std::unordered_map<FastSig, std::unique_ptr<MRO>, FastSigHash> cache;
    std::vector<Frame> stack;
    std::unordered_map<py::object, std::unique_ptr<Handler>, ObjectIdHash> registry;

    inline py::handle self() {
        return pyx::wrap<Dispatcher>::pycast(this);
    }

    bool prepare_call(PyObject*const* args, Py_ssize_t nargs) {
        FastSig sig(nargs);
        for (Py_ssize_t i = 0; i < nargs; ++i) {
            sig[i] = Py_TYPE(args[i]);
        }
        auto it = cache.find(sig);
        if (it == cache.end()) {
            if (auto mro = resolve(sig)) {
                it = cache.emplace(std::move(sig), std::move(mro)).first;
            } else {
                return false;
            }
        }
        stack.emplace_back(it->second.get());
        return true;
    }

    template<typename T>
    PyObject* do_call(T&& caller) {
        auto& frame = stack.back();
        auto& mro = *frame.mro;
        auto& i = frame.mro_offset;
        for (; i < mro.size(); ++i) {
            if (mro[i]->enabled) {
                auto ret = caller(mro[i]->func);
                if (ret != Py_NotImplemented) {
                    stack.pop_back();
                    return ret;
                }
                Py_DECREF(ret);
            }
        }
        PyErr_SetString(PyExc_NotImplementedError, "mro exhausted");
        stack.pop_back();
        return nullptr;
    }

    std::unique_ptr<MRO> resolve(const FastSig& sig) {
        try {
            py::tuple args(sig.size());
            for (size_t i = 0; i < sig.size(); ++i) {
                args[i] = (PyObject*)sig[i];
            }
            auto mro_iter = self().attr("dispatch_iter")(*args);
            auto ret = std::make_unique<MRO>();
            for (auto i : mro_iter) {
                auto it = registry.find(py::reinterpret_borrow<py::object>(i));
                if (it == registry.end()) {
                    PyErr_SetString(PyExc_RuntimeError, "resolved to unregistered function");
                    return nullptr;
                }
                ret->push_back(it->second.get());
            }
            return ret;
        } catch (py::error_already_set& e) {
            e.restore();
        } catch (std::runtime_error& e) {
            PyErr_SetString(PyExc_RuntimeError, e.what());
        }
        return nullptr;
    }

public:
    static constexpr auto tp_name = "Dispatcher";

    PyObject* tp_vectorcall(PyObject*const* args, Py_ssize_t nargs) {
        if (!prepare_call(args, nargs)) return nullptr;
        return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
    }

    PyObject* tp_call(PyObject* args, PyObject* kwargs) {
        if (!prepare_call(&PyTuple_GET_ITEM(args, 0), PyTuple_GET_SIZE(args))) return nullptr;
        return do_call([=](PyObject* func){return PyObject_Call(func, args, kwargs);});
    }

    PyObject* super(PyObject*const* args, Py_ssize_t nargs) {
        if (stack.empty()) {
            PyErr_SetString(PyExc_RuntimeError, "super called at top level");
            return nullptr;
        }
        stack.emplace_back(stack.back()).mro_offset++;
        return do_call([=](PyObject* func){return _PyObject_FastCall(func, args, nargs);});
    }

    void enable(PyObject* func) {
        auto obj = py::reinterpret_borrow<py::object>(func);
        auto it = registry.find(obj);
        if (it != registry.end()) {
            it->second->enabled = true;
        } else {
            registry.emplace(std::move(obj), std::make_unique<Handler>(func));
        }
    }

    PyObject* disable(PyObject* func) {
        auto obj = py::reinterpret_borrow<py::object>(func);
        auto it = registry.find(obj);
        if (it == registry.end()) {
            PyErr_SetString(PyExc_ValueError, "function not registered");
            return nullptr;
        } else {
            it->second->enabled = false;
        }
        Py_RETURN_NONE;
    }

    void clear_cache() {
        cache.clear();
    }
};

} // namespace

void init_dispatcher(py::module m) {
    auto* dispatcher_type = pyx::wrap<Dispatcher>::type()
        .def<&Dispatcher::enable>("enable")
        .def<&Dispatcher::disable>("disable")
        .def<&Dispatcher::clear_cache>("clear_cache")
        .def<&Dispatcher::tp_vectorcall>("call")
        .def<&Dispatcher::super>("super")
        .finalize();
    if (!dispatcher_type) throw py::error_already_set();
    m.attr("Dispatcher") = dispatcher_type;
}