tensor.cpp 74.9 KB
Newer Older
1
#include "megbrain/common.h"
M
Megvii Engine Team 已提交
2
#include "megbrain/dtype.h"
3
#include "megbrain/imperative/backtrace.h"
4
#include "megbrain/imperative/cpp_cupti.h"
5
#include "megbrain/imperative/dispatch.h"
6
#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
7
#include "megbrain/imperative/ops/backward_graph.h"
8
#include "megbrain/imperative/ops/opr_attr.h"
M
Megvii Engine Team 已提交
9
#include "megbrain/imperative/ops/utility.h"
10
#include "megbrain/imperative/profiler.h"
11
#include "megbrain/imperative/transformation.h"
12
#include "megbrain/imperative/transformations/complex.h"
13
#include "megbrain/imperative/transformations/dim_expansion.h"
14
#include "megbrain/imperative/transformations/dtype_promote.h"
15
#include "megbrain/imperative/transformations/eval.h"
16
#include "megbrain/imperative/transformations/format.h"
17
#include "megbrain/imperative/transformations/group_comm.h"
18 19 20 21 22
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
23
#include "megbrain/opr/io.h"
24
#include "megbrain/plugin/profiler.h"
25
#include "megbrain/utils/stats.h"
26
#include "megdnn/algorithm_cache.h"
27

28
#include "./common.h"
29 30
#include "./dlpack.h"
#include "./dlpack_convertor.h"
31
#include "./external_convert.h"
M
Megvii Engine Team 已提交
32
#include "./grad.h"
33
#include "./graph_rt.h"
34
#include "./helper.h"
M
Megvii Engine Team 已提交
35 36 37
#include "./module_trace.h"
#include "./numpy_dtypes.h"
#include "./tensor.h"
38
#include "./tensor_utils.h"
39
#include "./transformation.h"
40

41
#include <object.h>
42 43
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
44 45
#include <pybind11/pytypes.h>
#include <pyerrors.h>
46
#include <iterator>
47
#include <range/v3/all.hpp>
48
#include <string>
49 50 51

#include <unordered_map>

52
#include "../../src/impl/mgb_cg_impl.h"
53
#include "./backtrace.h"
54

55 56
#include <iostream>

57
namespace py = pybind11;
58
namespace views = ranges::views;
59 60 61

namespace mgb::imperative::python {

62 63
interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr;
64
PyTypeObject* py_varnode_type = nullptr;
65
PyTypeObject* py_external_type = nullptr;
66
pybind11::handle py_device_type = nullptr;
67
PyObject* cpp_use_symbolic_shape;
68 69 70 71 72 73 74

#define REGISTE_APPLY_FUNC(mode) \
    void set_##mode(py::object pyf) { mode = pyf.ptr(); }

REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)

#undef REGISTE_APPLY_FUNC
75

76 77 78
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
CompNode _get_device(PyObject* const* args, size_t nargs);

M
Megvii Engine Team 已提交
79 80
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
81 82 83 84 85
    try {
        // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
        //     PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
        //     return nullptr;
        // }
86
        if (nargs < 2) {
M
Megvii Engine Team 已提交
87 88 89 90
            PyErr_SetString(
                    PyExc_TypeError,
                    "py_apply expects one Op and at least one tensor "
                    "as argument");
91 92
            return nullptr;
        }
93

94
        auto* py_op = args[0];
95

96 97 98
        ++args;
        --nargs;

99
        auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
100
        SmallVector<ValueRef, 8> tensors(nargs);
101

102 103 104 105 106 107 108 109 110 111
        mgb::CompNode target_cn;
        mgb::DType target_dtype;

        auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
            if (!target_dtype.valid()) {
                target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
                target_cn = _get_device(args, nargs);
            }
            HostTensorND ht(target_cn);
            ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
112
            record_py_backtrace();
113 114 115 116
            //! operand in elemwise can't be None
            if (args[i] == Py_None) {
                throw py::type_error("the operand is None and is not supported.");
            } else if (PyArray_Check(args[i]) || PyList_Check(args[i])) {  // non scaler
117
                // py_tuple is not allowed here because of tracing
118 119 120 121 122 123 124 125 126 127
                return imperative::apply(
                        CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
                        HostStorage::make(ht.storage()))[0];
            } else {  // scaler
                return imperative::apply(
                        CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
                        HostStorage::make(ht.storage()))[0];
            }
        };

128
        bool is_varnode_apply = false;
129
        for (size_t i = 0; i < nargs; ++i) {
130 131 132
            if (PyObject_TypeCheck(args[i], py_varnode_type)) {
                is_varnode_apply = true;
            }
133
            if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
134
                tensors[i] = tw->m_tensor->data();
135 136
            } else if (
                    DTypePromoteCfg::convert_input_enabled &&
137
                    (op->same_type<Elemwise>() || op->same_type<ElemwiseMultiType>())) {
138
                tensors[i] = convert_pyinput_to_tensor(i);
139 140 141
            } else {
                PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
                return nullptr;
142 143
            }
        }
144
        record_py_backtrace();
145
        auto outputs = [&] { return imperative::apply(*op, tensors); }();
146 147
        size_t nout = outputs.size();
        auto ret = py::tuple(nout);
148
        PyTypeObject* py_type = is_varnode_apply ? py_varnode_type : py_tensor_type;
149
        for (size_t i = 0; i < nout; ++i) {
150
            ret[i] = TensorWrapper::make(py_type, std::move(outputs[i]));
151 152
        }
        return ret.release().ptr();
M
Megvii Engine Team 已提交
153 154
    }
    PYEXT17_TRANSLATE_EXC_RET(nullptr)
155
}
156 157 158 159 160
FrameInfoPtr get_current_frameinfo() {
    auto frame = PyEval_GetFrame();
    auto frameinfo = get_frameinfo_from_pyframe(frame);
    return frameinfo;
}
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175
namespace {

template <typename T>
py::handle py_type() {
    if constexpr (std::is_same_v<T, py::int_>) {
        return (PyObject*)&PyLong_Type;
    } else if constexpr (std::is_same_v<T, py::float_>) {
        return (PyObject*)&PyFloat_Type;
    } else if constexpr (std::is_same_v<T, py::tuple>) {
        return (PyObject*)&PyTuple_Type;
    } else if constexpr (std::is_same_v<T, py::list>) {
        return (PyObject*)&PyList_Type;
    } else {
        static_assert(std::is_same_v<T, T>);
176
    }
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
}

template <typename T>
auto scalar2storage(T val, CompNode cn, DType dtype) {
    using max_ctype_t = DTypeScalar::max_ctype;
    DTypeScalar scalar(dtype);
    scalar.set_retain_dtype(val);
    HostTensorStorage storage(cn);
    auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
    std::shared_ptr<dt_byte> raw_storage = {
            raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
    storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
    std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
    return HostStorage::make(std::move(storage));
}

template <typename ctype>
auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
    mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
    // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
    auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
    for (size_t i = 0; i < vec.size(); ++i) {
        raw_ptr[i] = vec[i].get_cast<ctype>();
200
    }
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
    mgb_assert(sizeof(ctype) == dtype.size());
    std::shared_ptr<dt_byte> raw_storage = {
            reinterpret_cast<dt_byte*>(raw_ptr),
            [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
    HostTensorStorage storage(cn);
    storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
    return HostStorage::make(std::move(storage));
}

struct HostTensorArgs {
    ValueShape shape;
    DType dtype;
    HostStorage::ref_t storage;

    HostTensorND as_tensor_nd() const {
        HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
        ret.only_reset_raw_storage(*storage);
        return ret;
    }
};

template <typename seq_type, typename ctype>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    auto size = obj.size();
    if (size > MEGDNN_MAX_NDIM) {
        return false;
    }
    ctype items[size];
    for (size_t i = 0; i < size; ++i) {
        py::handle item = obj[i];
        if (item.get_type().is(py_type<py::int_>())) {
            items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
        } else if (item.get_type().is(py_type<py::float_>())) {
            items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
        } else {
            return false;
237
        }
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    }
    mgb_assert(sizeof(ctype) == dtype.size());
    auto* raw_ptr = new ctype[size];
    std::shared_ptr<dt_byte> raw_storage = {
            reinterpret_cast<dt_byte*>(raw_ptr),
            [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
    HostTensorStorage storage(cn);
    storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
    std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
    ret.dtype = dtype;
    ret.shape = {size};
    ret.storage = HostStorage::make(std::move(storage));
    return true;
}

template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
    auto size = obj.size();
    if (size > MEGDNN_MAX_NDIM) {
        return false;
    }
    DTypeScalar items[size];
    DType dtype;
    for (size_t i = 0; i < size; ++i) {
        auto&& item = obj[i];
        if (item.get_type().is(py_type<py::int_>())) {
            items[i] = (dt_int32)item.template cast<py::int_>();
            if (!dtype.valid()) {
                dtype = dtype::Int32();
            } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
                return false;
            }
        } else if (item.get_type().is(py_type<py::float_>())) {
            items[i] = (dt_float32)item.template cast<py::float_>();
            if (!dtype.valid()) {
                dtype = dtype::Float32();
            } else if (dtype == dtype::Int32()) {
                dtype = dtype::Float32();
            } else if (dtype != dtype::Float32()) {
                return false;
278
            }
279
        } else {
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468
            return false;
        }
    }
    if (!dtype.valid()) {
        dtype = dtype::Float32();
    }
    ret.dtype = dtype;
    ret.shape = {size};
    if (dtype == dtype::Int32()) {
        ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
    } else if (dtype == dtype::Float32()) {
        ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
    } else {
        mgb_assert(false);
    }
    return true;
}

template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (dtype == dtype::Int32()) {
        return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
    } else if (dtype == dtype::Float32()) {
        return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
    } else if (!dtype.valid()) {
        return pyseq2hval<seq_type>(obj, cn, ret);
    } else {
        return false;
    }
}

bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    auto data = obj.cast<py::array>();
    auto strides = data.strides();
    bool need_squeeze = false;
    for (size_t i = 0; i < data.ndim(); ++i) {
        if (strides[i] == 0) {
            need_squeeze = true;
            break;
        }
    }
    if (need_squeeze) {
        std::vector<size_t> shape;
        for (size_t i = 0; i < data.ndim(); ++i) {
            shape.push_back(data.shape(i));
        }
        data = data.squeeze();
        data.resize(shape);
    }
    HostTensorND retnd(cn);
    retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
    if (!dtype.valid()) {
        dtype = retnd.dtype();
    }
    mgb_assert(
            retnd.layout().is_empty() || retnd.layout().is_contiguous(),
            "host value should be continuous");
    for (size_t i = 0; i < data.ndim(); ++i) {
        ret.shape[ret.shape.ndim++] = data.shape(i);
    }
    ret.dtype = dtype;
    ret.storage = HostStorage::make(retnd.storage());
    return true;
}

bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (!dtype.valid()) {
        dtype = dtype::Int32();
    }
    ret.dtype = dtype;
    ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
    return true;
}

bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (!dtype.valid()) {
        dtype = dtype::Float32();
    }
    ret.dtype = dtype;
    ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
    return true;
}

HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
    HostTensorArgs ret;
    bool success = false;
    // check order: float -> int -> tuple(int -> float) -> list(int -> float)
    // only handle `exact` pytype, isinstance also accepts subtype
    // for example, isinstance(True, int) == True
    if (obj.get_type().is(py_type<py::float_>())) {
        success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::int_>())) {  // py::bool_ is py::int_
        success = pyint2hval(py::int_(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::tuple>())) {
        success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::list>())) {
        success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
    } else if (obj.is_none()) {
        obj = py::list(0);
    }
    if (!success) {
        success = pyarr2hval(obj, cn, dtype, ret);
    }
    mgb_assert(success);
    return ret;
}

struct PyArgDesc {
    const char* name;
    py::object (*default_value)();
};

struct PyArgDescs {
    std::vector<PyArgDesc> items;
    ssize_t (*name2idx)(const char* name);
};

py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
    size_t nr_args = args.size();
    size_t nr_items = descs.items.size();
    mgb_assert(nr_args <= nr_items, "too many args");
    if (nr_args == nr_items) {
        return args;
    }
    py::tuple ret(nr_items);
    for (size_t i = 0; i < nr_args; ++i) {
        ret[i] = args[i];
    }
    for (size_t i = nr_args; i < nr_items; ++i) {
        ret[i] = descs.items[i].default_value();
    }
    return ret;
}

py::tuple parse_args_and_kwargs(
        py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
    size_t nr_args = args.size();
    size_t nr_kwargs = kwargs.size();
    size_t nr_items = descs.items.size();
    mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
    if (nr_args == nr_items) {
        return args;
    }
    py::tuple ret(nr_items);
    for (size_t i = 0; i < nr_args; ++i) {
        ret[i] = args[i];
    }
    bool has_value[nr_items - nr_args];
    for (size_t i = nr_args; i < nr_items; ++i) {
        has_value[i - nr_args] = false;
    }
    for (auto&& [k, v] : kwargs) {
        auto key = py::str(k).cast<std::string>();
        ssize_t index = descs.name2idx(key.c_str());
        mgb_assert(index >= nr_args);
        ret[index] = v;
        has_value[index - nr_args] = true;
    }
    for (size_t i = nr_args; i < nr_items; ++i) {
        if (!has_value[i - nr_args]) {
            ret[i] = descs.items[i].default_value();
        }
    }
    return ret;
}

CompNode as_comp_node(const std::string& name) {
    thread_local struct {
        std::string name;
        CompNode cn;
    } cached;
    if (cached.name != name) {
        cached.name = name;
        cached.cn = CompNode::load(name);
    }
    return cached.cn;
}

CompNode as_comp_node(py::object py_device) {
    std::optional<std::string> device_name;
    if (py_device.is_none() || py::str::check_(py_device)) {
        auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
        auto dmap_callback = cls.attr("dmap_callback");
        std::string name;
        if (dmap_callback.is_none() && py_device.is_none()) {
            name = get_default_device();
        } else {
            if (py_device.is_none()) {
                py_device = py::str(get_default_device());
469
            }
470 471
            if (!dmap_callback.is_none()) {
                py_device = dmap_callback(py_device);
472
            }
473 474 475 476 477 478 479 480 481 482 483
            name = py::str(py_device).cast<std::string>();
        }
        return as_comp_node(name);
    } else {
        if (py::isinstance(py_device, py_device_type)) {
            py_device = py_device.attr("_cn");
        }
        mgb_assert(py::isinstance(py_device, py_comp_node_type));
        return py_device.cast<CompNode>();
    }
}
484

485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513
template <char... Chars>
bool compare_cstr(const char* cstr) {
    return (((*cstr++) == Chars) && ...) && *cstr == '\0';
}

ssize_t name2idx(const char* name) {
    const char* ch = name;
    // TODO: trie
    // clang-format off
    switch (*ch++) {
    case 'd':
        switch (*ch++) {
        // data
        case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
        // dtype
        case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
        // device
        case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
        }
    case 'i':
        // is_const
        return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
    case 'n':
        switch (*ch++) {
        // no_cache
        case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
        // name
        case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
        }
514 515 516
    case 'f':
        // format
        return compare_cstr<'o', 'r', 'm', 'a', 't'>(ch) ? 6 : -1;
517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
    }
    // clang-format on
    return -1;
}

}  // namespace

TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
    static PyArgDescs descs = {
            {
                    {"data", []() -> py::object { return py::none(); }},
                    {"dtype", []() -> py::object { return py::none(); }},
                    {"device", []() -> py::object { return py::none(); }},
                    {"is_const", []() -> py::object { return py::bool_(false); }},
                    {"no_cache", []() -> py::object { return py::bool_(false); }},
                    {"name", []() -> py::object { return py::none(); }},
533
                    {"format", []() -> py::object { return py::none(); }},
534 535 536 537 538 539 540 541 542 543
            },
            name2idx};
    py::detail::loader_life_support life_sup;  // FIXME!!!required to cast DType
    auto tup = py::reinterpret_borrow<py::tuple>(args);
    if (kwargs) {
        tup = parse_args_and_kwargs(
                tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
    } else {
        tup = parse_args(tup, descs);
    }
544
    mgb_assert(tup.size() == 7);
545
    if (auto* t = try_cast(tup[0].ptr())) {
546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
        m_tensor = t->m_tensor;
        // TODO: merge two path in arg parse
        if (!tup[1].is_none()) {
            auto dtype = tup[1].cast<DType>();
            mgb_assert(
                    dtype == m_tensor->dtype(), "dtype mismatch: %s vs %s",
                    dtype.name(), m_tensor->dtype().name());
        }
        if (!tup[2].is_none()) {
            auto device = as_comp_node(tup[2]);
            mgb_assert(
                    device == m_tensor->comp_node(), "device mismatch: %s vs %s",
                    device.to_string().c_str(),
                    m_tensor->comp_node().to_string().c_str());
        }
        mgb_assert(!tup[3].cast<bool>(), "expect is_const == False, got True");
        bool no_cache = tup[4].cast<bool>();
        if (no_cache) {
            // always copy because it's hard to tell whether this tensor is cached
            m_tensor = m_tensor->copy();
        }
        // ignore name
        if (!tup[6].is_none()) {
            Format format = tup[6].cast<std::string>();
            mgb_assert(
                    format == m_tensor->format(), "format mismatch: %s vs %s",
                    format.to_string().c_str(), m_tensor->format().to_string().c_str());
        }
574 575 576
    } else {
        auto data = tup[0];
        DType dtype = tup[1].cast<DType>();
577
        CompNode cn = as_comp_node(tup[2]);
578 579 580 581 582 583
        bool is_const = tup[3].cast<bool>();
        bool no_cache = tup[4].cast<bool>();
        std::string name;
        if (!tup[5].is_none()) {
            name = tup[5].cast<std::string>();
        }
584 585 586 587
        Format format;
        if (!tup[6].is_none()) {
            format = tup[6].cast<std::string>();
        }
588 589 590 591 592

        {
            CreateTensor::Kind kind = is_const ? CreateTensor::Const
                                    : no_cache ? CreateTensor::Unique
                                               : CreateTensor::Common;
593
            ValueRef val;
594 595 596 597 598 599 600
            bool use_external_inp = py_external_type != nullptr;
            if (use_external_inp &&
                PyObject_TypeCheck(py::handle(data).ptr(), py_external_type)) {
                val = imperative::apply(
                        CreateExternalWrapper(data, cn),
                        Span<ValueRef>(nullptr, nullptr))[0];
            } else if (py::isinstance(data, Py_Varnode)) {
601 602 603 604 605 606
                cg::VarNode* m_node = py::handle(data).cast<cg::VarNode*>();
                val = imperative::apply(
                        CreateNode(m_node), Span<ValueRef>(nullptr, nullptr))[0];
            } else {
                auto&& hval = pyobj2hval(data, cn, dtype);
                val = imperative::apply(
607
                        CreateTensor(kind, cn, hval.dtype, hval.shape, format),
608 609
                        hval.storage)[0];
            }
610 611 612 613 614
            m_tensor.emplace(val);
        }

        if (!name.empty()) {
            m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
615 616
        }
    }
617
    mgb_assert(m_tensor->data());
618 619
}

620
PyObject* TensorWrapper::module_trace_info() {
621 622 623
    if (auto module_trace_info =
                ModuleTraceTransformation::module_trace_info_map.try_get(
                        m_tensor->data())) {
624 625 626
        if (module_trace_info->ptr()) {
            return module_trace_info->inc_ref().ptr();
        }
627
    }
628 629 630 631 632
    PyErr_SetString(
            PyExc_AttributeError,
            "Has no attribute named \'_NodeMixin__node\', please "
            "set it first");
    return nullptr;
633 634 635
}

void TensorWrapper::set_module_trace_info(PyObject* obj) {
636
    // TODO: erase when obj == nullptr
637 638
    ModuleTraceTransformation::module_trace_info_map[m_tensor->data()] =
            py::reinterpret_borrow<py::object>(obj);
639 640
}

641 642 643 644 645 646
void TensorWrapper::_set_format(PyObject* dest) {
    auto py_dest = py::reinterpret_borrow<py::object>(dest);
    auto format = py_dest.cast<std::string>();
    m_tensor->set_format(format);
}

647 648 649
void TensorWrapper::_set_name(PyObject* dest) {
    auto py_dest = py::reinterpret_borrow<py::object>(dest);
    auto name = py_dest.cast<std::string>();
650

651 652
    m_tensor->set_name(name);
}
653

654 655
PyObject* TensorWrapper::_detail() {
    return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
656 657
}

658 659
void TensorWrapper::_watch() {
    m_tensor->data().watch();
660 661
}

662
PyObject* TensorWrapper::shape() {
663
    auto shape = m_tensor->shape();
664

665
    if (!shape) {
666 667
        Py_RETURN_NONE;
    }
668 669 670
    py::tuple ret(shape->ndim);
    for (size_t i = 0; i < shape->ndim; ++i) {
        ret[i] = shape->at(i);
671 672 673 674 675 676 677 678 679 680 681 682
    }
    return ret.release().ptr();
}

PyObject* TensorWrapper::dtype() {
    return py::cast(m_tensor->dtype()).release().ptr();
}

PyObject* TensorWrapper::device() {
    return py::cast(m_tensor->comp_node()).release().ptr();
}

683 684 685 686
PyObject* TensorWrapper::format() {
    return py::cast(m_tensor->format().to_string()).release().ptr();
}

687
PyObject* TensorWrapper::numpy() {
688
    auto hv = m_tensor->numpy();
689
    if (!hv) {
690 691 692 693 694 695 696 697 698 699 700
        if (TransformationManager::get_instance()
                    .segments[TransformationManager::Segment::Eval]
                    .size() > 1) {
            PyErr_SetString(
                    PyExc_ValueError,
                    "tensor invalid, can not infer value of this tensor under "
                    "trace(symbolic=True). You can try to use trace(symbolic=False) to "
                    "avoid this issue.");
        } else {
            PyErr_SetString(PyExc_ValueError, "tensor invalid");
        }
701 702
        return nullptr;
    }
703 704
    auto arr = py::reinterpret_steal<py::array>(
            npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
705
    if (hv->shape().is_scalar()) {
706 707 708 709 710 711 712
        mgb_assert(PyArray_Check(arr.ptr()));
        return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
    }
    return arr.release().ptr();
}

void TensorWrapper::reset(PyObject* tensor) {
713
    TensorWrapper* t = TensorWrapper::try_cast(tensor);
714 715 716
    if (!t) {
        throw py::type_error("expect Tensor");
    }
717
    m_tensor->reset(t->m_tensor->data());
718 719
}

720
PyObject* TensorWrapper::detach() {
721 722
    auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
    return TensorWrapper::make(py_tensor_type, detached).release().ptr();
723 724
}

M
Megvii Engine Team 已提交
725
PyObject* TensorWrapper::_dev_tensor() {
726 727 728
    auto dv = m_tensor->data().dev_tensor();
    // TODO: handle scalar
    return py::cast(dv->as_nd(true)).release().ptr();
729 730 731
}

void TensorWrapper::_drop() {
732
    imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
733 734
}

735
PyObject* TensorWrapper::isscalar() {
736
    if (m_tensor->is_scalar()) {
737 738 739 740 741 742
        Py_RETURN_TRUE;
    } else {
        Py_RETURN_FALSE;
    }
}

743 744 745 746
PyObject* TensorWrapper::value_id() {
    return py::cast(m_tensor->value_id()).release().ptr();
}

747 748 749 750 751 752 753 754 755 756 757 758 759 760
PyObject* TensorWrapper::_var() {
    TypedValueRef<NodeValue> value =
            imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
    auto* node = value->node();
    return py::cast(node).release().ptr();
}

PyObject* TensorWrapper::_graph() {
    TypedValueRef<NodeValue> value =
            imperative::apply(GetVarVal(), m_tensor->data())[0].as_ref<NodeValue>();
    auto* graph = value->graph();
    return py::cast(graph).release().ptr();
}

761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781
PyObject* TensorWrapper::_external_obj() {
    TypedValueRef<PyobjectValue> value =
            imperative::apply(GetExternalVal(), m_tensor->data())[0]
                    .as_ref<PyobjectValue>();
    return value->object().release().ptr();
}

PyObject* TensorWrapper::_is_external_value() {
    auto&& external_tsf =
            TransformationManager::get_instance()
                    .segments[TransformationManager::Segment::ExternalConvert];
    auto* tsf = reinterpret_cast<ExternalConvertTransformation*>(external_tsf[0].get());
    mgb_assert(tsf->enabled());
    auto valueref = m_tensor->data();
    if (valueref.is(tsf->value_type())) {
        Py_RETURN_TRUE;
    } else {
        Py_RETURN_FALSE;
    }
}

782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810
void dlpack_capsule_destructor(PyObject* data) {
    if (!PyCapsule_IsValid(data, "dltensor")) {
        // early out, see DLPack spec: if a consuming library sets the capsule
        // name to something else, they own it and we don't need to do anything
        return;
    }
    DLManagedTensor* dlMTensor =
            (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
    dlMTensor->deleter(const_cast<DLManagedTensor*>(dlMTensor));
}

PyObject* tensor_to_dlpack(PyObject* tensor) {
    TensorWrapper* wrapper = TensorWrapper::try_cast(tensor);
    DLManagedTensor* dlMTensor = to_dlpack(wrapper->m_tensor->data());
    return PyCapsule_New(dlMTensor, "dltensor", dlpack_capsule_destructor);
}

PyObject* tensor_from_dlpack(PyObject* data, PyObject* stream) {
    DLManagedTensor* dlMTensor =
            (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
    if (!PyLong_Check(stream)) {
        throw py::type_error("expect int");
    }
    int sid = PyLong_AsLong(stream);
    PyCapsule_SetName(data, "used_dltensor");
    auto tensor = from_dlpack(dlMTensor, sid);
    return TensorWrapper::make(py_tensor_type, std::move(tensor)).release().ptr();
}

811
struct TensorWeakRef {
812
    ValueWeakRef data;
813

814
    TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
815 816

    py::object operator()() {
817
        if (auto p = data.lock()) {
818
            return TensorWrapper::make(py_tensor_type, p);
819 820 821 822 823
        }
        return py::none();
    }
};

824 825 826 827 828 829 830 831 832 833
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
#else
#define WRAP_FUNC_PY35(FUNC)                                \
    PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
        auto* arr = &PyTuple_GET_ITEM(args, 0);             \
        auto size = PyTuple_GET_SIZE(args);                 \
        return FUNC(self, arr, size);                       \
    }
834

835 836 837
WRAP_FUNC_PY35(py_apply);
WRAP_FUNC_PY35(dtype_promotion);
WRAP_FUNC_PY35(get_device);
838 839 840
WRAP_FUNC_PY35(make_shape_tuple);
WRAP_FUNC_PY35(getitem_cpp);
WRAP_FUNC_PY35(setitem_cpp);
841
WRAP_FUNC_PY35(split_cpp);
842
WRAP_FUNC_PY35(expand_dims_cpp);
843
WRAP_FUNC_PY35(squeeze_cpp);
844
WRAP_FUNC_PY35(transpose_cpp);
845 846
WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp);
847
WRAP_FUNC_PY35(adaptive_pool2d_cpp);
848
WRAP_FUNC_PY35(Const);
849
WRAP_FUNC_PY35(astype_cpp);
850 851
WRAP_FUNC_PY35(matmul_cpp);
WRAP_FUNC_PY35(batched_matmul_cpp);
852 853
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
854
WRAP_FUNC_PY35(astensor1d_cpp);
855
WRAP_FUNC_PY35(pixel_shuffle_cpp);
856 857 858 859 860
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif

861
void init_tensor(py::module m) {
862
    imperative::Tensor::static_initialize();
863
    init_backtrace_tss_key();
864
    // Transformations
865 866 867 868
    static auto& transformations = TransformationManager::get_instance();

    using Segment = TransformationManager::Segment;

869 870 871 872 873 874
    using Channel = interpreter::Interpreter::Channel;

    auto* channel =
            imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
                    interpreter::Interpreter::inst().create_channel())
                    ->get();
875
    interpreter_for_py = channel;
876 877 878 879 880 881 882 883 884 885
    MGB_MARK_USED_VAR(
            transformations
                    .register_at<Segment::Eval>(
                            std::make_shared<InterpreterTransformation>(
                                    std::shared_ptr<Channel>(channel, [](Channel*) {})))
                    .release());
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::Scalar>(
                                      std::make_shared<ScalarTransformation>())
                              .release());
886 887 888 889
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::Symbol>(
                                      std::make_shared<SymbolTransformation>())
                              .release());
890 891 892 893 894 895 896 897
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::DTypePromote>(
                                      std::make_shared<DTypePromoteTransformation>())
                              .release());
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::DimExpansion>(
                                      std::make_shared<DimExpansionTransformation>())
                              .release());
898 899 900 901
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::Complex>(
                                      std::make_shared<ComplexTransformation>())
                              .release());
902 903 904 905
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::Grad>(
                                      std::make_shared<GradTransformationGuard>())
                              .release());
906 907 908
    auto format_trans = std::make_shared<FormatTransformation>();
    MGB_MARK_USED_VAR(
            transformations.register_at<Segment::Format>(format_trans).release());
909

M
Megvii Engine Team 已提交
910 911
    static py::exception<interpreter::AsyncError> py_async_error(
            m, "AsyncError", PyExc_RuntimeError);
912 913
    py::register_exception_translator([](std::exception_ptr p) {
        try {
M
Megvii Engine Team 已提交
914 915
            if (p)
                std::rethrow_exception(p);
916 917 918 919 920 921 922 923 924 925
        } catch (const interpreter::AsyncError& e) {
            pyext17::pybind11_translate_exception(e.nested_ptr());
            if (PyErr_Occurred()) {
                PyObject *exc, *val, *tb;
                PyErr_Fetch(&exc, &val, &tb);
                PyErr_NormalizeException(&exc, &val, &tb);
                if (tb) {
                    PyException_SetTraceback(val, tb);
                }
                auto val2 = py_async_error.py::object::operator()(
M
Megvii Engine Team 已提交
926 927
                        "An async error is reported. See above for the actual cause."
                        " Hint: This is where it is reported, not where it happened."
928
                        " You may call `megengine.config.async_level = 0 "
M
Megvii Engine Team 已提交
929 930 931
                        "to get better error reporting.");
                PyException_SetCause(
                        val2.ptr(), val);  // PyException_SetCause steals reference
932 933
                Py_XDECREF(exc);
                Py_XDECREF(tb);
M
Megvii Engine Team 已提交
934 935
                PyErr_Restore(
                        py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
936 937 938 939 940 941
            } else {
                py_async_error("Unkown async error");
            }
        }
    });

942
    // Tensor
M
Megvii Engine Team 已提交
943 944 945 946 947 948
    auto* tensor_type =
            TensorWrapper::wrap_t::type()
                    .def<&TensorWrapper::numpy>("numpy")
                    .def_getset<&TensorWrapper::shape>("shape")
                    .def_getset<&TensorWrapper::dtype>("dtype")
                    .def_getset<&TensorWrapper::device>("device")
949
                    .def<&TensorWrapper::format>("format")
M
Megvii Engine Team 已提交
950 951 952
                    .def<&TensorWrapper::reset>("_reset")
                    .def<&TensorWrapper::isscalar>("_isscalar")
                    .def<&TensorWrapper::detach>("detach")
953
                    // TODO: remove this
M
Megvii Engine Team 已提交
954 955
                    .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
                    .def<&TensorWrapper::_drop>("_drop")
956
                    .def<&TensorWrapper::_detail>("_detail")
957
                    .def<&TensorWrapper::_set_format>("_set_format")
958 959
                    .def<&TensorWrapper::_set_name>("_set_name")
                    .def<&TensorWrapper::_watch>("_watch")
960 961
                    .def<&TensorWrapper::_var>("var")
                    .def<&TensorWrapper::_graph>("graph")
962
                    .def<&TensorWrapper::value_id>("value_id")
963 964
                    .def<&TensorWrapper::_is_external_value>("_is_external_value")
                    .def<&TensorWrapper::_external_obj>("_external_obj")
M
Megvii Engine Team 已提交
965 966 967 968 969 970
                    .def_getset<
                            &TensorWrapper::module_trace_info,
                            &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
                    .finalize();
    if (!tensor_type)
        throw py::error_already_set();
971
    py::setattr(m, "Tensor", tensor_type);
972 973 974 975

    auto* tracekey_type = TraceKeyWrapper::wrap_t::type().finalize();
    py::setattr(m, "tracekey", tracekey_type);

976 977 978 979 980
    py::enum_<Format::Type>(m, "FormatType")
            .value("DEFAULT", Format::Type::DEFAULT)
            .value("NCHW", Format::Type::NCHW)
            .value("NHWC", Format::Type::NHWC)
            .export_values();
981 982

    py::class_<TensorWeakRef>(m, "TensorWeakRef")
M
Megvii Engine Team 已提交
983
            .def(py::init<const TensorWrapper&>())
984
            .def("__call__", &TensorWeakRef::operator());
985

986
    static PyMethodDef method_defs[] = {
987 988 989
            MGE_PY_INTERFACE(apply, py_apply),
            MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
            MGE_PY_INTERFACE(get_device, get_device),
990 991 992
            MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
            MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
            MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
993
            MGE_PY_INTERFACE(split_cpp, split_cpp),
994
            MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
995
            MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
996
            MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
997 998
            MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
            MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
999
            MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
1000
            MGE_PY_INTERFACE(Const, Const),
1001
            MGE_PY_INTERFACE(astype_cpp, astype_cpp),
1002 1003
            MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
            MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
1004 1005
            MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
            MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
1006
            MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
1007
            MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
1008
            {nullptr, nullptr, 0, nullptr}};
M
Megvii Engine Team 已提交
1009
    for (auto&& def : method_defs) {
1010 1011
        if (def.ml_meth != nullptr) {
            auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
M
Megvii Engine Team 已提交
1012 1013
            if (!func)
                throw py::error_already_set();
1014 1015 1016
            py::setattr(m, def.ml_name, func);
        }
    }
1017

1018 1019 1020 1021
    static constexpr auto sync_py_task_q = [] {
        py::gil_scoped_release _;
        py_task_q.wait_all_task_finish();
    };
1022

1023
    m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
1024 1025
    m.def("set_option", [channel](std::string name, size_t value) {
        channel->set_option(name, value);
M
Megvii Engine Team 已提交
1026
    });
1027
    m.def("get_option",
1028 1029 1030 1031 1032
          [channel](std::string name) { return channel->get_option(name); });
    m.def("push_scope", [channel](std::string name) {
        Transformation::push_scope(name);
        channel->push_scope(name);
    });
1033 1034 1035 1036
    m.def("record_scope", [](py::object frame, std::string name) {
        mgb_assert(PyFrame_Check(frame.ptr()));
        record_scope((PyFrameObject*)frame.ptr(), std::move(name));
    });
1037 1038 1039 1040
    m.def("pop_scope", [channel](std::string name) {
        channel->pop_scope(name);
        Transformation::pop_scope(name);
    });
1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063
    std::unordered_map<std::string, ScopeType> str2scopetype = {
            {"default", ScopeType::DEFAULT},
            {"module", ScopeType::MODULE},
            {"tensor_method", ScopeType::TENSOR_METHOD},
            {"functional", ScopeType::FUNCTIONAL},
            {"backward", ScopeType::BACKWARD}};

    m.def("push_scope_with_type",
          [channel, str2scopetype](std::string name, std::string type) {
              if (str2scopetype.find(type) == str2scopetype.end()) {
                  throw py::value_error("unsupport scope type");
              } else {
                  channel->push_scope(name, str2scopetype.find(type)->second);
              }
          });
    m.def("pop_scope_with_type",
          [channel, str2scopetype](std::string name, std::string type) {
              if (str2scopetype.find(type) == str2scopetype.end()) {
                  throw py::value_error("unsupport scope type");
              } else {
                  channel->pop_scope(name, str2scopetype.find(type)->second);
              }
          });
1064 1065 1066 1067 1068 1069 1070 1071 1072
    m.def("start_profile", [channel](imperative::Profiler::options_t options) {
        channel->sync();
        imperative::Profiler::load_options(std::move(options));
        imperative::Profiler::start_profile();
        channel->start_profile();
    });
    m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
        channel->stop_profile();
        channel->sync();
1073
        CompNode::sync_all();
1074 1075 1076 1077 1078 1079 1080 1081
        imperative::Profiler::stop_profile();
        auto results = std::make_shared<imperative::Profiler::bundle_t>(
                imperative::Profiler::collect());
        return [results = results](std::string basename, std::string format) mutable {
            imperative::Profiler::dump_profile(basename, format, std::move(*results));
            results = nullptr;
        };
    });
1082 1083 1084 1085
    m.def("stop_step", [channel]() {
        imperative::Profiler::stop_step();
        channel->stop_step();
    });
1086 1087 1088
    m.def("enable_cupti", &cupti::enable);
    m.def("disable_cupti", &cupti::disable);
    m.def("cupti_available", &cupti::available);
1089 1090 1091 1092 1093 1094 1095

    static std::unique_ptr<CleanupGuard<>> group_comm_guard;
    m.def("group_start", []() {
        auto commtrans = std::make_shared<GroupCommTransformation>();
        group_comm_guard = transformations.register_at<Segment::GroupComm>(commtrans);
    });
    m.def("group_end", []() { group_comm_guard.reset(); });
1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
    m.def("sync", [channel]() {
        if (channel->check_available()) {
            channel->sync();
        }
        sync_py_task_q();
    });
    m.def("full_sync", [channel]() {
        if (channel->check_available()) {
            channel->sync();
        }
        CompNode::sync_all();
        CompNode::foreach ([](CompNode cn) {
            auto err = cn.check_async_error();
            mgb_assert(!err, "%s", err->what());
        });
        sync_py_task_q();
    });
    m.def("close", [channel]() {
1114 1115 1116 1117 1118 1119 1120 1121 1122
        // sync channel and compnode before close to ensure all tasks have been completed
        if (channel->check_available()) {
            channel->sync();
        }
        CompNode::sync_all();
        CompNode::foreach ([](CompNode cn) {
            auto err = cn.check_async_error();
            mgb_assert(!err, "%s", err->what());
        });
1123 1124
        channel->close();
        sync_py_task_q();
M
Megvii Engine Team 已提交
1125 1126
    });

1127 1128 1129 1130 1131 1132 1133 1134
    py::class_<GradSlotPtr>(m, "GradSlot")
            .def_property_readonly("grad", [](GradSlotPtr& self) -> py::object {
                if (self->grad())
                    return TensorWrapper::make(py_tensor_type, self->grad());
                else
                    return py::none();
            });

1135
    // GradTransformation
M
Megvii Engine Team 已提交
1136 1137 1138 1139 1140 1141
    py::handle grad_key_type =
            GradKeyWrapper::wrap_t::type()
                    .def<&GradKeyWrapper::attach>("attach")
                    .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
                    .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
                            "name")
1142 1143 1144 1145
                    .def<&GradKeyWrapper::enter>("enter")
                    .def<&GradKeyWrapper::exit>("exit")
                    .def<&GradKeyWrapper::suppress>("suppress")
                    .def<&GradKeyWrapper::resume>("resume")
M
Megvii Engine Team 已提交
1146 1147 1148
                    .finalize();
    if (!grad_key_type)
        throw py::error_already_set();
1149
    py::setattr(m, "GradKey", grad_key_type);
1150
    m.def("backward", &GradKeyWrapper::backward);
1151
    m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
    m.def("get_grad_slot", [](py::object tensor) -> py::object {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        if (tw) {
            auto rst = imperative::apply(GetGradSlot(), tw->m_tensor->data());
            if (rst.size() == 1) {
                GradSlotPtr slot = rst[0].cast<GradSlotValue>();
                return py::cast(slot);
            } else {
                return py::none();
            }
        }

        return py::none();
    });
    m.def("get_handle_id", [](py::object tensor) -> py::object {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        if (tw) {
            auto rst = imperative::apply(GetId(), tw->m_tensor->data());
            int id = rst[0].cast<IntegerValue>();
            return py::cast(id);
        }
        return py::none();
    });
1175

1176 1177 1178 1179
    m.def("set_py_tensor_type", [](py::object type_obj) {
        py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
    });

1180 1181 1182 1183
    m.def("set_py_varnode_type", [](py::object type_obj) {
        py_varnode_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
    });

1184 1185 1186 1187
    m.def("set_py_external_type", [](py::object type_obj) {
        py_external_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
    });

1188 1189 1190
    m.def("set_py_device_type",
          [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });

1191 1192 1193 1194 1195 1196 1197 1198 1199 1200
    /**
     * \brief trace proxy
     *
     */
    struct Trace {
        bool symbolic = false;
        bool no_exec = false;
        bool capture_as_const = false;
        bool profile = false;
        bool record_input_shapes = false;
1201 1202 1203
        bool without_host = false;
        bool check_external = true;
        bool remove_unused_data_required = true;
1204 1205 1206 1207 1208 1209 1210
        py::function options_visitor;
        std::shared_ptr<TracingTransformation> tracing;
        std::shared_ptr<CompiledTransformation> compiled;
        std::shared_ptr<LazyEvalTransformation> lazy_eval;
        std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
        std::optional<TraceResult> trace_result;
        std::function<bool(py::object, py::object)> array_comparator;
1211 1212 1213
        std::unique_ptr<CleanupGuard<>> tracing_guard;
        std::unique_ptr<CleanupGuard<>> compiled_guard;
        std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
1214 1215
        std::unordered_map<size_t, size_t> inpmark_to_id;
        std::unordered_map<size_t, size_t> outmark_to_id;
1216 1217

        bool compare_value(ValueRef lhs, ValueRef rhs) {
1218 1219
            auto lvalue = lhs.cast_ref<HostValue>();
            auto rvalue = rhs.cast_ref<HostValue>();
1220
            if (lvalue->shape() != rvalue->shape()) {
1221 1222
                return false;
            }
1223
            if (lvalue->shape().total_nr_elems() == 1) {
1224 1225 1226 1227
                return lvalue->item() == rvalue->item();
            }
            HostTensorND lnd = lvalue->as_nd(true);
            HostTensorND rnd = rvalue->as_nd(true);
1228
            auto larr = py::reinterpret_steal<py::array>(
1229
                    npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
1230
            auto rarr = py::reinterpret_steal<py::array>(
1231
                    npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
1232 1233 1234
            return array_comparator(larr, rarr);
        }

1235 1236 1237 1238 1239 1240 1241 1242 1243 1244
        void mark_input(size_t mark, size_t id) {
            trace_result->vars[id].inp_marker.insert(mark);
            mgb_assert(inpmark_to_id.find(mark) == inpmark_to_id.end());
            inpmark_to_id[mark] = id;
        }
        void mark_output(size_t mark, size_t id) {
            trace_result->vars[id].out_marker.insert(mark);
            mgb_assert(outmark_to_id.find(mark) == outmark_to_id.end());
            outmark_to_id[mark] = id;
        }
1245 1246 1247 1248 1249
        void enter() {
            auto& self = *this;
            if (!self.trace_result) {  // untraced
                self.tracing = std::make_shared<TracingTransformation>(
                        self.capture_as_const, self.record_input_shapes);
1250 1251
                if (self.without_host)
                    self.tracing->enable_record_all_shapes();
1252 1253 1254 1255 1256 1257 1258 1259 1260
                if (self.symbolic) {
                    self.lazy_eval =
                            std::make_shared<LazyEvalTransformation>(self.no_exec);
                    self.options_visitor(py::cast(&self.lazy_eval->options()));
                }
            } else if (!self.compiled) {  // traced but not compiled
                using namespace std::placeholders;
                self.compiled = std::make_shared<CompiledTransformation>(
                        *self.trace_result, self.record_input_shapes);
1261 1262 1263
                self.compiled->set_value_comparator(
                        std::bind(&Trace::compare_value, this, _1, _2));
                self.options_visitor(py::cast(&self.compiled->options()));
1264 1265 1266
                try {
                    self.compiled->compile();
                } catch (const std::exception& e) {
1267
                    mgb_log_error("error in trace: %s", e.what());
1268
                }
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280
            }
            // register transformations
            if (self.compiled) {
                if (self.profile) {
                    auto& current_graph = self.compiled->graph();
                    if (self.profiler.first != self.compiled->graph().id()) {
                        // graph changed
                        self.profiler = std::make_pair(
                                current_graph.id(),
                                std::make_shared<GraphProfiler>(&current_graph));
                    }
                }
1281 1282 1283 1284 1285
                if (!without_host)
                    compiled_guard =
                            transformations.register_at<Segment::Trace>(self.compiled);
                else
                    self.compiled->set_pc_to_end();
1286 1287 1288
                // start execute because InputCallback depends
                self.compiled->execute();
            } else if (self.tracing) {
1289 1290
                tracing_guard =
                        transformations.register_at<Segment::Trace>(self.tracing);
1291
                if (self.lazy_eval) {
1292 1293
                    lazy_eval_guard =
                            transformations.register_at<Segment::Eval>(self.lazy_eval);
1294 1295 1296 1297 1298 1299 1300 1301 1302
                }
            } else {
                mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
            }
        }

        void exit() {
            auto& self = *this;
            if (self.tracing) {
1303
                tracing_guard.reset();
1304 1305 1306 1307 1308
                if (self.without_host) {
                    self.tracing->postprocess_trace_result();
                    self.inpmark_to_id = self.tracing->inpmark_to_id;
                    self.outmark_to_id = self.tracing->outmark_to_id;
                }
1309
                self.trace_result = self.tracing->get_result();
1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328
                if (self.without_host) {
                    for (auto&& var : self.trace_result->vars) {
                        var.shape_required = false;
                        var.value_required = false;
                        if (var.data_required && var.out_marker.empty() &&
                            remove_unused_data_required)
                            var.data_required = false;
                        if (var.inp_marker.empty() &&
                            var.kind == TraceResult::VarInfo::Kind::External) {
                            if (var.bound_data) {
                                var.kind = TraceResult::VarInfo::Kind::Constant;
                            } else if (self.check_external) {
                                throw std::runtime_error(
                                        "have some unknown input tensors in trace "
                                        "result");
                            }
                        }
                    }
                }
1329 1330 1331
                self.tracing.reset();
                if (self.lazy_eval) {
                    auto lazy_eval = std::move(self.lazy_eval);
1332
                    lazy_eval_guard.reset();
1333 1334 1335
                    lazy_eval->check_exception();
                }
            } else if (self.compiled) {
1336 1337
                if (!without_host)
                    compiled_guard.reset();
1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380
                self.compiled->wait();
            } else {
                mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
            }
        }

        VarNodeArray dump(
                std::shared_ptr<ComputingGraph> graph,
                std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
                std::vector<std::pair<std::string, std::string>> outputs,
                bool prefer_input_names) {
            auto& self = *this;
            mgb_assert(self.trace_result);
            // mark is like "arg_0", "kwarg_xxx", "output_0" ...
            std::unordered_map<std::string, size_t> mark2var;
            for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
                auto& name = self.trace_result->vars[i].mark;
                if (!name.empty()) {
                    mark2var[name] = i;
                }
            }
            std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
            std::vector<std::pair<size_t, std::string>> output_vars;
            for (auto&& [input_mark, input_name, input_shape] : inputs) {
                mgb_assert(input_shape.ndim, "input shape invalid");
                input_vars.push_back(
                        {mark2var.at(input_mark), input_name, input_shape});
            }
            for (auto&& [output_name, repr] : outputs) {
                output_vars.push_back({mark2var.at(output_name), repr});
            }
            self.options_visitor(py::cast(&graph->options()));
            auto vars = self.trace_result->dump(
                    *graph, input_vars, output_vars, prefer_input_names);
            return vars;
        }
    };

    py::class_<Trace>(m, "Trace")
            .def(py::init<>())
            .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
            .def_readwrite("array_comparator", &Trace::array_comparator)
            .def_readwrite("profile", &Trace::profile)
1381 1382 1383 1384
            .def_readwrite("without_host", &Trace::without_host)
            .def_readwrite("check_external", &Trace::check_external)
            .def_readwrite(
                    "remove_unused_data_required", &Trace::remove_unused_data_required)
1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
            .def_property_readonly(
                    "options",
                    [](Trace& self) {
                        if (self.compiled) {
                            return &self.compiled->options();
                        } else {
                            return (ComputingGraph::Options*)nullptr;
                        }
                    })
            .def("get_profile",
                 [](Trace& self) -> py::object {
                     if (self.profiler.second && self.compiled) {
                         auto json = self.profiler.second->to_json_full(
                                 self.compiled->graph().current_comp_seq());
                         return py::str(json->to_string());
                     } else {
                         return py::none();
                     }
                 })
            .def_readwrite("symbolic", &Trace::symbolic)
            .def_readwrite("capture_as_const", &Trace::capture_as_const)
            .def_readwrite("no_exec", &Trace::no_exec)
            .def_readwrite("options_visitor", &Trace::options_visitor)
            .def("enter", &Trace::enter)
            .def("exit", &Trace::exit)
            .def("dump", &Trace::dump)
1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469
            .def("set_execption",
                 [](Trace& self, std::string error) {
                     if (self.compiled) {
                         auto exc = std::make_exception_ptr(TraceError(error));
                         self.compiled->set_exception(exc);
                     }
                 })
            .def("compiled", [](Trace& self) { return bool(self.compiled); })
            .def("put_data",
                 [](Trace& self, int mark, py::object tensor) {
                     auto id = self.inpmark_to_id[mark];
                     auto&& varinfo = self.trace_result->vars[id];
                     mgb_assert(varinfo.kind == TraceResult::VarInfo::Kind::External);
                     auto&& accessor = self.compiled->get_accessor_by_id(id);
                     auto* tw = TensorWrapper::try_cast(tensor.ptr());
                     mgb_assert(tw);
                     accessor.data_setter(
                             tw->m_tensor->data().dev_tensor()->as_nd(true));
                 })
            .def("put_datas",
                 [](Trace& self, std::unordered_map<int, py::object> inps) {
                     for (auto&& inp : inps) {
                         auto&& mark = inp.first;
                         auto&& tensor = inp.second;
                         auto id = self.inpmark_to_id[mark];
                         auto&& varinfo = self.trace_result->vars[id];
                         mgb_assert(
                                 varinfo.kind == TraceResult::VarInfo::Kind::External);
                         auto&& accessor = self.compiled->get_accessor_by_id(id);
                         auto* tw = TensorWrapper::try_cast(tensor.ptr());
                         mgb_assert(tw);
                         accessor.data_setter(
                                 tw->m_tensor->data().dev_tensor()->as_nd(true));
                     }
                 })
            .def("get_data",
                 [](Trace& self, int mark) {
                     auto id = self.outmark_to_id[mark];
                     auto&& varinfo = self.trace_result->vars[id];
                     mgb_assert(varinfo.data_required);
                     auto&& accessor = self.compiled->get_accessor_by_id(id);
                     mgb_assert(accessor.data_getter);
                     auto dev_value = DeviceValue::make(accessor.data_getter());
                     return TensorWrapper::make(
                             py_tensor_type,
                             imperative::apply(
                                     CreateTensor(
                                             CreateTensor::Common, dev_value->device(),
                                             dev_value->dtype(), dev_value->shape()),
                                     DeviceStorage::make(dev_value->storage()))[0]);
                 })
            .def_property_readonly(
                    "ops", [](Trace& self) { return self.trace_result->seq; })
            .def_property_readonly(
                    "vars", [](Trace& self) { return self.trace_result->vars; })
            .def_property_readonly(
                    "inpmark_to_id", [](Trace& self) { return self.inpmark_to_id; })
            .def_property_readonly(
                    "outmark_to_id", [](Trace& self) { return self.outmark_to_id; })
1470 1471 1472 1473
            .def("begin_excluded_region",
                 [](Trace& self) {
                     mgb_assert(bool(self.tracing) ^ bool(self.compiled));
                     if (self.tracing) {
1474
                         self.tracing_guard.reset();
1475
                     } else if (self.compiled) {
1476
                         self.compiled_guard.reset();
1477
                     }
M
Megvii Engine Team 已提交
1478
                 })
1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606
            .def("end_excluded_region",
                 [](Trace& self) {
                     mgb_assert(bool(self.tracing) ^ bool(self.compiled));
                     if (self.tracing) {
                         self.tracing_guard =
                                 transformations.register_at<Segment::Trace>(
                                         self.tracing);
                     } else if (self.compiled) {
                         self.compiled_guard =
                                 transformations.register_at<Segment::Trace>(
                                         self.compiled);
                     }
                 })
            .def("mark_output", &Trace::mark_output)
            .def("mark_input", &Trace::mark_input);
    using VarInfo = TraceResult::VarInfo;
    using OpKind = TraceResult::SeqItem::OpKind;
    std::unordered_map<VarInfo::Kind, std::string> kind2str = {
            {VarInfo::Kind::Internal, "internal"},
            {VarInfo::Kind::External, "external"},
            {VarInfo::Kind::Constant, "const"},
    };
    std::unordered_map<OpKind, std::string> opkind2str = {
            {OpKind::Unknown, "unknown"},
            {OpKind::TraceMarkVar, "trace_mark_var"},
            {OpKind::IOMarkVar, "io_mark_var"},
            {OpKind::CreateTensor, "create_tensor"},
            {OpKind::Rename, "rename"}

    };
    py::class_<VarInfo>(m, "VarInfo")
            .def_property_readonly("shape", [](VarInfo& self) { return self.shape; })
            .def_property_readonly(
                    "value_required", [](VarInfo& self) { return self.value_required; })
            .def_property_readonly(
                    "shape_required", [](VarInfo& self) { return self.shape_required; })
            .def_readwrite("data_required", &VarInfo::data_required)
            .def("set_external",
                 [](VarInfo& self) { self.kind = VarInfo::Kind::External; })
            .def_property_readonly(
                    "bound_data",
                    [](VarInfo& self) -> py::object {
                        if (self.bound_data)
                            return py::reinterpret_steal<py::array>(
                                    npy::ndarray_from_tensor(
                                            self.bound_data.numpy()->as_nd(true),
                                            npy::ShareType::TRY_SHARE));
                        return py::none();
                    })
            .def_property_readonly(
                    "dtype",
                    [](VarInfo& self) {
                        auto ret = static_cast<DType>(*self.dtype);
                        if (ret == dtype::Byte()) {
                            ret = dtype::Uint8();
                        }
                        return ret;
                    })
            .def_property_readonly(
                    "device",
                    [](VarInfo& self) { return static_cast<CompNode>(*self.device); })
            .def_property_readonly("id", [](VarInfo& self) { return self.id; })
            .def_property_readonly(
                    "handle_id", [](VarInfo& self) { return self.handle_id; })
            .def_property_readonly("name", [](VarInfo& self) { return self.name; })
            .def_property_readonly("mark", [](VarInfo& self) { return self.mark; })
            .def_property_readonly(
                    "inp_mark", [](VarInfo& self) { return self.inp_marker; })
            .def_property_readonly(
                    "out_mark", [](VarInfo& self) { return self.out_marker; })
            .def_property_readonly("kind", [kind2str](VarInfo& self) {
                return kind2str.find(self.kind)->second;
            });
    using SeqItem = TraceResult::SeqItem;
    auto json = py::module::import("json");

    py::class_<SeqItem>(m, "OpInfo")
            .def(py::init([opkind2str](
                                  std::shared_ptr<OpDef> op,
                                  const SmallVector<size_t>& inputs,
                                  const SmallVector<size_t>& outputs,
                                  const std::string& op_kind) {
                SeqItem::OpKind enum_op_kind = SeqItem::OpKind::Unknown;
                for (auto&& kv : opkind2str) {
                    if (op_kind == kv.second) {
                        enum_op_kind = kv.first;
                    }
                }
                return SeqItem{op, inputs, outputs, enum_op_kind};
            }))
            .def_property_readonly(
                    "op",
                    [opkind2str](SeqItem& self) -> py::object {
                        if (self.op) {
                            if (auto* op = self.op->try_cast_final<OprAttr>()) {
                                return py::cast(op->type);
                            }
                            return py::cast(self.op);
                        } else
                            return py::cast(opkind2str.find(self.kind)->second);
                    })
            .def_property_readonly("inputs", [](SeqItem& self) { return self.inputs; })
            .def_property_readonly(
                    "outputs", [](SeqItem& self) { return self.outputs; })
            .def_property_readonly(
                    "type",
                    [opkind2str](SeqItem& self) -> py::object {
                        if (self.op)
                            return py::cast(self.op->type_name());
                        else
                            return py::cast(opkind2str.find(self.kind)->second);
                    })
            .def_property_readonly(
                    "kind",
                    [opkind2str](SeqItem& self) {
                        return opkind2str.find(self.kind)->second;
                    })
            .def_property_readonly("param", [json](SeqItem& self) -> py::object {
                if (self.op) {
                    if (auto* op = self.op->try_cast_final<OprAttr>()) {
                        auto param =
                                op->mgb_param(_imperative_sm_opr_footprint_ptr.get())
                                        ->to_string();
                        return json.attr("loads")(py::cast(param));
                    } else {
                        auto pyop = py::cast(self.op);
                        return pyop.attr("__getstate__")();
                    }
1607
                }
1608
                return py::dict();
1609 1610 1611
            });
    m.def("name_tensor", [](std::string name, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
1612
        mgb_assert(tw, "Arg_1 shoud be Tensor!");
1613 1614 1615 1616
        auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
        tw->m_tensor->reset(output);
    });

1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643
    m.def("get_marked_tensor", [](std::string name, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
        return TensorWrapper::make(py_tensor_type, output);
    });

    m.def("get_marked_input_tensor", [](int mark, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        auto output = imperative::apply(
                IOMarkVar(mark, IOMarkVar::Kind::Input), tw->m_tensor->data())[0];
        return TensorWrapper::make(py_tensor_type, output);
    });

    m.def("marked_input_tensor", [](int mark, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        auto output = imperative::apply(
                IOMarkVar(mark, IOMarkVar::Kind::Input), tw->m_tensor->data())[0];
        tw->m_tensor->reset(output);
    });

    m.def("get_marked_output_tensor", [](int mark, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        auto output = imperative::apply(
                IOMarkVar(mark, IOMarkVar::Kind::Output), tw->m_tensor->data())[0];
        return TensorWrapper::make(py_tensor_type, output);
    });

1644
    m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
1645
        SmallVector<ValueRef> values(tensors.size());
1646 1647
        for (size_t i = 0; i < tensors.size(); ++i) {
            values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
1648 1649 1650 1651 1652 1653 1654 1655 1656 1657
        }
        auto outputs = imperative::apply(GetGradKey(), values);
        if (outputs[0].is<GradKeyValue>()) {
            return true;
        } else {
            return false;
        }
    });

    m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
1658
        SmallVector<ValueRef> values(tensors.size());
1659 1660
        for (size_t i = 0; i < tensors.size(); ++i) {
            values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
1661
        }
1662 1663
        auto output = imperative::apply(GetGradKey(), values)[0];
        if (!output) {
1664 1665
            return py::none();
        }
1666 1667
        return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
                GradKeyWrapper::get(output.cast<GradKeyValue>())));
1668 1669
    });

1670
    m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
1671 1672
                         std::vector<py::object> outputs) {
        GenericFunction generic_backward_fn =
1673
                [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
1674 1675 1676 1677 1678 1679 1680 1681 1682 1683
            py::list output_grad_tws;
            for (auto&& output_grad : output_grads) {
                if (output_grad) {
                    output_grad_tws.append(
                            TensorWrapper::make(py_tensor_type, output_grad));
                } else {
                    output_grad_tws.append(py::none());
                }
            }
            py::tuple input_grad_tws = backward_fn(*output_grad_tws);
1684 1685 1686
            ValueRefList input_grads(input_grad_tws.size());
            for (size_t i = 0; i < input_grad_tws.size(); ++i) {
                auto input_grad_tw = input_grad_tws[i];
1687
                if (!input_grad_tw.is_none()) {
1688 1689
                    input_grads[i] =
                            py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
1690
                } else {
1691
                    input_grads[i] = {};
1692 1693 1694 1695
                }
            }
            return input_grads;
        };
1696
        SmallVector<ValueRef> values(inputs.size() + outputs.size());
1697 1698
        for (size_t i = 0; i < inputs.size(); ++i) {
            values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
1699
        }
1700 1701 1702
        for (size_t i = 0; i < outputs.size(); ++i) {
            values[i + inputs.size()] =
                    outputs[i].cast<TensorWrapper>().m_tensor->data();
1703
        }
1704 1705
        auto wrapped_output_values =
                imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
1706 1707 1708 1709 1710 1711 1712 1713 1714
        std::vector<py::object> wrapped_outputs;
        mgb_assert(wrapped_output_values.size() == outputs.size());
        for (auto&& output_value : wrapped_output_values) {
            wrapped_outputs.push_back(
                    TensorWrapper::make(py_tensor_type, output_value));
        }
        return wrapped_outputs;
    });

1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725
    m.def("add_backward_callback", [](py::function callback) {
        ValueRef id = IntegerValue::make(0);
        GenericFunction generic_function =
                [callback](Span<ValueRef> inputs) -> ValueRefList {
            callback();
            return {};
        };
        auto output_values =
                imperative::apply(InsertGradCallback(generic_function), id);
    });

1726
    // ModuleTraceTransformation
1727 1728
    static py::function module_trace_hook;

1729 1730
    static auto get_module_trace = [] {
        static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
1731 1732 1733 1734
        if (!module_trace_transformation) {
            mgb_assert(module_trace_hook);
            module_trace_transformation =
                    std::make_shared<ModuleTraceTransformation>(module_trace_hook);
1735 1736 1737 1738
            MGB_MARK_USED_VAR(transformations
                                      .register_at<Segment::ModuleTrace>(
                                              module_trace_transformation)
                                      .release());
1739
        }
1740 1741
        return module_trace_transformation;
    };
1742

1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760
    static py::function external_convert_hook;

    static auto get_external_convert = [] {
        static std::shared_ptr<ExternalConvertTransformation>
                external_convert_transformation;
        if (!external_convert_transformation) {
            mgb_assert(external_convert_hook);
            external_convert_transformation =
                    std::make_shared<ExternalConvertTransformation>(
                            external_convert_hook);
            MGB_MARK_USED_VAR(transformations
                                      .register_at<Segment::ExternalConvert>(
                                              external_convert_transformation)
                                      .release());
        }
        return external_convert_transformation;
    };

1761 1762
    m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);

1763 1764 1765
    m.def("set_module_tracing", [=] { get_module_trace()->enable(); });

    m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
1766

1767
    m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
1768 1769 1770 1771 1772 1773

    m.def("set_external_convert", [=] { get_external_convert()->enable(); });

    m.def("unset_external_convert", [=] { get_external_convert()->disable(); });

    m.def("is_external_convert", [=] { return get_external_convert()->enabled(); });
1774 1775 1776 1777 1778 1779
    m.def("set_python_backtrace_enabled", &set_python_backtrace_enabled);
    m.def("set_transformation_backtrace_enabled",
          &set_transformation_backtrace_enabled);
    m.def("_mge_backtrace", &get_py_backtrace);
    m.def("_get_frame_cache_id",
          []() { return (size_t)FrameInfoCache::get_instance(); });
1780 1781 1782 1783
    m.def("set_module_trace_hook", [](py::function function) {
        module_trace_hook = function;
        module_trace_hook.inc_ref();
    });
1784

1785 1786 1787 1788 1789
    m.def("set_external_convert_hook", [](py::function function) {
        external_convert_hook = function;
        external_convert_hook.inc_ref();
    });

1790
    auto atexit = py::module::import("atexit");
1791 1792 1793 1794
    atexit.attr("register")(py::cpp_function([]() {
        module_trace_hook = {};
        external_convert_hook = {};
    }));
1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805
    m.def("begin_record_values", [] { Value::begin_record_values(); });

    m.def("end_record_values", [] {
        std::vector<std::pair<size_t, std::string>> reprs;
        auto values = Value::end_record_values();
        for (auto&& value : values) {
            reprs.push_back({value.id(), value.to_string()});
        }
        return reprs;
    });

1806
    m.def("print_stats", [] { Stats::print(); });
1807

1808
    m.def("reset_stats", [] { Stats::reset(); });
1809

1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866
    m.def("_get_convert_inputs",
          []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
    m.def("_set_convert_inputs", [](bool flag) -> bool {
        bool ret = DTypePromoteCfg::convert_input_enabled;
        DTypePromoteCfg::convert_input_enabled = flag;
        return ret;
    });
    m.def("_get_amp_dtype_autocast",
          []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
    m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
        bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
        DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
        return ret;
    });

    static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
        DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
                                : DTypePromoteCfg::amp_low_prec_dtype;
        mgb_assert(target.category() == DTypeCategory::FLOAT);
        std::string ret = target.name();
        transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
        return ret;
    };

    static auto set_amp_prec_dtype = [](bool is_high,
                                        std::string dtype_name) -> std::string {
        DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
                                : DTypePromoteCfg::amp_low_prec_dtype;
        std::string ret = target.name();

        if (dtype_name == "float32") {
            target = dtype::Float32();
        } else if (dtype_name == "float16") {
            target = dtype::Float16();
        } else if (dtype_name == "bfloat16") {
            target = dtype::BFloat16();
        } else {
            mgb_assert(
                    false, "casted type of amp should be float, but you give %s\n",
                    dtype_name.c_str());
        }

        transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
        return ret;
    };

    m.def("_get_amp_high_prec_dtype",
          []() -> std::string { return get_amp_prec_dtype(true); });
    m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
        return set_amp_prec_dtype(true, dtype_name);
    });
    m.def("_get_amp_low_prec_dtype",
          []() -> std::string { return get_amp_prec_dtype(false); });
    m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
        return set_amp_prec_dtype(false, dtype_name);
    });

1867 1868
    m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });

1869 1870 1871 1872 1873 1874
    // FormatTransformation
    m.def("set_auto_format_convert",
          [format_trans](bool enabled) { format_trans->set_auto_convert(enabled); });
    m.def("get_auto_format_convert",
          [format_trans]() { return format_trans->get_auto_convert(); });

1875 1876 1877 1878 1879 1880 1881 1882
    m.def("_to_dlpack", [](py::object tensor) {
        return py::reinterpret_steal<py::object>(tensor_to_dlpack(tensor.ptr()));
    });

    m.def("_from_dlpack", [](py::object data, py::object stream) {
        return py::reinterpret_steal<py::object>(
                tensor_from_dlpack(data.ptr(), stream.ptr()));
    });
1883
    py::register_exception<TraceError>(m, "TraceError");
1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908

    m.def("create_complex", [](py::object real, py::object imag) {
        return TensorWrapper::make(
                py_tensor_type,
                imperative::apply(
                        CreateComplex(),
                        TensorWrapper::try_cast(real.ptr())->m_tensor->data(),
                        TensorWrapper::try_cast(imag.ptr())->m_tensor->data())[0]);
    });

    m.def("get_real", [](py::object complex) {
        return TensorWrapper::make(
                py_tensor_type,
                imperative::apply(
                        GetReal(),
                        TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
    });

    m.def("get_imag", [](py::object complex) {
        return TensorWrapper::make(
                py_tensor_type,
                imperative::apply(
                        GetImag(),
                        TensorWrapper::try_cast(complex.ptr())->m_tensor->data())[0]);
    });
1909 1910
}

1911 1912
#undef MGE_PY_INTERFACE

M
Megvii Engine Team 已提交
1913
}  // namespace mgb::imperative::python