ops.cpp 30.0 KB
Newer Older
1
#include "./ops.h"
2 3
#include "./helper.h"
#include "./tensor.h"
4

5
#include "megbrain/common.h"
6
#include "megbrain/custom/adaptor.h"
7
#include "megbrain/imperative.h"
8
#include "megbrain/imperative/graph_builder.h"
M
Megvii Engine Team 已提交
9
#include "megbrain/imperative/ops/autogen.h"
10 11
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
12
#include "megbrain/imperative/ops/rng.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/imperative/ops/utility.h"
14

15 16 17
#include <Python.h>
#include <unordered_map>

18
namespace py = pybind11;
19
using namespace mgb::imperative;
20

21 22 23 24 25 26 27 28
namespace {
auto normalize_enum(const std::string& in) {
    std::string ret;
    for (auto&& c : in) {
        ret += toupper(c);
    }
    return ret;
}
M
Megvii Engine Team 已提交
29 30 31 32 33 34 35 36 37 38 39 40
}  // anonymous namespace

#define CATCH_ALL(RETVAL)                              \
    catch (py::error_already_set & e) {                \
        e.restore();                                   \
        return RETVAL;                                 \
    }                                                  \
    catch (py::builtin_exception & e) {                \
        e.set_error();                                 \
        return RETVAL;                                 \
    }                                                  \
    catch (std::exception & e) {                       \
41
        PyErr_SetString(PyExc_RuntimeError, e.what()); \
M
Megvii Engine Team 已提交
42 43
        return RETVAL;                                 \
    }
44

45
namespace {
M
Megvii Engine Team 已提交
46
#define PyOp(name)     Py##name
47 48
#define PyOpType(name) PyOp(name)::py_type

M
Megvii Engine Team 已提交
49 50 51 52 53
#define PyOpDefBegin(name)                               \
    struct PyOp(name) : PyOpDef {                        \
        using Ty = name;                                 \
        Ty& inst() { return op->cast_final_safe<Ty>(); } \
        static PyTypeObject py_type;
54 55

#define PyOpDefEnd(name) \
M
Megvii Engine Team 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    }                    \
    ;                    \
    PyTypeObject PyOpType(name);

#define RETURN_RICHCOMPARE(val1, val2, op)                        \
    do {                                                          \
        switch (op) {                                             \
            case Py_EQ:                                           \
                if ((val1) == (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_NE:                                           \
                if ((val1) != (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_LT:                                           \
                if ((val1) < (val2))                              \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_GT:                                           \
                if ((val1) > (val2))                              \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_LE:                                           \
                if ((val1) <= (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_GE:                                           \
                if ((val1) >= (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            default:                                              \
                Py_FatalError("Unreachable C code path reached"); \
        }                                                         \
90 91
    } while (0)

92
template <typename T>
93 94 95 96 97 98 99 100 101
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
    PyObject* obj = type->tp_alloc(type, 0);
    T* self = reinterpret_cast<T*>(obj);
    if (self != NULL) {
        self->op = T::Ty::make();
    }
    return obj;
}

M
Megvii Engine Team 已提交
102
template <typename T, typename SNIFAE = void>
103
struct serialization {
M
Megvii Engine Team 已提交
104 105 106
    static T load(py::object obj) { return py::cast<T>(obj); }
    template <
            typename U, typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
107 108 109 110 111
    static py::object dump(U&& t) {
        return py::cast(std::forward<U>(t));
    }
};

M
Megvii Engine Team 已提交
112
template <typename T>
113 114 115 116 117
void py_dealloc_generic(PyObject* obj) {
    reinterpret_cast<T*>(obj)->op.reset();
    Py_TYPE(obj)->tp_free(obj);
}

M
Megvii Engine Team 已提交
118
template <typename T, typename U, U T::Ty::*attr>
119 120
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
    auto& op = reinterpret_cast<T*>(obj)->inst();
121
    return py::cast(op.*attr).release().ptr();
122 123 124 125
}
#define py_get_generic(name, attr) \
    py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

M
Megvii Engine Team 已提交
126
template <typename T, typename U, U T::Ty::*attr>
127 128 129 130 131 132 133
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
    if (value == NULL) {
        PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
        return -1;
    }
    auto& op = reinterpret_cast<T*>(obj)->inst();
    try {
134 135 136
        // TODO: remove this guard which is used for pybind11 implicit conversion
        py::detail::loader_life_support guard{};
        op.*attr = py::cast<U>(py::handle(value));
M
Megvii Engine Team 已提交
137 138
    }
    CATCH_ALL(-1)
139
    return 0;
140 141 142 143 144
}
#define py_set_generic(name, attr) \
    py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

struct PyOpDef {
M
Megvii Engine Team 已提交
145
    PyObject_HEAD std::shared_ptr<OpDef> op;
146 147
    static PyTypeObject py_type;
    static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
148
    static PyGetSetDef py_getsetters[];
M
Megvii Engine Team 已提交
149 150
    static Py_hash_t tp_hash(PyObject* obj);
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op);
151
    static PyObject* py_repr(PyObject* self) {
M
Megvii Engine Team 已提交
152
        return py::cast(reinterpret_cast<PyOpDef*>(self)->op->make_name())
153 154 155
                .release()
                .ptr();
    }
156 157 158 159
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;

160
PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
M
Megvii Engine Team 已提交
161
    return py::cast(reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
162 163 164 165 166 167 168 169
}

int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
    if (value == NULL) {
        PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
        return -1;
    }
    try {
M
Megvii Engine Team 已提交
170 171 172 173
        reinterpret_cast<PyOp(OpDef)*>(obj)->op->set_scope(
                py::cast<std::string>(py::handle(value)));
    }
    CATCH_ALL(-1)
174 175 176 177
    return 0;
}

PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
178 179
        {const_cast<char*>("scope"), py_get_scope, py_set_scope,
         const_cast<char*>("scope"), NULL},
M
Megvii Engine Team 已提交
180
        {NULL}};
181

M
Megvii Engine Team 已提交
182 183
Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
    return static_cast<Py_hash_t>(reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
184 185
}

M
Megvii Engine Team 已提交
186
PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) {
187
    bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
M
Megvii Engine Team 已提交
188
            *reinterpret_cast<PyOp(OpDef)*>(other)->op);
189 190 191 192 193 194
    if (op == Py_EQ || op == Py_NE) {
        RETURN_RICHCOMPARE(same, true, op);
    }
    Py_RETURN_NOTIMPLEMENTED;
}

M
Megvii Engine Team 已提交
195
template <typename T>
196 197
struct EnumTrait;

M
Megvii Engine Team 已提交
198 199 200 201 202 203 204
#define PyEnumHead                                          \
    static_assert(std::is_enum_v<T>);                       \
    PyObject_HEAD T value;                                  \
    constexpr static const char* name = EnumTrait<T>::name; \
    static PyTypeObject* type;                              \
    static const char* members[];                           \
    static std::unordered_map<std::string, T> mem2value;    \
205 206
    static PyObject* pyobj_insts[];

M
Megvii Engine Team 已提交
207
template <typename T>
208
struct EnumWrapper {
M
Megvii Engine Team 已提交
209
    PyEnumHead std::string to_string() const {
210
        return members[static_cast<size_t>(value)];
211 212
    }
    static PyObject* py_repr(PyObject* self) {
213
        return py::cast(
M
Megvii Engine Team 已提交
214 215 216 217
                       std::string(name) + "." +
                       reinterpret_cast<EnumWrapper*>(self)->to_string())
                .release()
                .ptr();
218
    }
219 220 221 222 223 224 225

    static PyObject* py_dump(PyObject* self) {
        return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
                .release()
                .ptr();
    }

M
Megvii Engine Team 已提交
226
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
227
        if (op == Py_EQ || op == Py_NE) {
228 229 230 231 232 233
            T lhs, rhs;
            if (load(other, rhs) && load(self, lhs)) {
                RETURN_RICHCOMPARE(lhs, rhs, op);
            } else {
                RETURN_RICHCOMPARE(0, 1, op);
            }
234 235 236
        }
        Py_RETURN_NOTIMPLEMENTED;
    }
237 238
    static bool load(py::handle src, T& value) {
        PyObject* obj = src.ptr();
239
        if (PyObject_TypeCheck(obj, type)) {
240 241 242 243
            value = reinterpret_cast<EnumWrapper*>(obj)->value;
            return true;
        }
        if (py::isinstance<py::str>(src)) {
M
Megvii Engine Team 已提交
244
            auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
245 246 247 248 249 250
            if (iter != mem2value.end()) {
                value = iter->second;
                return true;
            } else {
                return false;
            }
251
        }
252
        return false;
253
    }
254 255 256 257 258
    static PyObject* cast(const T& value) {
        auto v = static_cast<std::underlying_type_t<T>>(value);
        mgb_assert(v <= EnumTrait<T>::max);
        PyObject* obj = pyobj_insts[v];
        Py_INCREF(obj);
259 260 261 262
        return obj;
    }
};

M
Megvii Engine Team 已提交
263
template <typename T>
264
struct BitCombinedEnumWrapper {
M
Megvii Engine Team 已提交
265
    PyEnumHead std::string to_string() const {
266 267
        uint32_t value_int = static_cast<uint32_t>(value);
        if (value_int == 0) {
268 269
            return "None";
        } else {
270
            std::string ret;
271 272
            bool first = true;
            for (uint32_t i = 0; i < 32; i++) {
273
                if (value_int >> i & 1) {
274 275 276 277 278
                    if (!first) {
                        ret += " + ";
                    } else {
                        first = false;
                    }
279
                    ret += (std::string(name) + "." + members[i]);
280 281 282 283 284
                }
            }
            return ret;
        }
    }
M
Megvii Engine Team 已提交
285 286
    static PyObject* py_new_combined_enum(
            PyTypeObject* type, PyObject* args, PyObject*) {
287 288 289 290
        if (!PyTuple_Size(args)) {
            PyObject* obj = type->tp_alloc(type, 0);
            reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
            return obj;
M
Megvii Engine Team 已提交
291
        } else {
292 293 294 295 296
            PyObject* input;
            if (!PyArg_ParseTuple(args, "|O", &input)) {
                return nullptr;
            }
            T value;
297 298 299
            if (load(input, value)) {
                return cast(value);
            } else {
M
Megvii Engine Team 已提交
300 301 302 303 304 305
                PyErr_SetString(
                        PyExc_RuntimeError,
                        mgb::ssprintf(
                                "Cannot convert type %s to type %s\n",
                                input->ob_type->tp_name, name)
                                .c_str());
306 307
                return nullptr;
            }
308 309 310
        }
    }
    static PyObject* py_repr(PyObject* self) {
M
Megvii Engine Team 已提交
311 312 313
        return py::cast(reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
                .release()
                .ptr();
314
    }
315 316 317 318 319 320 321 322 323 324 325 326 327

    static PyObject* py_dump(PyObject* self) {
        std::vector<std::string> result;
        auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
        uint32_t value_int = static_cast<uint32_t>(value);
        for (uint32_t i = 0; i < 32; i++) {
            if (value_int >> i & 1) {
                result.push_back(members[i]);
            }
        }
        return py::tuple(py::cast(result)).release().ptr();
    }

328
    static PyObject* py_or(PyObject* self, PyObject* other) {
M
Megvii Engine Team 已提交
329
        if (!(self->ob_type == other->ob_type)) {
330 331 332 333 334 335
            return PyErr_Format(
                    PyExc_RuntimeError,
                    "Operand in or operator must be the same type.");
        }
        T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
          rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
336
        return cast(lhs | rhs);
337 338 339 340 341 342 343 344 345
    }
    static PyObject* py_and(PyObject* self, PyObject* other) {
        if (!(self->ob_type == other->ob_type)) {
            return PyErr_Format(
                    PyExc_RuntimeError,
                    "Operand in and operator must be the same type.");
        }
        T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
          rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
346
        return cast(lhs & rhs);
347 348 349
    }
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
        if (op == Py_EQ || op == Py_NE) {
350 351 352 353 354 355
            T lhs, rhs;
            if (load(other, rhs) && load(self, lhs)) {
                RETURN_RICHCOMPARE(lhs, rhs, op);
            } else {
                RETURN_RICHCOMPARE(0, 1, op);
            }
356 357 358
        }
        Py_RETURN_NOTIMPLEMENTED;
    }
359 360
    static bool load(py::handle src, T& value) {
        PyObject* obj = src.ptr();
361
        if (PyObject_TypeCheck(obj, type)) {
362 363 364 365
            value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
            return true;
        }
        if (py::isinstance<py::str>(src)) {
M
Megvii Engine Team 已提交
366
            auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
367 368 369 370 371 372 373
            if (iter != mem2value.end()) {
                value = iter->second;
                return true;
            } else {
                return false;
            }
        }
374
        if (py::isinstance<py::tuple>(src)) {
M
Megvii Engine Team 已提交
375
            auto params = py::cast<std::vector<std::string>>(src);
376
            bool first = true;
M
Megvii Engine Team 已提交
377
            for (auto s : params) {
378 379 380 381 382 383 384 385 386 387 388 389 390 391
                auto&& iter = mem2value.find(normalize_enum(s));
                if (iter != mem2value.end()) {
                    if (first) {
                        value = iter->second;
                        first = false;
                    } else {
                        value |= iter->second;
                    }
                } else {
                    return false;
                }
            }
            return true;
        }
392 393
        if (py::isinstance<py::int_>(obj)) {
            auto v = py::cast<std::underlying_type_t<T>>(src);
M
Megvii Engine Team 已提交
394
            if (v > EnumTrait<T>::max) {
395 396 397 398
                return false;
            }
            value = static_cast<T>(v);
            return true;
399
        }
400
        return false;
401
    }
402 403 404 405
    static PyObject* cast(const T& value) {
        auto v = static_cast<std::underlying_type_t<T>>(value);
        mgb_assert(v <= EnumTrait<T>::max);
        if ((!v) || (v & (v - 1))) {
406
            PyObject* obj = type->tp_alloc(type, 0);
407 408 409 410 411 412 413
            reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
            return obj;
        } else {
            PyObject* obj = pyobj_insts[__builtin_ctz(v)];
            Py_INCREF(obj);
            return obj;
        }
414 415 416
    }
};

M
Megvii Engine Team 已提交
417 418
template <typename T>
struct serialization<T, std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
419 420 421 422 423
    static T load(py::object obj) {
        auto caster = pybind11::detail::type_caster<T>();
        if (caster.load(obj, true)) {
            return caster;
        } else {
M
Megvii Engine Team 已提交
424 425
            PyErr_SetString(PyExc_RuntimeError, "load faild \n");
            return caster;
426 427
        }
    }
M
Megvii Engine Team 已提交
428
    static py::object dump(T t) { return py::cast(t).attr("dump")(); }
429 430
};

431
void _init_py_op_def(py::module m) {
432
    using py_op = PyOp(OpDef);
433 434 435 436 437 438 439 440 441
    auto& py_type = PyOpType(OpDef);
    py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
    py_type.tp_name = "megengine.core._imperative_rt.OpDef";
    py_type.tp_basicsize = sizeof(PyOp(OpDef));
    py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
    py_type.tp_doc = "OpDef";
    py_type.tp_base = &PyBaseObject_Type;
    py_type.tp_hash = PyOp(OpDef)::tp_hash;
    py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
442
    py_type.tp_getset = py_op::py_getsetters;
443
    py_type.tp_repr = py_op::py_repr;
444
    py_type.tp_dealloc = py_dealloc_generic<PyOp(OpDef)>;
445 446 447 448 449
    mgb_assert(PyType_Ready(&py_type) >= 0);
    m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}

/*********** begin of hand-write opdefs **************/
450 451 452 453 454 455 456
struct PyOpBase : PyOpDef {
    static PyTypeObject py_type;

    static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
        auto* obj = type->tp_alloc(type, 0);
        if (obj) {
            auto* self = reinterpret_cast<PyOpBase*>(obj);
M
Megvii Engine Team 已提交
457
            new (&self->op) decltype(self->op);
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
        }
        return obj;
    }
};
PyTypeObject PyOpBase::py_type;

void _init_py_op_base(py::module m) {
    using py_op = PyOpBase;
    auto& py_type = PyOpBase::py_type;
    py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
    py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
    py_type.tp_basicsize = sizeof(py_op);
    py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
    py_type.tp_doc = "PyOpBase";
    py_type.tp_base = &PyOpType(OpDef);
    py_type.tp_dealloc = py_dealloc_generic<py_op>;
    py_type.tp_new = py_op::tp_new;
    mgb_assert(PyType_Ready(&py_type) >= 0);
    m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
}

479 480 481 482 483
/*********** end of hand-write opdefs **************/

// auto generated opdefs
#include "opdef.cpy.inl"

484
#undef CATCH_ALL
M
Megvii Engine Team 已提交
485
}  // anonymous namespace
486 487 488 489 490 491 492 493 494

namespace PYBIND11_NAMESPACE {
namespace detail {
bool type_caster<OpDef>::load(handle src, bool convert) {
    PyObject* obj = src.ptr();
    if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
        return false;
    }
    value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
495 496 497 498
    if (!value) {
        // opdef only defined in Python
        value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
    }
499 500 501
    return true;
}
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
502 503 504
    if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
        return object(pyop->obj).release();
    }
505 506 507
    PyTypeObject* pytype;
    auto& c2p = PyOp(OpDef)::ctype2pytype;
    auto&& iter = c2p.find(op.dyn_typeinfo());
M
Megvii Engine Team 已提交
508
    if (iter != c2p.end()) {  // FIXME: should always meet this condition
509
        pytype = iter->second;
M
Megvii Engine Team 已提交
510
    } else {  // which means unregistered op type, jsut make it as an opaque op type
511 512 513 514 515 516 517 518
        // currently, only OprAttr goes into this branch
        pytype = &PyOpType(OpDef);
    }
    PyObject* obj = pytype->tp_alloc(pytype, 0);
    mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
    reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
    return py::handle(obj);
}
519

M
Megvii Engine Team 已提交
520 521 522 523 524 525 526
#define ENUM_CASTER_IMPL(T)                                                    \
    bool type_caster<T>::load(handle src, bool) {                              \
        return EnumWrapper<T>::load(src, value);                               \
    }                                                                          \
    handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
        return EnumWrapper<T>::cast(value);                                    \
    }
527 528
FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)

M
Megvii Engine Team 已提交
529 530 531 532 533 534 535
#define BIT_COMBINED_ENUM_CASTER_IMPL(T)                                       \
    bool type_caster<T>::load(handle src, bool) {                              \
        return BitCombinedEnumWrapper<T>::load(src, value);                    \
    }                                                                          \
    handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
        return BitCombinedEnumWrapper<T>::cast(value);                         \
    }
536 537
FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)

M
Megvii Engine Team 已提交
538 539
}  // namespace detail
}  // namespace PYBIND11_NAMESPACE
540

541
void init_ops(py::module m) {
542
    _init_py_op_def(m);
543
    _init_py_op_base(m);
544
    INIT_ALL_OP(m)
545

546
    m.def("new_rng_handle", &rng::new_handle);
M
Megvii Engine Team 已提交
547 548 549
    m.def(
            "delete_rng_handle",
            [](size_t handle) {
550 551 552
                if (mgb::imperative::python::interpreter_for_py->check_available()) {
                    mgb::imperative::python::interpreter_for_py->sync();
                }
M
Megvii Engine Team 已提交
553
                mgb::CompNode::sync_all();
554 555 556 557
                mgb::CompNode::foreach ([](mgb::CompNode cn) {
                    auto err = cn.check_async_error();
                    mgb_assert(!err, "%s", err->what());
                });
M
Megvii Engine Team 已提交
558 559 560 561
                py_task_q.wait_all_task_finish();
                rng::delete_handle(handle);
            },
            py::call_guard<py::gil_scoped_release>());
562 563 564 565 566 567 568 569 570
    m.def("set_global_rng_seed", [](uint64_t seed) -> void {
        mgb_assert(
                python::interpreter_for_py->check_available(),
                "set global random seed failed since imperative interpreter has been "
                "destroyed");
        python::interpreter_for_py->sync();
        mgb::CompNode::sync_all();
        rng::set_global_rng_seed(seed);
    });
571
    m.def("get_global_rng_seed", &rng::get_global_rng_seed);
572
    m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
573 574

    struct PySubgraphBuilder {
M
Megvii Engine Team 已提交
575
        explicit PySubgraphBuilder(std::string name) : name{name} {}
576
        std::string name;
577
        Subgraph graph;
578 579
        mgb::SmallVector<bool> output_grad_mask;
        Subgraph::var_t next_var = 1;
580
        std::shared_ptr<mgb::Hashable> key = nullptr;
581

582 583 584 585 586 587
        std::shared_ptr<OpDef> build() {
            if (key == nullptr) {
                key = std::make_shared<UniqueKey>();
            }
            return SubgraphOp::make(
                    name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
588
        }
589 590 591
    };

    py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
M
Megvii Engine Team 已提交
592
            .def(py::init<std::string>())
593
            .def(py::init<PySubgraphBuilder>())
M
Megvii Engine Team 已提交
594 595
            .def("input",
                 [](PySubgraphBuilder& self) {
596
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
597 598 599 600 601 602 603
                     auto var = self.next_var++;
                     self.graph.inputs.push_back(var);
                     return var;
                 })
            .def("apply",
                 [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
                    Subgraph::vars_t inputs, size_t nr_outputs) {
604
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
605 606 607 608 609 610 611 612 613 614
                     Subgraph::vars_t outputs;
                     for (size_t i = 0; i < nr_outputs; ++i) {
                         outputs.push_back(self.next_var++);
                     }
                     self.graph.exprs.push_back({op, inputs, outputs});
                     return outputs;
                 })
            .def("apply_const",
                 [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
                    mgb::CompNode cn) {
615
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
616 617 618 619 620 621 622 623 624 625
                     auto var = self.next_var++;
                     mgb::HostTensorND hvalue(cn);
                     npy::np2tensor(
                             value.cast<py::array>().ptr(),
                             npy::Meth::copy_into(&hvalue), dtype);
                     self.graph.constants.push_back({var, Tensor::make(hvalue)});
                     return var;
                 })
            .def("outputs",
                 [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
626
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
627 628 629 630 631
                     self.graph.outputs = outputs;
                     self.output_grad_mask.resize(outputs.size(), true);
                 })
            .def("outputs_has_grad",
                 [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
632
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
633 634 635 636 637 638 639 640
                     mgb_assert(
                             self.graph.outputs.size() == self.output_grad_mask.size());
                     self.output_grad_mask = outputs_has_grad;
                 })
            .def("get",
                 [](PySubgraphBuilder& self) {
                     return (std::shared_ptr<OpDef>)self.build();
                 })
641 642 643 644 645 646
            .def("compile",
                 [](PySubgraphBuilder& self, int gopt_level) {
                     return (std::shared_ptr<OpDef>)CompiledOp::make(
                             self.build(), gopt_level);
                 })
            .def("jit_fuse", [](PySubgraphBuilder& self) {
M
Megvii Engine Team 已提交
647
                return (std::shared_ptr<OpDef>)CompiledOp::make(
648
                        JITFusionOp::make(self.build()));
M
Megvii Engine Team 已提交
649 650
            });

651
    m.def("set_jit_enabled", &JITFusionOp::set_enabled);
652 653 654 655 656
    bool jit_supported = false;
#if MGB_JIT
    jit_supported = true;
#endif
    m.attr("jit_supported") = jit_supported;
657

658 659 660 661
    auto custom = submodule(m, "_custom");
    init_custom(custom);
}

M
Megvii Engine Team 已提交
662 663 664 665
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type)   \
    case custom::ParamDynType::dyn_type: {                     \
        param_val = py::handle(kv.second).cast<static_type>(); \
        break;                                                 \
666 667
    }

M
Megvii Engine Team 已提交
668 669 670 671 672 673 674 675 676 677
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type)                            \
    case custom::ParamDynType::dyn_type: {                                          \
        auto pyvals = py::handle(kv.second).cast<py::list>();                       \
        static_type vals;                                                           \
        using basic_type = custom::get_vector_template_arg_type<static_type>::type; \
        for (auto& pyval : pyvals) {                                                \
            vals.push_back(py::handle(pyval).cast<basic_type>());                   \
        }                                                                           \
        param_val = vals;                                                           \
        break;                                                                      \
678 679
    }

M
Megvii Engine Team 已提交
680
PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
681
#if MGB_CUSTOM_OP
682 683 684 685
    auto op_name = py::handle(args[0]).cast<std::string>();
    auto kwargs = py::handle(args[1]).cast<py::dict>();

    std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
M
Megvii Engine Team 已提交
686 687
    auto& custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
    auto& param = custom_opdef.param();
688

M
Megvii Engine Team 已提交
689
    for (auto&& kv : kwargs) {
690 691
        std::string param_name = py::handle(kv.first).cast<std::string>();
        std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
M
Megvii Engine Team 已提交
692

693 694
        if (!param.exist(param_name)) {
            mgb_log_warn(
M
Megvii Engine Team 已提交
695 696 697
                    "op %s have no param named %s, ignore this param parsed from "
                    "python",
                    op_name.c_str(), param_name.c_str());
698 699 700 701 702 703 704 705 706 707
            continue;
        }

        auto& param_val = param[param_name];
        switch (param_val.type()) {
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
            CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
            CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
            CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
            CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
708 709 710 711 712
            case custom::ParamDynType::Device: {
                param_val =
                        to_custom_device(py::handle(kv.second).cast<mgb::CompNode>());
                break;
            }
713 714
            default: {
                mgb_assert(
M
Megvii Engine Team 已提交
715 716
                        false, "param dtype of %s:%s is invalid", op_name.c_str(),
                        param_name.c_str());
717 718 719 720 721 722 723 724
            }
        }
    }

    PyTypeObject* pytype;
    pytype = &PyOpType(OpDef);
    PyObject* obj = pytype->tp_alloc(pytype, 0);
    reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
M
Megvii Engine Team 已提交
725

726
    return obj;
727
#else
728
    mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
729 730
    return nullptr;
#endif
731 732 733 734 735
}

#undef CUSTOM_CASE_TO_PARSE_LIST
#undef CUSTOM_CASE_TO_PARSE_NON_LIST

M
Megvii Engine Team 已提交
736
py::list install_custom(const std::string& name, const std::string& path) {
737
#if MGB_CUSTOM_OP
738 739
    const auto& ops_in_lib = custom::CustomOpManager::inst()->install(name, path);
    py::list ret = py::cast(ops_in_lib);
740
    return ret;
741
#else
742 743
    mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
    return py::list{};
744
#endif
745 746
}

747
py::list uninstall_custom(const std::string& name) {
748
#if MGB_CUSTOM_OP
749 750 751
    const auto& ops_in_lib = custom::CustomOpManager::inst()->uninstall(name);
    py::list ret = py::cast(ops_in_lib);
    return ret;
752
#else
753
    mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
754 755
    return false;
#endif
756 757 758
}

py::list get_custom_op_list(void) {
759
#if MGB_CUSTOM_OP
760 761
    std::vector<std::string> all_ops = custom::CustomOpManager::inst()->op_name_list();
    py::list ret = py::cast(all_ops);
762
    return ret;
763
#else
764 765 766 767 768 769 770 771 772 773 774 775 776
    mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
    return py::list{};
#endif
}

py::dict get_custom_op_lib_info(void) {
#if MGB_CUSTOM_OP
    auto&& libs = custom::CustomOpManager::inst()->lib_info();
    py::dict ret;
    for (auto&& [lib_name, lib_handle] : libs) {
        py::list ops = py::cast(lib_handle->ops_in_lib());
        ret[py::str(lib_name)] = ops;
    }
777
    return ret;
778 779 780
#else
    mgb_assert(false, "CustomOp disabled, please build megengine with CustomOp open");
    return py::list{};
781
#endif
782 783
}

784
#ifndef METH_FASTCALL
M
Megvii Engine Team 已提交
785 786 787 788 789
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
    auto* arr = &PyTuple_GET_ITEM(args, 0);
    auto size = PyTuple_GET_SIZE(args);
    return make_custom_op(self, arr, size);
};
790 791
#endif

792
void init_custom(pybind11::module m) {
793 794 795
    m.def("_install", &install_custom);
    m.def("_uninstall", &uninstall_custom);
    m.def("_get_custom_op_list", &get_custom_op_list);
796
    m.def("_get_custom_op_lib_info", &get_custom_op_lib_info);
797 798 799 800 801 802 803
    m.def("get_custom_op_abi_tag", [](void) -> int {
        int ret = 0;
#ifdef _GLIBCXX_USE_CXX11_ABI
        ret = _GLIBCXX_USE_CXX11_ABI;
#endif
        return ret;
    });
804 805

    static PyMethodDef method_def = {
806
#ifdef METH_FASTCALL
M
Megvii Engine Team 已提交
807
            "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
808
#else
M
Megvii Engine Team 已提交
809
            "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
810
#endif
811 812 813
    };
    auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
    pybind11::setattr(m, method_def.ml_name, func);
814
}