ops.cpp 29.4 KB
Newer Older
1
#include "./ops.h"
2 3
#include "./helper.h"
#include "./tensor.h"
4

5
#include "megbrain/common.h"
6
#include "megbrain/imperative.h"
7
#include "megbrain/imperative/graph_builder.h"
M
Megvii Engine Team 已提交
8
#include "megbrain/imperative/ops/autogen.h"
9 10
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
11
#include "megbrain/imperative/ops/rng.h"
M
Megvii Engine Team 已提交
12
#include "megbrain/imperative/ops/utility.h"
13

14 15 16
#include <Python.h>
#include <unordered_map>

17
namespace py = pybind11;
18
using namespace mgb::imperative;
19

20 21 22 23 24 25 26 27
namespace {
auto normalize_enum(const std::string& in) {
    std::string ret;
    for (auto&& c : in) {
        ret += toupper(c);
    }
    return ret;
}
M
Megvii Engine Team 已提交
28 29 30 31 32 33 34 35 36 37 38 39
}  // anonymous namespace

#define CATCH_ALL(RETVAL)                              \
    catch (py::error_already_set & e) {                \
        e.restore();                                   \
        return RETVAL;                                 \
    }                                                  \
    catch (py::builtin_exception & e) {                \
        e.set_error();                                 \
        return RETVAL;                                 \
    }                                                  \
    catch (std::exception & e) {                       \
40
        PyErr_SetString(PyExc_RuntimeError, e.what()); \
M
Megvii Engine Team 已提交
41 42
        return RETVAL;                                 \
    }
43

44
namespace {
M
Megvii Engine Team 已提交
45
#define PyOp(name)     Py##name
46 47
#define PyOpType(name) PyOp(name)::py_type

M
Megvii Engine Team 已提交
48 49 50 51 52
#define PyOpDefBegin(name)                               \
    struct PyOp(name) : PyOpDef {                        \
        using Ty = name;                                 \
        Ty& inst() { return op->cast_final_safe<Ty>(); } \
        static PyTypeObject py_type;
53 54

#define PyOpDefEnd(name) \
M
Megvii Engine Team 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
    }                    \
    ;                    \
    PyTypeObject PyOpType(name);

#define RETURN_RICHCOMPARE(val1, val2, op)                        \
    do {                                                          \
        switch (op) {                                             \
            case Py_EQ:                                           \
                if ((val1) == (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_NE:                                           \
                if ((val1) != (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_LT:                                           \
                if ((val1) < (val2))                              \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_GT:                                           \
                if ((val1) > (val2))                              \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_LE:                                           \
                if ((val1) <= (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            case Py_GE:                                           \
                if ((val1) >= (val2))                             \
                    Py_RETURN_TRUE;                               \
                Py_RETURN_FALSE;                                  \
            default:                                              \
                Py_FatalError("Unreachable C code path reached"); \
        }                                                         \
89 90
    } while (0)

91
template <typename T>
92 93 94 95 96 97 98 99 100
PyObject* py_new_generic(PyTypeObject* type, PyObject*, PyObject*) {
    PyObject* obj = type->tp_alloc(type, 0);
    T* self = reinterpret_cast<T*>(obj);
    if (self != NULL) {
        self->op = T::Ty::make();
    }
    return obj;
}

M
Megvii Engine Team 已提交
101
template <typename T, typename SNIFAE = void>
102
struct serialization {
M
Megvii Engine Team 已提交
103 104 105
    static T load(py::object obj) { return py::cast<T>(obj); }
    template <
            typename U, typename = std::enable_if_t<std::is_same_v<T, std::decay_t<U>>>>
106 107 108 109 110
    static py::object dump(U&& t) {
        return py::cast(std::forward<U>(t));
    }
};

M
Megvii Engine Team 已提交
111
template <typename T>
112 113 114 115 116
void py_dealloc_generic(PyObject* obj) {
    reinterpret_cast<T*>(obj)->op.reset();
    Py_TYPE(obj)->tp_free(obj);
}

M
Megvii Engine Team 已提交
117
template <typename T, typename U, U T::Ty::*attr>
118 119
PyObject* py_get_generic_impl(PyObject* obj, void* /* closure */) {
    auto& op = reinterpret_cast<T*>(obj)->inst();
120
    return py::cast(op.*attr).release().ptr();
121 122 123 124
}
#define py_get_generic(name, attr) \
    py_get_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

M
Megvii Engine Team 已提交
125
template <typename T, typename U, U T::Ty::*attr>
126 127 128 129 130 131 132
int py_set_generic_impl(PyObject* obj, PyObject* value, void* /* closure */) {
    if (value == NULL) {
        PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
        return -1;
    }
    auto& op = reinterpret_cast<T*>(obj)->inst();
    try {
133 134 135
        // TODO: remove this guard which is used for pybind11 implicit conversion
        py::detail::loader_life_support guard{};
        op.*attr = py::cast<U>(py::handle(value));
M
Megvii Engine Team 已提交
136 137
    }
    CATCH_ALL(-1)
138
    return 0;
139 140 141 142 143
}
#define py_set_generic(name, attr) \
    py_set_generic_impl<PyOp(name), decltype(std::declval<name>().attr), &name::attr>

struct PyOpDef {
M
Megvii Engine Team 已提交
144
    PyObject_HEAD std::shared_ptr<OpDef> op;
145 146
    static PyTypeObject py_type;
    static std::unordered_map<mgb::Typeinfo*, PyTypeObject*> ctype2pytype;
147
    static PyGetSetDef py_getsetters[];
M
Megvii Engine Team 已提交
148 149
    static Py_hash_t tp_hash(PyObject* obj);
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op);
150
    static PyObject* py_repr(PyObject* self) {
M
Megvii Engine Team 已提交
151
        return py::cast(reinterpret_cast<PyOpDef*>(self)->op->make_name())
152 153 154
                .release()
                .ptr();
    }
155 156 157 158
};
PyTypeObject PyOpType(OpDef);
std::unordered_map<mgb::Typeinfo*, PyTypeObject*> PyOp(OpDef)::ctype2pytype;

159
PyObject* py_get_scope(PyObject* obj, void* /* closure */) {
M
Megvii Engine Team 已提交
160
    return py::cast(reinterpret_cast<PyOp(OpDef)*>(obj)->op->scope()).release().ptr();
161 162 163 164 165 166 167 168
}

int py_set_scope(PyObject* obj, PyObject* value, void* /* closure */) {
    if (value == NULL) {
        PyErr_SetString(PyExc_TypeError, "Cannot delete the attribute");
        return -1;
    }
    try {
M
Megvii Engine Team 已提交
169 170 171 172
        reinterpret_cast<PyOp(OpDef)*>(obj)->op->set_scope(
                py::cast<std::string>(py::handle(value)));
    }
    CATCH_ALL(-1)
173 174 175 176
    return 0;
}

PyGetSetDef PyOp(OpDef)::py_getsetters[] = {
177 178
        {const_cast<char*>("scope"), py_get_scope, py_set_scope,
         const_cast<char*>("scope"), NULL},
M
Megvii Engine Team 已提交
179
        {NULL}};
180

M
Megvii Engine Team 已提交
181 182
Py_hash_t PyOp(OpDef)::tp_hash(PyObject* obj) {
    return static_cast<Py_hash_t>(reinterpret_cast<PyOp(OpDef)*>(obj)->op->hash());
183 184
}

M
Megvii Engine Team 已提交
185
PyObject* PyOp(OpDef)::tp_richcompare(PyObject* self, PyObject* other, int op) {
186
    bool same = reinterpret_cast<PyOp(OpDef)*>(self)->op->is_same(
M
Megvii Engine Team 已提交
187
            *reinterpret_cast<PyOp(OpDef)*>(other)->op);
188 189 190 191 192 193
    if (op == Py_EQ || op == Py_NE) {
        RETURN_RICHCOMPARE(same, true, op);
    }
    Py_RETURN_NOTIMPLEMENTED;
}

M
Megvii Engine Team 已提交
194
template <typename T>
195 196
struct EnumTrait;

M
Megvii Engine Team 已提交
197 198 199 200 201 202 203
#define PyEnumHead                                          \
    static_assert(std::is_enum_v<T>);                       \
    PyObject_HEAD T value;                                  \
    constexpr static const char* name = EnumTrait<T>::name; \
    static PyTypeObject* type;                              \
    static const char* members[];                           \
    static std::unordered_map<std::string, T> mem2value;    \
204 205
    static PyObject* pyobj_insts[];

M
Megvii Engine Team 已提交
206
template <typename T>
207
struct EnumWrapper {
M
Megvii Engine Team 已提交
208
    PyEnumHead std::string to_string() const {
209
        return members[static_cast<size_t>(value)];
210 211
    }
    static PyObject* py_repr(PyObject* self) {
212
        return py::cast(
M
Megvii Engine Team 已提交
213 214 215 216
                       std::string(name) + "." +
                       reinterpret_cast<EnumWrapper*>(self)->to_string())
                .release()
                .ptr();
217
    }
218 219 220 221 222 223 224

    static PyObject* py_dump(PyObject* self) {
        return py::cast(reinterpret_cast<EnumWrapper*>(self)->to_string())
                .release()
                .ptr();
    }

M
Megvii Engine Team 已提交
225
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
226
        if (op == Py_EQ || op == Py_NE) {
227 228 229 230 231 232
            T lhs, rhs;
            if (load(other, rhs) && load(self, lhs)) {
                RETURN_RICHCOMPARE(lhs, rhs, op);
            } else {
                RETURN_RICHCOMPARE(0, 1, op);
            }
233 234 235
        }
        Py_RETURN_NOTIMPLEMENTED;
    }
236 237
    static bool load(py::handle src, T& value) {
        PyObject* obj = src.ptr();
238
        if (PyObject_TypeCheck(obj, type)) {
239 240 241 242
            value = reinterpret_cast<EnumWrapper*>(obj)->value;
            return true;
        }
        if (py::isinstance<py::str>(src)) {
M
Megvii Engine Team 已提交
243
            auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
244 245 246 247 248 249
            if (iter != mem2value.end()) {
                value = iter->second;
                return true;
            } else {
                return false;
            }
250
        }
251
        return false;
252
    }
253 254 255 256 257
    static PyObject* cast(const T& value) {
        auto v = static_cast<std::underlying_type_t<T>>(value);
        mgb_assert(v <= EnumTrait<T>::max);
        PyObject* obj = pyobj_insts[v];
        Py_INCREF(obj);
258 259 260 261
        return obj;
    }
};

M
Megvii Engine Team 已提交
262
template <typename T>
263
struct BitCombinedEnumWrapper {
M
Megvii Engine Team 已提交
264
    PyEnumHead std::string to_string() const {
265 266
        uint32_t value_int = static_cast<uint32_t>(value);
        if (value_int == 0) {
267 268
            return "None";
        } else {
269
            std::string ret;
270 271
            bool first = true;
            for (uint32_t i = 0; i < 32; i++) {
272
                if (value_int >> i & 1) {
273 274 275 276 277
                    if (!first) {
                        ret += " + ";
                    } else {
                        first = false;
                    }
278
                    ret += (std::string(name) + "." + members[i]);
279 280 281 282 283
                }
            }
            return ret;
        }
    }
M
Megvii Engine Team 已提交
284 285
    static PyObject* py_new_combined_enum(
            PyTypeObject* type, PyObject* args, PyObject*) {
286 287 288 289
        if (!PyTuple_Size(args)) {
            PyObject* obj = type->tp_alloc(type, 0);
            reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = T();
            return obj;
M
Megvii Engine Team 已提交
290
        } else {
291 292 293 294 295
            PyObject* input;
            if (!PyArg_ParseTuple(args, "|O", &input)) {
                return nullptr;
            }
            T value;
296 297 298
            if (load(input, value)) {
                return cast(value);
            } else {
M
Megvii Engine Team 已提交
299 300 301 302 303 304
                PyErr_SetString(
                        PyExc_RuntimeError,
                        mgb::ssprintf(
                                "Cannot convert type %s to type %s\n",
                                input->ob_type->tp_name, name)
                                .c_str());
305 306
                return nullptr;
            }
307 308 309
        }
    }
    static PyObject* py_repr(PyObject* self) {
M
Megvii Engine Team 已提交
310 311 312
        return py::cast(reinterpret_cast<BitCombinedEnumWrapper*>(self)->to_string())
                .release()
                .ptr();
313
    }
314 315 316 317 318 319 320 321 322 323 324 325 326

    static PyObject* py_dump(PyObject* self) {
        std::vector<std::string> result;
        auto value = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value;
        uint32_t value_int = static_cast<uint32_t>(value);
        for (uint32_t i = 0; i < 32; i++) {
            if (value_int >> i & 1) {
                result.push_back(members[i]);
            }
        }
        return py::tuple(py::cast(result)).release().ptr();
    }

327
    static PyObject* py_or(PyObject* self, PyObject* other) {
M
Megvii Engine Team 已提交
328
        if (!(self->ob_type == other->ob_type)) {
329 330 331 332 333 334
            return PyErr_Format(
                    PyExc_RuntimeError,
                    "Operand in or operator must be the same type.");
        }
        T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
          rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
335
        return cast(lhs | rhs);
336 337 338 339 340 341 342 343 344
    }
    static PyObject* py_and(PyObject* self, PyObject* other) {
        if (!(self->ob_type == other->ob_type)) {
            return PyErr_Format(
                    PyExc_RuntimeError,
                    "Operand in and operator must be the same type.");
        }
        T lhs = reinterpret_cast<BitCombinedEnumWrapper*>(self)->value,
          rhs = reinterpret_cast<BitCombinedEnumWrapper*>(other)->value;
345
        return cast(lhs & rhs);
346 347 348
    }
    static PyObject* tp_richcompare(PyObject* self, PyObject* other, int op) {
        if (op == Py_EQ || op == Py_NE) {
349 350 351 352 353 354
            T lhs, rhs;
            if (load(other, rhs) && load(self, lhs)) {
                RETURN_RICHCOMPARE(lhs, rhs, op);
            } else {
                RETURN_RICHCOMPARE(0, 1, op);
            }
355 356 357
        }
        Py_RETURN_NOTIMPLEMENTED;
    }
358 359
    static bool load(py::handle src, T& value) {
        PyObject* obj = src.ptr();
360
        if (PyObject_TypeCheck(obj, type)) {
361 362 363 364
            value = reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value;
            return true;
        }
        if (py::isinstance<py::str>(src)) {
M
Megvii Engine Team 已提交
365
            auto&& iter = mem2value.find(normalize_enum(py::cast<std::string>(src)));
366 367 368 369 370 371 372
            if (iter != mem2value.end()) {
                value = iter->second;
                return true;
            } else {
                return false;
            }
        }
373
        if (py::isinstance<py::tuple>(src)) {
M
Megvii Engine Team 已提交
374
            auto params = py::cast<std::vector<std::string>>(src);
375
            bool first = true;
M
Megvii Engine Team 已提交
376
            for (auto s : params) {
377 378 379 380 381 382 383 384 385 386 387 388 389 390
                auto&& iter = mem2value.find(normalize_enum(s));
                if (iter != mem2value.end()) {
                    if (first) {
                        value = iter->second;
                        first = false;
                    } else {
                        value |= iter->second;
                    }
                } else {
                    return false;
                }
            }
            return true;
        }
391 392
        if (py::isinstance<py::int_>(obj)) {
            auto v = py::cast<std::underlying_type_t<T>>(src);
M
Megvii Engine Team 已提交
393
            if (v > EnumTrait<T>::max) {
394 395 396 397
                return false;
            }
            value = static_cast<T>(v);
            return true;
398
        }
399
        return false;
400
    }
401 402 403 404
    static PyObject* cast(const T& value) {
        auto v = static_cast<std::underlying_type_t<T>>(value);
        mgb_assert(v <= EnumTrait<T>::max);
        if ((!v) || (v & (v - 1))) {
405
            PyObject* obj = type->tp_alloc(type, 0);
406 407 408 409 410 411 412
            reinterpret_cast<BitCombinedEnumWrapper*>(obj)->value = value;
            return obj;
        } else {
            PyObject* obj = pyobj_insts[__builtin_ctz(v)];
            Py_INCREF(obj);
            return obj;
        }
413 414 415
    }
};

M
Megvii Engine Team 已提交
416 417
template <typename T>
struct serialization<T, std::enable_if_t<std::is_enum_v<std::decay_t<T>>>> {
418 419 420 421 422
    static T load(py::object obj) {
        auto caster = pybind11::detail::type_caster<T>();
        if (caster.load(obj, true)) {
            return caster;
        } else {
M
Megvii Engine Team 已提交
423 424
            PyErr_SetString(PyExc_RuntimeError, "load faild \n");
            return caster;
425 426
        }
    }
M
Megvii Engine Team 已提交
427
    static py::object dump(T t) { return py::cast(t).attr("dump")(); }
428 429
};

430
void _init_py_op_def(py::module m) {
431
    using py_op = PyOp(OpDef);
432 433 434 435 436 437 438 439 440
    auto& py_type = PyOpType(OpDef);
    py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
    py_type.tp_name = "megengine.core._imperative_rt.OpDef";
    py_type.tp_basicsize = sizeof(PyOp(OpDef));
    py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
    py_type.tp_doc = "OpDef";
    py_type.tp_base = &PyBaseObject_Type;
    py_type.tp_hash = PyOp(OpDef)::tp_hash;
    py_type.tp_richcompare = PyOp(OpDef)::tp_richcompare;
441
    py_type.tp_getset = py_op::py_getsetters;
442
    py_type.tp_repr = py_op::py_repr;
443
    py_type.tp_dealloc = py_dealloc_generic<PyOp(OpDef)>;
444 445 446 447 448
    mgb_assert(PyType_Ready(&py_type) >= 0);
    m.add_object("OpDef", reinterpret_cast<PyObject*>(&py_type));
}

/*********** begin of hand-write opdefs **************/
449 450 451 452 453 454 455
struct PyOpBase : PyOpDef {
    static PyTypeObject py_type;

    static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) {
        auto* obj = type->tp_alloc(type, 0);
        if (obj) {
            auto* self = reinterpret_cast<PyOpBase*>(obj);
M
Megvii Engine Team 已提交
456
            new (&self->op) decltype(self->op);
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
        }
        return obj;
    }
};
PyTypeObject PyOpBase::py_type;

void _init_py_op_base(py::module m) {
    using py_op = PyOpBase;
    auto& py_type = PyOpBase::py_type;
    py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
    py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase";
    py_type.tp_basicsize = sizeof(py_op);
    py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
    py_type.tp_doc = "PyOpBase";
    py_type.tp_base = &PyOpType(OpDef);
    py_type.tp_dealloc = py_dealloc_generic<py_op>;
    py_type.tp_new = py_op::tp_new;
    mgb_assert(PyType_Ready(&py_type) >= 0);
    m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type));
}

478 479 480 481 482
/*********** end of hand-write opdefs **************/

// auto generated opdefs
#include "opdef.cpy.inl"

483
#undef CATCH_ALL
M
Megvii Engine Team 已提交
484
}  // anonymous namespace
485 486 487 488 489 490 491 492 493

namespace PYBIND11_NAMESPACE {
namespace detail {
bool type_caster<OpDef>::load(handle src, bool convert) {
    PyObject* obj = src.ptr();
    if (!PyObject_TypeCheck(obj, &PyOpType(OpDef))) {
        return false;
    }
    value = reinterpret_cast<PyOp(OpDef)*>(obj)->op;
494 495 496 497
    if (!value) {
        // opdef only defined in Python
        value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src));
    }
498 499 500
    return true;
}
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) {
501 502 503
    if (auto* pyop = op.try_cast_final<GenericPyOp>()) {
        return object(pyop->obj).release();
    }
504 505 506
    PyTypeObject* pytype;
    auto& c2p = PyOp(OpDef)::ctype2pytype;
    auto&& iter = c2p.find(op.dyn_typeinfo());
M
Megvii Engine Team 已提交
507
    if (iter != c2p.end()) {  // FIXME: should always meet this condition
508
        pytype = iter->second;
M
Megvii Engine Team 已提交
509
    } else {  // which means unregistered op type, jsut make it as an opaque op type
510 511 512 513 514 515 516 517
        // currently, only OprAttr goes into this branch
        pytype = &PyOpType(OpDef);
    }
    PyObject* obj = pytype->tp_alloc(pytype, 0);
    mgb_assert(PyObject_TypeCheck(obj, &PyOpType(OpDef)));
    reinterpret_cast<PyOp(OpDef)*>(obj)->op = const_cast<OpDef&>(op).shared_from_this();
    return py::handle(obj);
}
518

M
Megvii Engine Team 已提交
519 520 521 522 523 524 525
#define ENUM_CASTER_IMPL(T)                                                    \
    bool type_caster<T>::load(handle src, bool) {                              \
        return EnumWrapper<T>::load(src, value);                               \
    }                                                                          \
    handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
        return EnumWrapper<T>::cast(value);                                    \
    }
526 527
FOR_EACH_ENUM_PARAM(ENUM_CASTER_IMPL)

M
Megvii Engine Team 已提交
528 529 530 531 532 533 534
#define BIT_COMBINED_ENUM_CASTER_IMPL(T)                                       \
    bool type_caster<T>::load(handle src, bool) {                              \
        return BitCombinedEnumWrapper<T>::load(src, value);                    \
    }                                                                          \
    handle type_caster<T>::cast(const T& value, return_value_policy, handle) { \
        return BitCombinedEnumWrapper<T>::cast(value);                         \
    }
535 536
FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)

M
Megvii Engine Team 已提交
537 538
}  // namespace detail
}  // namespace PYBIND11_NAMESPACE
539

540
void init_ops(py::module m) {
541
    _init_py_op_def(m);
542
    _init_py_op_base(m);
543
    INIT_ALL_OP(m)
544

545
    m.def("new_rng_handle", &rng::new_handle);
M
Megvii Engine Team 已提交
546 547 548
    m.def(
            "delete_rng_handle",
            [](size_t handle) {
549 550 551
                if (mgb::imperative::python::interpreter_for_py->check_available()) {
                    mgb::imperative::python::interpreter_for_py->sync();
                }
M
Megvii Engine Team 已提交
552
                mgb::CompNode::sync_all();
553 554 555 556
                mgb::CompNode::foreach ([](mgb::CompNode cn) {
                    auto err = cn.check_async_error();
                    mgb_assert(!err, "%s", err->what());
                });
M
Megvii Engine Team 已提交
557 558 559 560
                py_task_q.wait_all_task_finish();
                rng::delete_handle(handle);
            },
            py::call_guard<py::gil_scoped_release>());
561 562 563 564 565 566 567 568 569
    m.def("set_global_rng_seed", [](uint64_t seed) -> void {
        mgb_assert(
                python::interpreter_for_py->check_available(),
                "set global random seed failed since imperative interpreter has been "
                "destroyed");
        python::interpreter_for_py->sync();
        mgb::CompNode::sync_all();
        rng::set_global_rng_seed(seed);
    });
570
    m.def("get_global_rng_seed", &rng::get_global_rng_seed);
571
    m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode);
572 573

    struct PySubgraphBuilder {
M
Megvii Engine Team 已提交
574
        explicit PySubgraphBuilder(std::string name) : name{name} {}
575
        std::string name;
576
        Subgraph graph;
577 578
        mgb::SmallVector<bool> output_grad_mask;
        Subgraph::var_t next_var = 1;
579
        std::shared_ptr<mgb::Hashable> key = nullptr;
580

581 582 583 584 585 586
        std::shared_ptr<OpDef> build() {
            if (key == nullptr) {
                key = std::make_shared<UniqueKey>();
            }
            return SubgraphOp::make(
                    name, std::make_shared<Subgraph>(graph), output_grad_mask, key);
587
        }
588 589 590
    };

    py::class_<PySubgraphBuilder>(m, "SubgraphBuilder")
M
Megvii Engine Team 已提交
591
            .def(py::init<std::string>())
592
            .def(py::init<PySubgraphBuilder>())
M
Megvii Engine Team 已提交
593 594
            .def("input",
                 [](PySubgraphBuilder& self) {
595
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
596 597 598 599 600 601 602
                     auto var = self.next_var++;
                     self.graph.inputs.push_back(var);
                     return var;
                 })
            .def("apply",
                 [](PySubgraphBuilder& self, std::shared_ptr<OpDef> op,
                    Subgraph::vars_t inputs, size_t nr_outputs) {
603
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
604 605 606 607 608 609 610 611 612 613
                     Subgraph::vars_t outputs;
                     for (size_t i = 0; i < nr_outputs; ++i) {
                         outputs.push_back(self.next_var++);
                     }
                     self.graph.exprs.push_back({op, inputs, outputs});
                     return outputs;
                 })
            .def("apply_const",
                 [](PySubgraphBuilder& self, py::object value, mgb::DType dtype,
                    mgb::CompNode cn) {
614
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
615 616 617 618 619 620 621 622 623 624
                     auto var = self.next_var++;
                     mgb::HostTensorND hvalue(cn);
                     npy::np2tensor(
                             value.cast<py::array>().ptr(),
                             npy::Meth::copy_into(&hvalue), dtype);
                     self.graph.constants.push_back({var, Tensor::make(hvalue)});
                     return var;
                 })
            .def("outputs",
                 [](PySubgraphBuilder& self, Subgraph::vars_t outputs) {
625
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
626 627 628 629 630
                     self.graph.outputs = outputs;
                     self.output_grad_mask.resize(outputs.size(), true);
                 })
            .def("outputs_has_grad",
                 [](PySubgraphBuilder& self, mgb::SmallVector<bool> outputs_has_grad) {
631
                     mgb_assert(self.key == nullptr);
M
Megvii Engine Team 已提交
632 633 634 635 636 637 638 639
                     mgb_assert(
                             self.graph.outputs.size() == self.output_grad_mask.size());
                     self.output_grad_mask = outputs_has_grad;
                 })
            .def("get",
                 [](PySubgraphBuilder& self) {
                     return (std::shared_ptr<OpDef>)self.build();
                 })
640 641 642 643 644 645
            .def("compile",
                 [](PySubgraphBuilder& self, int gopt_level) {
                     return (std::shared_ptr<OpDef>)CompiledOp::make(
                             self.build(), gopt_level);
                 })
            .def("jit_fuse", [](PySubgraphBuilder& self) {
M
Megvii Engine Team 已提交
646
                return (std::shared_ptr<OpDef>)CompiledOp::make(
647
                        JITFusionOp::make(self.build()));
M
Megvii Engine Team 已提交
648 649
            });

650
    m.def("set_jit_enabled", &JITFusionOp::set_enabled);
651 652 653 654 655
    bool jit_supported = false;
#if MGB_JIT
    jit_supported = true;
#endif
    m.attr("jit_supported") = jit_supported;
656

657 658 659 660
    auto custom = submodule(m, "_custom");
    init_custom(custom);
}

M
Megvii Engine Team 已提交
661 662 663 664
#define CUSTOM_CASE_TO_PARSE_NON_LIST(dyn_type, static_type)   \
    case custom::ParamDynType::dyn_type: {                     \
        param_val = py::handle(kv.second).cast<static_type>(); \
        break;                                                 \
665 666
    }

M
Megvii Engine Team 已提交
667 668 669 670 671 672 673 674 675 676
#define CUSTOM_CASE_TO_PARSE_LIST(dyn_type, static_type)                            \
    case custom::ParamDynType::dyn_type: {                                          \
        auto pyvals = py::handle(kv.second).cast<py::list>();                       \
        static_type vals;                                                           \
        using basic_type = custom::get_vector_template_arg_type<static_type>::type; \
        for (auto& pyval : pyvals) {                                                \
            vals.push_back(py::handle(pyval).cast<basic_type>());                   \
        }                                                                           \
        param_val = vals;                                                           \
        break;                                                                      \
677 678
    }

M
Megvii Engine Team 已提交
679
PyObject* make_custom_op(PyObject* self, PyObject** args, Py_ssize_t nargs) {
680
#if MGB_CUSTOM_OP
681 682 683 684
    auto op_name = py::handle(args[0]).cast<std::string>();
    auto kwargs = py::handle(args[1]).cast<py::dict>();

    std::shared_ptr<OpDef> opdef = CustomOpDefFactory::inst()->create_opdef(op_name);
M
Megvii Engine Team 已提交
685 686
    auto& custom_opdef = static_cast<mgb::imperative::CustomOpDef&>(*opdef);
    auto& param = custom_opdef.param();
687

M
Megvii Engine Team 已提交
688
    for (auto&& kv : kwargs) {
689 690
        std::string param_name = py::handle(kv.first).cast<std::string>();
        std::string type_name = py::handle(kv.second).ptr()->ob_type->tp_name;
M
Megvii Engine Team 已提交
691

692 693
        if (!param.exist(param_name)) {
            mgb_log_warn(
M
Megvii Engine Team 已提交
694 695 696
                    "op %s have no param named %s, ignore this param parsed from "
                    "python",
                    op_name.c_str(), param_name.c_str());
697 698 699 700 701 702 703 704 705 706 707 708
            continue;
        }

        auto& param_val = param[param_name];
        switch (param_val.type()) {
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
            CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PARSE_NON_LIST)
            CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
            CUSTOM_FOR_BOOL_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
            CUSTOM_FOR_STRING_LIST_PARAMTYPE(CUSTOM_CASE_TO_PARSE_LIST)
            default: {
                mgb_assert(
M
Megvii Engine Team 已提交
709 710
                        false, "param dtype of %s:%s is invalid", op_name.c_str(),
                        param_name.c_str());
711 712 713 714 715 716 717 718
            }
        }
    }

    PyTypeObject* pytype;
    pytype = &PyOpType(OpDef);
    PyObject* obj = pytype->tp_alloc(pytype, 0);
    reinterpret_cast<PyOp(OpDef)*>(obj)->op = opdef;
M
Megvii Engine Team 已提交
719

720
    return obj;
721
#else
M
Megvii Engine Team 已提交
722 723 724
    mgb_assert(
            false,
            "Custom Op is disabled now, please build megengine with Custom Op open");
725 726
    return nullptr;
#endif
727 728 729 730 731
}

#undef CUSTOM_CASE_TO_PARSE_LIST
#undef CUSTOM_CASE_TO_PARSE_NON_LIST

M
Megvii Engine Team 已提交
732
py::list install_custom(const std::string& name, const std::string& path) {
733
#if MGB_CUSTOM_OP
734
    py::list ret;
M
Megvii Engine Team 已提交
735 736
    const auto& ops_in_lib = custom::LibManager::inst()->install(name, path);
    for (const auto& op : ops_in_lib) {
737 738
        ret.append(op);
    }
739
    return ret;
740
#else
M
Megvii Engine Team 已提交
741 742 743
    mgb_assert(
            false,
            "Custom Op is disabled now, please build megengine with Custom Op open");
744 745 746
    py::list ret;
    return ret;
#endif
747 748
}

M
Megvii Engine Team 已提交
749
bool uninstall_custom(const std::string& name) {
750
#if MGB_CUSTOM_OP
751
    return custom::LibManager::inst()->uninstall(name);
752
#else
M
Megvii Engine Team 已提交
753 754 755
    mgb_assert(
            false,
            "Custom Op is disabled now, please build megengine with Custom Op open");
756 757
    return false;
#endif
758 759 760
}

py::list get_custom_op_list(void) {
761
#if MGB_CUSTOM_OP
762 763
    std::vector<std::string> all_ops = CustomOpDefFactory::inst()->op_list();
    py::list ret;
M
Megvii Engine Team 已提交
764
    for (auto& op : all_ops) {
765 766
        ret.append(op);
    }
767
    return ret;
768
#else
M
Megvii Engine Team 已提交
769 770 771
    mgb_assert(
            false,
            "Custom Op is disabled now, please build megengine with Custom Op open");
772 773 774
    py::list ret;
    return ret;
#endif
775 776
}

777
#ifndef METH_FASTCALL
M
Megvii Engine Team 已提交
778 779 780 781 782
PyObject* py35_make_custom_op(PyObject* self, PyObject* args) {
    auto* arr = &PyTuple_GET_ITEM(args, 0);
    auto size = PyTuple_GET_SIZE(args);
    return make_custom_op(self, arr, size);
};
783 784
#endif

785
void init_custom(pybind11::module m) {
786 787 788
    m.def("_install", &install_custom);
    m.def("_uninstall", &uninstall_custom);
    m.def("_get_custom_op_list", &get_custom_op_list);
789 790 791 792 793 794 795
    m.def("get_custom_op_abi_tag", [](void) -> int {
        int ret = 0;
#ifdef _GLIBCXX_USE_CXX11_ABI
        ret = _GLIBCXX_USE_CXX11_ABI;
#endif
        return ret;
    });
796 797

    static PyMethodDef method_def = {
798
#ifdef METH_FASTCALL
M
Megvii Engine Team 已提交
799
            "_make_custom_op", (PyCFunction)make_custom_op, METH_FASTCALL, ""
800
#else
M
Megvii Engine Team 已提交
801
            "_make_custom_op", (PyCFunction)py35_make_custom_op, METH_VARARGS, ""
802
#endif
803 804 805
    };
    auto* func = PyCFunction_NewEx(&method_def, nullptr, nullptr);
    pybind11::setattr(m, method_def.ml_name, func);
806
}