tensor.cpp 41.1 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "megbrain/common.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/dtype.h"
14
#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
15 16
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
17
#include "megbrain/imperative/profiler.h"
18
#include "megbrain/opr/io.h"
19

20
#include "./common.h"
M
Megvii Engine Team 已提交
21
#include "./grad.h"
22
#include "./graph_rt.h"
23
#include "./helper.h"
M
Megvii Engine Team 已提交
24 25 26 27
#include "./module_trace.h"
#include "./numpy_dtypes.h"
#include "./tensor.h"
#include "./trace.h"
28

29
#include <object.h>
30 31
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
32 33
#include <pybind11/pytypes.h>
#include <pyerrors.h>
34
#include <range/v3/all.hpp>
35
#include <string>
36 37 38

#include <unordered_map>

39
namespace py = pybind11;
40
namespace views = ranges::views;
41 42 43

namespace mgb::imperative::python {

44
interpreter::Interpreter::Channel* interpreter_for_py;
45

46
PyObject *cpp_apply_with_tracing, *cpp_apply_const_with_tracing;
M
Megvii Engine Team 已提交
47 48
PyObject* cpp_apply_backward_varnode;
PyObject* cpp_apply_module_trace;
49

50 51
std::shared_ptr<Tensor> make_const(imperative::TensorPtr value) {
    if (!(ApplyContext::global_enable & Tensor::Flags::TRACE)) {
M
Megvii Engine Team 已提交
52 53
        return std::make_shared<Tensor>(
                interpreter_for_py->put(value->dev_tensor(), value->get_value()));
54 55 56
    }
    py::tuple tup(6);
    auto data = value->get_value();
M
Megvii Engine Team 已提交
57 58
    tup[0] = py::reinterpret_steal<py::array>(
            ndarray_from_tensor(data, npy::ShareType::MUST_SHARE));
59 60 61 62 63 64
    tup[1] = value->dtype();
    tup[2] = value->comp_node();
    tup[3] = true;
    tup[4] = false;
    tup[5] = py::none{};
    auto py_ret = PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
M
Megvii Engine Team 已提交
65 66
    if (!py_ret)
        throw py::error_already_set();
67 68 69 70 71 72
    auto py_list = py::reinterpret_steal<py::list>(py_ret);
    auto* tensor_wrapper = TensorWrapper::try_cast(py_list[0].ptr());
    auto tensor = tensor_wrapper->m_tensor;
    return tensor_wrapper->m_tensor;
}

M
Megvii Engine Team 已提交
73 74
#define REGISTE_APPLY_FUNC(mode) \
    void set_##mode(py::object pyf) { mode = pyf.ptr(); }
75 76 77 78

REGISTE_APPLY_FUNC(cpp_apply_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing)
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode)
79
REGISTE_APPLY_FUNC(cpp_apply_module_trace)
80 81 82

#undef REGISTE_APPLY_FUNC

83 84
Tensor::flags_t ApplyContext::global_disable = 0;
Tensor::flags_t ApplyContext::global_enable = 0;
85

M
Megvii Engine Team 已提交
86 87 88 89 90 91
void set_tracing() {
    ApplyContext::global_enable |= Tensor::Flags::TRACE;
}
void unset_tracing() {
    ApplyContext::global_enable &= ~Tensor::Flags::TRACE;
}
92

M
Megvii Engine Team 已提交
93 94 95 96 97 98
void set_module_tracing() {
    ApplyContext::global_enable |= Tensor::Flags::MODULE_TRACE;
}
void unset_module_tracing() {
    ApplyContext::global_enable &= ~Tensor::Flags::MODULE_TRACE;
}
99 100 101 102
bool is_tracing_module() {
    return ApplyContext::global_enable & Tensor::Flags::MODULE_TRACE;
}

103 104
bool skip_tracing = false;

105 106 107 108
apply_result_t apply(ApplyContext& ctx) {
    // emulating scalar should be put to specific op's apply, e.g.,
    // elementwise, reduce, typecvt. Currently it's still handled at python
    // side. It could be move to C++ side if it has an impact on performance
109
    auto flags = ctx.flags & ~ApplyContext::global_disable;
110
    flags = flags | ApplyContext::global_enable;
111 112

    if (flags & Tensor::Flags::SCALAR) {
113 114 115
        // TODO: emulate scalar
    }

116
    if (flags & Tensor::Flags::GRAD) {
117 118 119
        return apply_grad(ctx);
    }

120 121 122 123 124 125
    if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) {
        py::tuple pyin(ctx.nargs);
        for (size_t i = 0; i < ctx.nargs; ++i) {
            pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this());
        }
        auto f = py::getattr(op->obj, "_default_rule");
M
Megvii Engine Team 已提交
126 127 128 129
        auto pyout = py::reinterpret_steal<py::object>(
                PyObject_Call(f.ptr(), pyin.ptr(), nullptr));
        if (!pyout)
            throw py::error_already_set();
130 131 132 133 134 135 136 137 138 139 140 141 142
        if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) {
            return {tw->m_tensor};
        }
        apply_result_t ret;
        ret.reserve(py::len(pyout));
        for (auto&& i : pyout) {
            auto* tw = TensorWrapper::try_cast(i.ptr());
            mgb_assert(tw);
            ret.push_back(tw->m_tensor);
        }
        return ret;
    }

143 144 145 146
    if (flags & Tensor::Flags::MODULE_TRACE) {
        return apply_module_trace(ctx);
    }

147
    if (flags & Tensor::Flags::TRACE) {
148
        return apply_trace(ctx);
149 150 151 152 153 154
    } else {
        SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs);
        for (size_t i = 0; i < ctx.nargs; ++i) {
            handles[i] = ctx.args[i]->m_handle.get();
        }

155 156 157 158 159 160 161 162 163 164
        apply_result_t outputs;

        // fast copy without really applying
        if (ctx.op->same_type<FastpathCopy>()) {
            mgb_assert(ctx.nargs == 1);
            outputs.reserve(ctx.nargs);
            outputs.emplace_back(std::make_shared<Tensor>(ctx.args[0]->m_handle));
            return outputs;
        }

165 166 167 168 169 170 171 172 173 174 175 176
        auto output_handles = interpreter_for_py->apply_op(ctx.op, handles);

        outputs.reserve(output_handles.size());
        for (auto h : output_handles) {
            outputs.emplace_back(std::make_shared<Tensor>(h));
        }
        return outputs;
    }

    mgb_assert(0);
}

M
Megvii Engine Team 已提交
177 178
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
179 180 181 182 183
    try {
        // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
        //     PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
        //     return nullptr;
        // }
184
        if (nargs < 2) {
M
Megvii Engine Team 已提交
185 186 187 188
            PyErr_SetString(
                    PyExc_TypeError,
                    "py_apply expects one Op and at least one tensor "
                    "as argument");
189 190
            return nullptr;
        }
191

192 193 194
        auto* op = args[0];

        PyTypeObject* pytype = args[1]->ob_type;
195 196 197 198 199 200

        // check if pytype is Parameter(and all other python Tensor's derived class),
        // if yes, using it's tp_base(python Tensor)
        if (TensorWrapper::wrap_t::type().same_pytype(pytype->tp_base->tp_base)) {
            pytype = pytype->tp_base;
        }
201 202 203 204 205 206 207 208 209
        ++args;
        --nargs;

        ApplyContext ctx;
        ctx.flags = 0;
        ctx.op = py::handle(op).cast<std::shared_ptr<OpDef>>();
        SmallVector<Tensor*, 64> tensors(nargs);
        ctx.args = &tensors[0];
        ctx.nargs = nargs;
210
        ctx.pytype = pytype;
211

M
Megvii Engine Team 已提交
212
        if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
213 214
            SmallVector<cg::VarNode*> vinputs(nargs);
            for (size_t i = 0; i < nargs; ++i) {
215
                vinputs[i] = py::handle(args[i]).cast<PySymbolVar*>()->m_node;
216 217 218 219 220
            }
            auto op = ctx.op.get();
            auto rst = OpDef::apply_on_var_node(*op, vinputs);
            auto ret = pybind11::tuple(rst.size());
            auto typeobj = py::handle(args[0]).get_type();
M
Megvii Engine Team 已提交
221 222 223
            for (size_t i = 0; i < rst.size(); ++i) {
                ret[i] = typeobj(pybind11::cast(
                        rst[i], pybind11::return_value_policy::automatic));
224 225 226
            }
            return ret.release().ptr();
        }
227 228

        for (size_t i = 0; i < nargs; ++i) {
229
            if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
230 231 232
                auto* t = tensors[i] = tw->m_tensor.get();
                ctx.flags |= t->m_flags;
            } else {
M
Megvii Engine Team 已提交
233 234 235 236 237 238
                PyErr_SetString(
                        PyExc_TypeError,
                        ssprintf(
                                "op %s expect type Tensor as inputs, got %s actually",
                                ctx.op->make_name().c_str(), Py_TYPE(args[i])->tp_name)
                                .c_str());
239 240 241 242 243 244 245 246 247 248 249
                return nullptr;
            }
        }

        auto outputs = apply(ctx);
        size_t nout = outputs.size();
        auto ret = py::tuple(nout);
        for (size_t i = 0; i < nout; ++i) {
            ret[i] = TensorWrapper::make(pytype, std::move(outputs[i]));
        }
        return ret.release().ptr();
M
Megvii Engine Team 已提交
250 251
    }
    PYEXT17_TRANSLATE_EXC_RET(nullptr)
252 253 254 255 256 257 258 259 260 261 262
}

TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
    if (kwargs && PyDict_Size(kwargs)) {
        throw py::type_error("keyword argument not allowed");
    }
    auto nargs = PyTuple_Size(args);
    auto tup = py::reinterpret_borrow<py::tuple>(args);
    if (nargs == 0) {
        throw py::type_error("too few arguments");
    }
263
    if (auto* t = try_cast(tup[0].ptr())) {
264 265 266 267 268
        if (nargs > 1) {
            throw py::type_error("expect 1 argument");
        }
        m_tensor = t->m_tensor;
    } else {
269 270 271 272 273 274 275
        if (nargs == 1) {
            auto arg0 = PyTuple_GetItem(args, 0);
            // for lazy_eval_tensor
            if (strstr(arg0->ob_type->tp_name, "VarNode")) {
                if (PyObject_HasAttrString(arg0, "_node")) {
                    arg0 = PyObject_GetAttrString(arg0, "_node");
                }
M
Megvii Engine Team 已提交
276 277
                m_tensor =
                        std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode*>());
278 279 280 281
            } else {
                // for DeviceTensorND
                if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) {
                    auto dv = py::handle(arg0).cast<DeviceTensorND>();
M
Megvii Engine Team 已提交
282 283
                    interpreter::Interpreter::Handle handle =
                            interpreter_for_py->put(dv, {});
284 285
                    m_tensor = std::make_shared<Tensor>(handle);
                } else {
M
Megvii Engine Team 已提交
286 287
                    throw py::type_error(
                            "single argument is not tensor, varnode or devicetensor");
288 289
                }
            }
290
        } else {
M
Megvii Engine Team 已提交
291
            py::detail::loader_life_support life_sup;  // FIXME!!!required to cast DType
292 293
            if (nargs != 5 && nargs != 6) {
                throw py::type_error("expect 5 or 6 arguments");
294
            }
295 296 297 298
            auto data = tup[0].cast<py::array>();
            DType dtype = tup[1].cast<DType>();
            CompNode cn = tup[2].cast<CompNode>();
            bool is_const = tup[3].cast<bool>();
299
            bool no_cache = nargs == 6 ? tup[4].cast<bool>() : false;
300
            std::string name;
M
Megvii Engine Team 已提交
301 302
            if (tup[nargs - 1].ptr() != Py_None)
                name = tup[nargs - 1].cast<std::string>();
303 304

            // const op
305
            if (is_const && (ApplyContext::global_enable == Tensor::Flags::TRACE)) {
M
Megvii Engine Team 已提交
306 307 308 309
                auto py_ret =
                        PyObject_Call(cpp_apply_const_with_tracing, tup.ptr(), nullptr);
                if (!py_ret)
                    throw py::error_already_set();
310 311
                auto py_list = py::reinterpret_steal<py::list>(py_ret);
                if (auto* t = try_cast(py_list[0].ptr())) {
312 313 314 315 316 317
                    m_tensor = t->m_tensor;
                }
                return;
            }

            interpreter::Interpreter::Handle handle;
318
            {
319
                HostTensorND ret(cn);
M
Megvii Engine Team 已提交
320 321 322
                handle = interpreter_for_py->put(
                        npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype),
                        no_cache);
323 324 325
            }

            m_tensor = std::make_shared<Tensor>(handle);
326
            m_tensor->user_custom_name = name;
327

328 329 330
            if (data.ndim() == 0) {
                m_tensor->m_flags |= Tensor::Flags::SCALAR;
            }
331 332 333 334
        }
    }
}

M
Megvii Engine Team 已提交
335 336 337 338 339 340 341 342 343
#define REGISTE_TENSORWRAPPER_FUNC(type, member)                        \
    PyObject* TensorWrapper::member() {                                 \
        return py::cast(m_tensor->m_trace_info.member).release().ptr(); \
    }                                                                   \
    void TensorWrapper::set_##member(PyObject* dest) {                  \
        auto py_dest = py::reinterpret_borrow<py::object>(dest);        \
        type real_dest = py_dest.cast<type>();                          \
        m_tensor->m_trace_info.member = real_dest;                      \
    }
344 345

REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle)
346
REGISTE_TENSORWRAPPER_FUNC(bool, recording)
347 348 349

#undef REGISTE_TENSORWRAPPER_FUNC

350 351
PyObject* TensorWrapper::module_trace_info() {
    if (!m_tensor->m_module_trace_info.ptr()) {
M
Megvii Engine Team 已提交
352 353 354 355
        PyErr_SetString(
                PyExc_AttributeError,
                "Has no attribute named \'_NodeMixin__node\', please "
                "set it first");
356 357 358 359 360 361 362 363 364
        return nullptr;
    }
    return m_tensor->m_module_trace_info.inc_ref().ptr();
}

void TensorWrapper::set_module_trace_info(PyObject* obj) {
    m_tensor->m_module_trace_info = py::reinterpret_borrow<py::object>(obj);
}

M
Megvii Engine Team 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
#define REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(member)    \
    PyObject* TensorWrapper::member() {                \
        if (m_tensor->m_trace_info.member) {           \
            return m_tensor->m_trace_info.member;      \
        } else {                                       \
            Py_RETURN_NONE;                            \
        }                                              \
    }                                                  \
    void TensorWrapper::set_##member(PyObject* dest) { \
        if (dest == Py_None) {                         \
            Py_XDECREF(m_tensor->m_trace_info.member); \
            m_tensor->m_trace_info.member = nullptr;   \
        } else {                                       \
            Py_INCREF(dest);                           \
            m_tensor->m_trace_info.member = dest;      \
        }                                              \
    }
382 383 384 385 386 387

REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(compiled_info)
REGISTE_TENSORWRAPPER_PYOBJECT_FUNC(trace_mixin_info)

#undef REGISTE_TENSORWRAPPER_PYOBJECT_FUNC

388 389 390 391 392 393 394 395 396 397 398 399
#define SET_GET_NAME(member)                                     \
    PyObject* TensorWrapper::member() {                          \
        return py::cast(m_tensor->member).release().ptr();       \
    }                                                            \
    void TensorWrapper::set_##member(PyObject* dest) {           \
        auto py_dest = py::reinterpret_borrow<py::object>(dest); \
        m_tensor->member = py_dest.cast<std::string>();          \
    }
SET_GET_NAME(user_custom_name)
SET_GET_NAME(automatic_name)
#undef SET_GET_NAME

400 401 402 403 404 405 406 407 408 409
PyObject* TensorWrapper::handle() {
    return py::cast(m_tensor->m_handle).release().ptr();
}

void TensorWrapper::set_handle(PyObject* dest) {
    auto py_dest = py::reinterpret_borrow<py::object>(dest);
    SharedHandle real_dest = py_dest.cast<SharedHandle>();
    m_tensor->m_handle = std::move(real_dest);
}

410
PyObject* TensorWrapper::shape() {
411
    // if it's tracing compiled mode, get value from compiled_info
412 413 414 415
    if (m_tensor->m_trace_info.compiled_info != nullptr) {
        if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
            return PyTuple_New(0);
        }
M
Megvii Engine Team 已提交
416 417
        PyObject* shp =
                PyObject_GetAttrString(m_tensor->m_trace_info.compiled_info, "shape");
418 419 420 421
        if (shp == Py_None) {
            throw TraceReadError("shape of this tensor is not read in trace");
        }
        return shp;
422
    }
423 424

    // inside trace, if tensor shape is useful for other operations, set shape_read = true
425
    if (m_tensor->m_trace_info.recording && !skip_tracing) {
M
Megvii Engine Team 已提交
426 427 428
        PyObject_SetAttrString(
                m_tensor->m_trace_info.trace_mixin_info, "shape_read",
                py::cast(true).release().ptr());
429
    }
430

431 432 433
    if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
        return PyTuple_New(0);
    }
434 435

    TensorShape shape;
M
Megvii Engine Team 已提交
436
    if (m_tensor->m_var) {  // get shape from m_var
437
        auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
438 439 440 441 442
        auto&& type = mgr.get_infer_type(m_tensor->m_var);
        using InferType = cg::static_infer::InferType;
        if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) {
            Py_RETURN_NONE;
        }
M
Megvii Engine Team 已提交
443
        auto* tshp = mgr.infer_shape_fallible(m_tensor->m_var);
444 445 446 447
        if (!tshp) {
            Py_RETURN_NONE;
        }
        shape = *tshp;
448
    } else {
449
        py::gil_scoped_release _;
450 451 452
        shape = m_tensor->shape();
    }

453 454 455 456 457 458 459 460 461 462 463
    if (!shape.ndim) {
        Py_RETURN_NONE;
    }
    py::tuple ret(shape.ndim);
    for (size_t i = 0; i < shape.ndim; ++i) {
        ret[i] = shape[i];
    }
    return ret.release().ptr();
}

PyObject* TensorWrapper::dtype() {
464 465 466
    if (m_tensor->m_var) {
        return py::cast(m_tensor->m_var->dtype()).release().ptr();
    }
467 468 469 470
    return py::cast(m_tensor->dtype()).release().ptr();
}

PyObject* TensorWrapper::device() {
471 472 473
    if (m_tensor->m_var) {
        return py::cast(m_tensor->m_var->comp_node()).release().ptr();
    }
474 475 476 477
    return py::cast(m_tensor->comp_node()).release().ptr();
}

PyObject* TensorWrapper::numpy() {
478
    if (m_tensor->m_trace_info.compiled_info != nullptr) {
M
Megvii Engine Team 已提交
479 480 481 482
        PyObject* np_val = PyObject_CallMethod(
                m_tensor->m_trace_info.compiled_info, "numpy", nullptr);
        if (!np_val)
            throw py::error_already_set();
483 484 485
        if (np_val == Py_None) {
            throw TraceReadError("value of this tensor is not read in trace");
        }
486
        if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
M
Megvii Engine Team 已提交
487 488
            PyObject* np_scalar =
                    PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(np_val));
489 490
            Py_DECREF(np_val);
            return np_scalar;
491 492 493
        }
        return np_val;
    }
494

495
    if (m_tensor->m_trace_info.recording && !skip_tracing) {
M
Megvii Engine Team 已提交
496 497 498
        PyObject_SetAttrString(
                m_tensor->m_trace_info.trace_mixin_info, "value_read",
                py::cast(true).release().ptr());
499
    }
500

501 502 503 504 505
    if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) {
        auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager();
        auto&& type = mgr.get_infer_type(m_tensor->m_var);
        using InferType = cg::static_infer::InferType;
        if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
506
            PyErr_SetString(PyExc_ValueError, "tensor invalid");
507 508 509 510
            return nullptr;
        }
        auto* val = mgr.infer_value_fallible(m_tensor->m_var);
        if (!val) {
511
            PyErr_SetString(PyExc_ValueError, "tensor invalid");
512 513
            return nullptr;
        }
514 515
        auto np_val = py::cast(*val).attr("numpy")();
        if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
M
Megvii Engine Team 已提交
516 517
            return PyArray_Squeeze(
                    reinterpret_cast<PyArrayObject*>(np_val.release().ptr()));
518 519
        }
        return np_val.release().ptr();
520
    }
521 522 523 524
    auto&& hv = [&]() {
        py::gil_scoped_release _;
        return interpreter_for_py->get_value(m_tensor->m_handle.get());
    }();
M
Megvii Engine Team 已提交
525 526
    auto arr = py::reinterpret_steal<py::array>(
            npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE));
527 528 529 530
    if (!arr) {
        PyErr_SetString(PyExc_ValueError, "tensor invalid");
        return nullptr;
    }
531

532 533 534 535 536 537 538
    if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
        mgb_assert(PyArray_Check(arr.ptr()));
        return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
    }
    return arr.release().ptr();
}

539 540 541 542
PyObject* TensorWrapper::varnode() {
    if (m_tensor->m_var) {
        return py::cast(m_tensor->m_var).release().ptr();
    }
543
    Py_RETURN_NONE;
544 545
}

546
void TensorWrapper::reset(PyObject* tensor) {
547
    TensorWrapper* t = TensorWrapper::try_cast(tensor);
548 549 550
    if (!t) {
        throw py::type_error("expect Tensor");
    }
551 552
    std::string user_custom_name = m_tensor->user_custom_name;
    std::string automatic_name = m_tensor->automatic_name;
553
    auto module_trace_info = m_tensor->m_module_trace_info;
554
    m_tensor = t->m_tensor;
555
    m_tensor->m_module_trace_info = module_trace_info;
556 557
    m_tensor->user_custom_name = user_custom_name;
    m_tensor->automatic_name = automatic_name;
558 559
}

560 561 562 563
void TensorWrapper::reset_varnode() {
    m_tensor->m_var = nullptr;
}

564 565 566
PyObject* TensorWrapper::detach() {
    PyObject* self = wrap_t::pycast(this);
    PyTypeObject* pytype = self->ob_type;
567

568 569 570
    static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy());
    auto new_tensor = python::apply(op, m_tensor)[0];
    new_tensor->m_grad_info_dict = {};
571 572 573 574
    auto ret = TensorWrapper::make(pytype, std::move(new_tensor));
    return ret.release().ptr();
}

M
Megvii Engine Team 已提交
575
PyObject* TensorWrapper::_dev_tensor() {
576
    if (m_tensor->m_trace_info.compiled_info != nullptr) {
M
Megvii Engine Team 已提交
577 578 579 580
        auto* dev_tensor = PyObject_CallMethod(
                m_tensor->m_trace_info.compiled_info, "_dev_tensor", nullptr);
        if (!dev_tensor)
            throw py::error_already_set();
581 582 583
        if (dev_tensor == Py_None) {
            throw TraceReadError("raw data of this tensor is not read in trace");
        }
584 585

        // set m_handle to make it a real tensor
586
        auto py_dev_tensor = py::reinterpret_borrow<py::object>(dev_tensor);
587
        auto sh = interpreter_for_py->put(py_dev_tensor.cast<DeviceTensorND>(), {});
588
        m_tensor->m_handle = std::move(SharedHandle(sh));
589 590

        // compiled info is useless after m_handle is set
591 592
        Py_DECREF(m_tensor->m_trace_info.compiled_info);
        m_tensor->m_trace_info.compiled_info = nullptr;
593 594

        return dev_tensor;
595 596
    }
    if (m_tensor->m_trace_info.recording && !skip_tracing) {
M
Megvii Engine Team 已提交
597 598 599
        PyObject_SetAttrString(
                m_tensor->m_trace_info.trace_mixin_info, "data_read",
                py::cast(true).release().ptr());
600
    }
M
Megvii Engine Team 已提交
601
    auto dev_tensor = [&]() {
602 603 604
        py::gil_scoped_release _;
        return interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get());
    }();
605 606 607 608 609 610 611
    return py::cast(dev_tensor).release().ptr();
}

void TensorWrapper::_drop() {
    interpreter_for_py->drop(m_tensor->m_handle.get());
}

612
PyObject* TensorWrapper::isscalar() {
M
Megvii Engine Team 已提交
613
    if (m_tensor->m_flags & Tensor::Flags::SCALAR) {
614 615 616 617 618 619 620 621 622 623
        Py_RETURN_TRUE;
    } else {
        Py_RETURN_FALSE;
    }
}

void TensorWrapper::setscalar() {
    m_tensor->m_flags |= Tensor::Flags::SCALAR;
}

624 625 626 627
void TensorWrapper::unsetscalar() {
    m_tensor->m_flags &= ~Tensor::Flags::SCALAR;
}

628 629 630 631 632 633 634 635 636 637 638
struct TensorWeakRef {
    std::weak_ptr<Tensor> wptr;

    TensorWeakRef(const TensorWrapper& tw) : wptr(tw.m_tensor) {}

    py::object operator()() {
        if (auto p = wptr.lock()) {
            return TensorWrapper::make(p);
        }
        return py::none();
    }
639
    int _use_cnt() { return wptr.use_count(); }
640 641
};

642 643 644 645 646
/* ============== convert inputs ============== */

// map numpy.dtype.kind to priority
inline uint8_t category_priority(char c) {
    switch (c) {
M
Megvii Engine Team 已提交
647 648 649 650 651 652 653 654 655 656
        case 'f':
            return 3;  // floating-point
        case 'i':
            return 2;  // signed integer
        case 'u':
            return 2;  // unsigned integer
        case 'b':
            return 1;  // boolean
        default:
            return 0;
657 658 659 660 661 662 663 664 665
    }
}

// Returns the maximum value of the priority of each type in the list `types`.
uint8_t max_priority(SmallVector<PyArray_Descr*> types) {
    if (types.size() == 0) {
        return 0;
    } else {
        uint8_t max_p = 0;
M
Megvii Engine Team 已提交
666
        for (auto&& desc : types) {
667 668 669 670 671 672
            max_p = std::max(max_p, category_priority(desc->kind));
        }
        return max_p;
    }
}

673
// Returns the data type with sufficient size to hold all types of
674 675 676 677
// category `cat` in the list `types`.
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) {
    // Return value: New reference
    SmallVector<PyArray_Descr*> used_types;
M
Megvii Engine Team 已提交
678
    for (auto&& desc : types) {
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
        auto&& v = category_priority(desc->kind);
        if (v == cat) {
            used_types.emplace_back(desc);
        }
    }
    mgb_assert(used_types.size() > 0, "size of used_types is 0");
    PyArray_Descr* res = used_types[0];
    Py_INCREF(res);

    for (size_t i = 1; i < used_types.size(); ++i) {
        PyArray_Descr* tmp = PyArray_PromoteTypes(used_types[i], res);
        Py_DECREF(res);
        res = tmp;
    }
    return res;
}

PyArray_Descr* scalar2dtype(PyObject* arg) {
    // Return value: New reference
    if (PyBool_Check(arg)) {
        auto&& descr = PyArray_DescrFromType(NPY_BOOL);
        return descr;
    }
    if (PyLong_CheckExact(arg)) {
        auto&& descr = PyArray_DescrFromType(NPY_INT32);
        return descr;
    }
    if (PyFloat_CheckExact(arg)) {
        auto&& descr = PyArray_DescrFromType(NPY_FLOAT32);
        return descr;
    }
    return nullptr;
}

M
Megvii Engine Team 已提交
713
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs) {
714 715 716 717 718
    // Return value: New reference
    SmallVector<PyArray_Descr*> tensors;
    SmallVector<PyArray_Descr*> scalars;

    bool is_tuple = false;
719
    PyObject* tuple = nullptr;
720 721 722 723 724 725 726 727 728 729 730 731
    if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
        if (PyList_Check(args[0])) {
            tuple = PyList_AsTuple(args[0]);
        } else {
            tuple = args[0];
            Py_INCREF(tuple);
        }
        nargs = PyTuple_Size(tuple);
        is_tuple = true;
    }

    for (size_t i = 0; i < nargs; ++i) {
M
Megvii Engine Team 已提交
732 733 734
        PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
        if (handle == Py_None)
            continue;
735
        TensorWrapper* tw = TensorWrapper::try_cast(handle);
736 737 738 739 740
        if (tw) {
            mgb::DType type = tw->m_tensor->dtype();
            auto&& descr = npy::dtype_mgb2np_descr(type);
            Py_INCREF(descr.get());
            tensors.emplace_back(descr.get());
M
Megvii Engine Team 已提交
741
        } else {
742 743 744 745 746
            if (PyArray_Check(handle) || PyArray_CheckScalar(handle)) {
                auto&& descr = PyArray_DescrFromObject(handle, nullptr);
                tensors.emplace_back(descr);
                continue;
            }
747

M
Megvii Engine Team 已提交
748
            if (py::isinstance<PySymbolVar>(py::handle(handle))) {
749 750
                auto var = py::handle(handle).cast<PySymbolVar*>();
                mgb::DType type = var->m_node->dtype();
M
Megvii Engine Team 已提交
751
                auto&& descr = npy::dtype_mgb2np_descr(type);
752 753 754 755 756
                Py_INCREF(descr.get());
                tensors.emplace_back(descr.get());
                continue;
            }

757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773
            PyArray_Descr* descr = scalar2dtype(handle);
            if (descr) {
                scalars.emplace_back(descr);
                continue;
            }
        }
    }

    auto max_pri_scalars = max_priority(scalars);
    auto max_pri_tensors = max_priority(tensors);

    if (max_pri_scalars <= 0 && max_pri_tensors <= 0) {
        throw py::value_error("invalid input, no dtype avaliable");
    }
    PyArray_Descr* res;
    if (max_pri_scalars > max_pri_tensors) {
        res = promote_types(scalars, max_pri_scalars);
M
Megvii Engine Team 已提交
774
    } else {
775 776
        res = promote_types(tensors, max_pri_tensors);
    }
M
Megvii Engine Team 已提交
777 778 779 780 781 782
    for (auto* p : tensors) {
        Py_DECREF(p);
    }
    for (auto* p : scalars) {
        Py_DECREF(p);
    }
783
    Py_XDECREF(tuple);
784 785 786
    return res;
}

M
Megvii Engine Team 已提交
787
CompNode _get_device(PyObject* const* args, size_t nargs) {
788
    bool is_tuple = false;
789
    PyObject* tuple = nullptr;
790 791 792 793 794 795 796 797 798 799 800 801 802
    if (nargs == 1 && (PyTuple_Check(args[0]) || PyList_Check(args[0]))) {
        if (PyList_Check(args[0])) {
            tuple = PyList_AsTuple(args[0]);
        } else {
            tuple = args[0];
            Py_INCREF(tuple);
        }
        nargs = PyTuple_Size(tuple);
        is_tuple = true;
    }
    bool valid = false;
    CompNode cn;
    for (size_t i = 0; i < nargs; ++i) {
803
        PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i) : args[i];
804
        TensorWrapper* tw = TensorWrapper::try_cast(handle);
805

806 807
        bool is_symvar = py::isinstance<PySymbolVar>(py::handle(handle));
        if (tw || is_symvar) {
808
            if (!valid) {
809
                cn = tw ? tw->m_tensor->comp_node()
M
Megvii Engine Team 已提交
810
                        : py::handle(handle).cast<PySymbolVar*>()->m_node->comp_node();
811 812
                valid = true;
            } else {
813 814 815 816
                CompNode cn1 = tw ? tw->m_tensor->comp_node()
                                  : py::handle(handle)
                                               .cast<PySymbolVar*>()
                                               ->m_node->comp_node();
817
                if (cn1 != cn) {
M
Megvii Engine Team 已提交
818 819 820
                    throw py::value_error(ssprintf(
                            "ambiguous device: %s vs %s", cn.to_string().c_str(),
                            cn1.to_string().c_str()));
821 822 823 824 825
                }
            }
        }
    }
    if (!valid) {
826
        return CompNode::load(get_default_device());
827
    }
828
    Py_XDECREF(tuple);
829 830 831 832 833
    return cn;
}

// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
M
Megvii Engine Team 已提交
834
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
835 836 837 838 839 840 841
    if (!nargs) {
        PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
        return nullptr;
    }
    try {
        PyArray_Descr* res = _dtype_promotion(args, nargs);
        return py::cast(npy::dtype_np2mgb_descr(res)).release().ptr();
M
Megvii Engine Team 已提交
842 843
    }
    PYEXT17_TRANSLATE_EXC_RET(nullptr)
844 845
}

M
Megvii Engine Team 已提交
846
PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
847 848 849 850 851 852 853
    if (!nargs) {
        PyErr_SetString(PyExc_TypeError, "empty input is not allowed");
        return nullptr;
    }
    try {
        CompNode cn = _get_device(args, nargs);
        return py::cast(cn).release().ptr();
M
Megvii Engine Team 已提交
854 855
    }
    PYEXT17_TRANSLATE_EXC_RET(nullptr)
856
}
857

858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
#else
#define WRAP_FUNC_PY35(FUNC)                                \
    PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
        auto* arr = &PyTuple_GET_ITEM(args, 0);             \
        auto size = PyTuple_GET_SIZE(args);                 \
        return FUNC(self, arr, size);                       \
    }
WRAP_FUNC_PY35(py_apply);
WRAP_FUNC_PY35(dtype_promotion);
WRAP_FUNC_PY35(get_device);
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif

876
void init_tensor(py::module m) {
877
    imperative::Tensor::static_initialize();
M
Megvii Engine Team 已提交
878 879
    static auto sl_interpreter_for_py =
            interpreter::Interpreter::inst().create_channel();
880
    interpreter_for_py = sl_interpreter_for_py.get();
881

M
Megvii Engine Team 已提交
882 883
    static py::exception<interpreter::AsyncError> py_async_error(
            m, "AsyncError", PyExc_RuntimeError);
884 885
    py::register_exception_translator([](std::exception_ptr p) {
        try {
M
Megvii Engine Team 已提交
886 887
            if (p)
                std::rethrow_exception(p);
888 889 890 891 892 893 894 895 896 897
        } catch (const interpreter::AsyncError& e) {
            pyext17::pybind11_translate_exception(e.nested_ptr());
            if (PyErr_Occurred()) {
                PyObject *exc, *val, *tb;
                PyErr_Fetch(&exc, &val, &tb);
                PyErr_NormalizeException(&exc, &val, &tb);
                if (tb) {
                    PyException_SetTraceback(val, tb);
                }
                auto val2 = py_async_error.py::object::operator()(
M
Megvii Engine Team 已提交
898 899 900 901 902 903
                        "An async error is reported. See above for the actual cause."
                        " Hint: This is where it is reported, not where it happened."
                        " You may call `megengine.core.set_option('async_level', 0)` "
                        "to get better error reporting.");
                PyException_SetCause(
                        val2.ptr(), val);  // PyException_SetCause steals reference
904 905
                Py_XDECREF(exc);
                Py_XDECREF(tb);
M
Megvii Engine Team 已提交
906 907
                PyErr_Restore(
                        py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
908 909 910 911 912 913
            } else {
                py_async_error("Unkown async error");
            }
        }
    });

M
Megvii Engine Team 已提交
914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955
    auto* tensor_type =
            TensorWrapper::wrap_t::type()
                    .def<&TensorWrapper::numpy>("numpy")
                    .def_getset<&TensorWrapper::shape>("shape")
                    .def_getset<&TensorWrapper::dtype>("dtype")
                    .def_getset<&TensorWrapper::device>("device")
                    .def<&TensorWrapper::reset>("_reset")
                    .def<&TensorWrapper::isscalar>("_isscalar")
                    .def<&TensorWrapper::setscalar>("_setscalar")
                    .def<&TensorWrapper::unsetscalar>("_unsetscalar")
                    .def<&TensorWrapper::detach>("detach")
                    .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
                    .def<&TensorWrapper::_drop>("_drop")
                    .def<&TensorWrapper::reset_varnode>("_reset_varnode")
                    .def<&TensorWrapper::_use_cnt>("_use_cnt")
                    .def_getset<&TensorWrapper::varnode>("_varnode")
                    .def_getset<
                            &TensorWrapper::mixin_handle,
                            &TensorWrapper::set_mixin_handle>("_mixin_handle")
                    .def_getset<
                            &TensorWrapper::recording, &TensorWrapper::set_recording>(
                            "_recording")
                    .def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>(
                            "_handle")
                    .def_getset<
                            &TensorWrapper::compiled_info,
                            &TensorWrapper::set_compiled_info>("_compiled_info")
                    .def_getset<
                            &TensorWrapper::trace_mixin_info,
                            &TensorWrapper::set_trace_mixin_info>("_trace_mixin_info")
                    .def_getset<
                            &TensorWrapper::user_custom_name,
                            &TensorWrapper::set_user_custom_name>("c_name")
                    .def_getset<
                            &TensorWrapper::automatic_name,
                            &TensorWrapper::set_automatic_name>("_name")
                    .def_getset<
                            &TensorWrapper::module_trace_info,
                            &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
                    .finalize();
    if (!tensor_type)
        throw py::error_already_set();
956 957 958
    py::setattr(m, "Tensor", tensor_type);

    py::class_<TensorWeakRef>(m, "TensorWeakRef")
M
Megvii Engine Team 已提交
959 960 961
            .def(py::init<const TensorWrapper&>())
            .def("__call__", &TensorWeakRef::operator())
            .def("_use_cnt", &TensorWeakRef::_use_cnt);
962

963 964 965
    py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
            .def_property_readonly(
                    "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
M
Megvii Engine Team 已提交
966 967 968
            .def_property(
                    "var", [](PySymbolVar* v) { return v->m_node; },
                    [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
969
            .def_property_readonly(
M
Megvii Engine Team 已提交
970
                    "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
971
            .def_property_readonly(
M
Megvii Engine Team 已提交
972
                    "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
973 974 975
            .def_property_readonly(
                    "shape",
                    [](PySymbolVar* v) -> const TensorShape* {
M
Megvii Engine Team 已提交
976
                        auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
977 978
                        return mgr.infer_shape_fallible(v->m_node);
                    })
M
Megvii Engine Team 已提交
979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996
            .def("numpy",
                 [](PySymbolVar* v) {
                     auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
                     auto&& type = mgr.get_infer_type(v->m_node);
                     using InferType = cg::static_infer::InferType;
                     if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
                         throw py::value_error("value invalid!");
                     }
                     auto* val = mgr.infer_value_fallible(v->m_node);
                     if (!val) {
                         throw py::value_error("value invalid!");
                     }
                     auto np_val = py::cast(*val).attr("numpy")();
                     if (v->is_scalar) {
                         return py::object(py::array(np_val).squeeze());
                     }
                     return np_val;
                 })
997
            .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
M
Megvii Engine Team 已提交
998
            .def("_setscalar", [](PySymbolVar* v) { return v->is_scalar = true; })
999 1000 1001 1002 1003
            .def(py::init([](cg::VarNode* node) {
                     return std::make_shared<PySymbolVar>(node);
                 }),
                 py::arg() = nullptr);

1004
    static PyMethodDef method_defs[] = {
1005 1006 1007 1008
            MGE_PY_INTERFACE(apply, py_apply),
            MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
            MGE_PY_INTERFACE(get_device, get_device),
            {nullptr, nullptr, 0, nullptr}};
M
Megvii Engine Team 已提交
1009
    for (auto&& def : method_defs) {
1010 1011
        if (def.ml_meth != nullptr) {
            auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
M
Megvii Engine Team 已提交
1012 1013
            if (!func)
                throw py::error_already_set();
1014 1015 1016
            py::setattr(m, def.ml_name, func);
        }
    }
1017

M
Megvii Engine Team 已提交
1018
    static constexpr auto sync_py_task_q = [] { py_task_q.wait_all_task_finish(); };
1019

M
Megvii Engine Team 已提交
1020 1021 1022
    m.def("set_option", [](std::string name, size_t value) {
        interpreter_for_py->set_option(name, value);
    });
1023
    m.def("clear_candidates", []() { interpreter_for_py->clear_candidates(); });
1024
    m.def("get_option",
M
Megvii Engine Team 已提交
1025
          [](std::string name) { return interpreter_for_py->get_option(name); });
1026
    m.def("_set_drop_flag",
1027
          [](bool flag) { interpreter_for_py->set_option("enable_drop", flag); });
M
Megvii Engine Team 已提交
1028 1029 1030 1031
    m.def("config_async_level", [](int level) {
        mgb_assert(level >= 0 and level <= 2, "async_level should be 0, 1 or 2");
        interpreter_for_py->set_option("async_level", level);
    });
1032
    m.def("get_async_level",
1033
          []() { return interpreter_for_py->get_option("async_level"); });
M
Megvii Engine Team 已提交
1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
    m.def("set_buffer_length", [](int length) {
        mgb_assert(length >= 0 and length < 100, "buffer_length should be in [0, 100)");
        interpreter_for_py->set_option("buffer_length", length);
    });
    m.def("push_scope", [](std::string name) { interpreter_for_py->push_scope(name); });
    m.def("pop_scope", [](std::string name) { interpreter_for_py->pop_scope(name); });
    m.def(
            "start_profile",
            [](imperative::Profiler::options_t options) {
                interpreter_for_py->sync();
                imperative::Profiler::load_options(std::move(options));
                imperative::Profiler::start_profile();
                interpreter_for_py->start_profile();
            },
            py::call_guard<py::gil_scoped_release>());
    m.def(
            "stop_profile",
            []() -> std::function<void(std::string, std::string)> {
                interpreter_for_py->stop_profile();
                interpreter_for_py->sync();
                imperative::Profiler::stop_profile();
                auto results = std::make_shared<imperative::Profiler::bundle_t>(
                        imperative::Profiler::collect());
                return [results = results](
                               std::string basename, std::string format) mutable {
                    imperative::Profiler::dump_profile(
                            basename, format, std::move(*results));
                    results = nullptr;
                };
            },
            py::call_guard<py::gil_scoped_release>());
    m.def(
            "sync",
            []() {
                interpreter_for_py->sync();
                sync_py_task_q();
            },
            py::call_guard<py::gil_scoped_release>());
    m.def(
            "full_sync",
            []() {
                interpreter_for_py->sync();
                CompNode::sync_all();
1077 1078 1079 1080
                CompNode::foreach ([](CompNode cn) {
                    auto err = cn.check_async_error();
                    mgb_assert(!err, "%s", err->what());
                });
M
Megvii Engine Team 已提交
1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103
                sync_py_task_q();
            },
            py::call_guard<py::gil_scoped_release>());
    m.def(
            "close",
            []() {
                interpreter_for_py->close();
                sync_py_task_q();
            },
            py::call_guard<py::gil_scoped_release>());

    py::handle grad_key_type =
            GradKeyWrapper::wrap_t::type()
                    .def<&GradKeyWrapper::attach>("attach")
                    .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
                    .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
                            "name")
                    .def_getset<
                            &GradKeyWrapper::get_priority,
                            &GradKeyWrapper::set_priority>("priority")
                    .finalize();
    if (!grad_key_type)
        throw py::error_already_set();
1104
    py::setattr(m, "GradKey", grad_key_type);
1105 1106
    m.def("backward", &GradKeyWrapper::backward);

1107 1108 1109
    m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing);
    m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing);
    m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode);
1110
    m.def("set_cpp_apply_module_trace", &set_cpp_apply_module_trace);
1111 1112 1113
    m.attr("skip_tracing") = &skip_tracing;

    py::class_<SharedHandle>(m, "SharedHandle")
M
Megvii Engine Team 已提交
1114 1115 1116 1117 1118 1119 1120
            .def(py::init<const SharedHandle&>())
            .def("__eq__",
                 [](SharedHandle& thish, SharedHandle& thath) {
                     return (thish.get() == thath.get());
                 })
            .def("__hash__",
                 [](SharedHandle& sh) { return reinterpret_cast<int64_t>(sh.get()); });
1121 1122 1123

    m.def("set_tracing", &set_tracing);
    m.def("unset_tracing", &unset_tracing);
M
Megvii Engine Team 已提交
1124 1125
    m.def("set_allow_higher_order_directive",
          [](bool value) { GradKey::allow_higher_order_directive = value; });
1126 1127 1128
    m.def("set_module_tracing", &set_module_tracing);
    m.def("unset_module_tracing", &unset_module_tracing);
    m.def("is_tracing_module", &is_tracing_module);
1129 1130
}

1131 1132
#undef MGE_PY_INTERFACE

M
Megvii Engine Team 已提交
1133
}  // namespace mgb::imperative::python