/** * \file imperative/python/src/pyext17.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include #include #include namespace pyext17 { #ifdef METH_FASTCALL constexpr bool has_fastcall = true; #else constexpr bool has_fastcall = false; #endif #ifdef _Py_TPFLAGS_HAVE_VECTORCALL constexpr bool has_vectorcall = true; #else constexpr bool has_vectorcall = false; #endif template struct invocable_with { template constexpr bool operator()(T&& lmb) { return std::is_invocable_v; } }; #define HAS_MEMBER_TYPE(T, U) invocable_with{}([](auto&& x) -> typename std::decay_t::U {}) #define HAS_MEMBER(T, m) invocable_with{}([](auto&& x) -> decltype(&std::decay_t::m) {}) inline PyObject* cvt_retval(PyObject* rv) { return rv; } #define CVT_RET_PYOBJ(...) \ if constexpr (std::is_same_v) { \ __VA_ARGS__; \ Py_RETURN_NONE; \ } else { \ return cvt_retval(__VA_ARGS__); \ } template struct wrap { private: typedef wrap wrap_t; public: PyObject_HEAD std::aligned_storage_t storage; #ifdef _Py_TPFLAGS_HAVE_VECTORCALL PyObject* (*vectorcall_slot)(PyObject*, PyObject*const*, size_t, PyObject*); #endif inline T* inst() { return reinterpret_cast(&storage); } inline static PyObject* pycast(T* ptr) { return (PyObject*)((char*)ptr - offsetof(wrap_t, storage)); } private: // method wrapper enum struct meth_type { noarg, varkw, fastcall, singarg }; template struct detect_meth_type { static constexpr meth_type value = []() { using F = decltype(f); static_assert(std::is_member_function_pointer_v); if constexpr (std::is_invocable_v) { return meth_type::noarg; } else if constexpr (std::is_invocable_v) { return meth_type::varkw; } else if constexpr (std::is_invocable_v) { return meth_type::fastcall; } else if constexpr (std::is_invocable_v) { return meth_type::singarg; } else { static_assert(!std::is_same_v); } }(); }; template struct meth {}; template struct meth { static constexpr int flags = METH_NOARGS; static PyObject* impl(PyObject* self, PyObject*) { auto* inst = reinterpret_cast(self)->inst(); CVT_RET_PYOBJ((inst->*f)()); } }; template struct meth { static constexpr int flags = METH_VARARGS | METH_KEYWORDS; static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { auto* inst = reinterpret_cast(self)->inst(); CVT_RET_PYOBJ((inst->*f)(args, kwargs)); } }; template struct meth { #ifdef METH_FASTCALL static constexpr int flags = METH_FASTCALL; static PyObject* impl(PyObject* self, PyObject*const* args, Py_ssize_t nargs) { auto* inst = reinterpret_cast(self)->inst(); CVT_RET_PYOBJ((inst->*f)(args, nargs)); } #else static constexpr int flags = METH_VARARGS; static PyObject* impl(PyObject* self, PyObject* args) { auto* inst = reinterpret_cast(self)->inst(); auto* arr = &PyTuple_GET_ITEM(args, 0); auto size = PyTuple_GET_SIZE(args); CVT_RET_PYOBJ((inst->*f)(arr, size)); } #endif }; template struct meth { static constexpr int flags = METH_O; static PyObject* impl(PyObject* self, PyObject* obj) { auto* inst = reinterpret_cast(self)->inst(); CVT_RET_PYOBJ((inst->*f)(obj)); } }; template static constexpr PyMethodDef make_meth_def(const char* name, const char* doc = nullptr) { using M = meth::value, f>; return {name, (PyCFunction)M::impl, M::flags, doc}; } // polyfills struct tp_vectorcall { static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall); static constexpr bool haskw = [](){if constexpr (valid) if constexpr (std::is_invocable_v) return true; return false;}(); template static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) { auto* inst = reinterpret_cast(self)->inst(); if constexpr (haskw) { CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames)); } else { if (kwnames && PyTuple_GET_SIZE(kwnames)) { PyErr_SetString(PyExc_TypeError, "expect no keyword argument"); return nullptr; } CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf)); } } static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot); else return 0;}(); }; struct tp_call { static constexpr bool provided = HAS_MEMBER(T, tp_call); static constexpr bool static_form = invocable_with{}( [](auto&& t, auto... args) -> decltype(std::decay_t::tp_call(args...)) {}); static constexpr bool valid = provided || tp_vectorcall::valid; template static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { auto* inst = reinterpret_cast(self)->inst(); CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); } static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; else if constexpr (provided) return impl<>; #ifdef _Py_TPFLAGS_HAVE_VECTORCALL else if constexpr (valid) return PyVectorcall_Call; #endif else return nullptr;}(); }; struct tp_new { static constexpr bool provided = HAS_MEMBER(T, tp_new); static constexpr bool varkw = std::is_constructible_v; static constexpr bool noarg = std::is_default_constructible_v; template static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { auto* self = type->tp_alloc(type, 0); auto* inst = reinterpret_cast(self)->inst(); if constexpr (has_vectorcall && tp_vectorcall::valid) { reinterpret_cast(self)->vectorcall_slot = &tp_vectorcall::template impl<>; } if constexpr (varkw) { new(inst) T(args, kwargs); } else { new(inst) T(); } return self; } static constexpr newfunc value = []() {if constexpr (provided) return T::tp_new; else if constexpr (varkw || noarg) return impl<>; else return nullptr;}(); }; struct tp_dealloc { static constexpr bool provided = HAS_MEMBER(T, tp_dealloc); template static void impl(PyObject* self) { reinterpret_cast(self)->inst()->~T(); Py_TYPE(self)->tp_free(self); } static constexpr destructor value = []() {if constexpr (provided) return T::tp_dealloc; else return impl<>;}(); }; public: class TypeBuilder { std::vector m_methods; PyTypeObject m_type; bool m_finalized = false; bool m_ready = false; void check_finalized() { if (m_finalized) { throw std::runtime_error("type is already finalized"); } } public: TypeBuilder(const TypeBuilder&) = delete; TypeBuilder& operator=(const TypeBuilder&) = delete; TypeBuilder() : m_type{PyVarObject_HEAD_INIT(nullptr, 0)} { constexpr auto has_tp_name = HAS_MEMBER(T, tp_name); if constexpr (has_tp_name) { m_type.tp_name = T::tp_name; } m_type.tp_dealloc = tp_dealloc::value; #ifdef _Py_TPFLAGS_HAVE_VECTORCALL m_type.tp_vectorcall_offset = tp_vectorcall::offset; #endif m_type.tp_call = tp_call::value; m_type.tp_basicsize = sizeof(wrap_t); m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; #ifdef _Py_TPFLAGS_HAVE_VECTORCALL if constexpr (tp_vectorcall::valid) { m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL; } #endif m_type.tp_new = tp_new::value; } PyTypeObject* operator->() { return &m_type; } bool ready() const { return m_ready; } PyObject* finalize() { if (!m_finalized) { if (m_methods.size()) { m_methods.push_back({0}); if (m_type.tp_methods) { PyErr_SetString(PyExc_SystemError, "tp_method is already set"); return nullptr; } m_type.tp_methods = &m_methods[0]; } if (PyType_Ready(&m_type)) { return nullptr; } m_ready = true; } return (PyObject*)&m_type; } template TypeBuilder& def(const char* name, const char* doc = nullptr) { check_finalized(); m_methods.push_back(make_meth_def(name, doc)); return *this; } }; static TypeBuilder& type() { static TypeBuilder type_helper; return type_helper; } }; } // namespace pyext17 #undef HAS_MEMBER_TYPE #undef HAS_MEMBER #undef CVT_RET_PYOBJ