tensor.cpp 53.9 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/tensor.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12
#include "megbrain/common.h"
M
Megvii Engine Team 已提交
13
#include "megbrain/dtype.h"
14
#include "megbrain/imperative/ops/autogen.h"
M
Megvii Engine Team 已提交
15 16
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
17
#include "megbrain/imperative/profiler.h"
18
#include "megbrain/imperative/transformations/dim_expansion.h"
19
#include "megbrain/imperative/transformations/dtype_promote.h"
20 21 22 23 24 25
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
26
#include "megbrain/opr/io.h"
27
#include "megbrain/plugin/profiler.h"
28
#include "megbrain/utils/stats.h"
29
#include "megdnn/algorithm_cache.h"
30

31
#include "./common.h"
M
Megvii Engine Team 已提交
32
#include "./grad.h"
33
#include "./graph_rt.h"
34
#include "./helper.h"
M
Megvii Engine Team 已提交
35 36 37
#include "./module_trace.h"
#include "./numpy_dtypes.h"
#include "./tensor.h"
38
#include "./tensor_utils.h"
39
#include "./transformation.h"
40

41
#include <object.h>
42 43
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
44 45
#include <pybind11/pytypes.h>
#include <pyerrors.h>
46
#include <iterator>
47
#include <range/v3/all.hpp>
48
#include <string>
49 50 51

#include <unordered_map>

52 53
#include "../../src/impl/mgb_cg_impl.h"

54
namespace py = pybind11;
55
namespace views = ranges::views;
56 57 58

namespace mgb::imperative::python {

59 60
namespace {
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
61 62 63

struct SymbolVarContext {
    TransformationContext context;
64 65
    std::shared_ptr<SymbolTransformation> symbol_tsf;
    std::shared_ptr<ScalarTransformation> scalar_tsf;
66
    std::shared_ptr<DTypePromoteTransformation> dtype_promote_tsf;
67
    std::shared_ptr<DimExpansionTransformation> dim_expansion_tsf;
68

69 70 71
    SymbolVarContext(cg::ComputingGraph* graph) {
        symbol_tsf = std::make_shared<SymbolTransformation>(graph);
        scalar_tsf = std::make_shared<ScalarTransformation>();
72
        dtype_promote_tsf = std::make_shared<DTypePromoteTransformation>();
73
        dim_expansion_tsf = std::make_shared<DimExpansionTransformation>();
74 75 76 77
        Transformation::swap_context(context);
    }

    void init() {
78 79
        symbol_tsf->register_at(Transformation::top());
        scalar_tsf->register_at(Transformation::top());
80
        dtype_promote_tsf->register_at(Transformation::top());
81
        dim_expansion_tsf->register_at(Transformation::top());
82 83
    }

84 85 86 87 88 89 90 91
    ValueRef symvar2val(py::handle py_symbol_var) {
        auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
        ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
        if (symbol_var->is_scalar) {
            value = scalar_tsf->value_type().make(value);
        }
        return value;
    }
92

93 94 95 96 97 98 99 100 101 102 103
    py::object val2symvar(py::handle typeobj, ValueRef value) {
        bool is_scalar = false;
        if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
            value = scalar_value->value();
            is_scalar = true;
        }
        auto* node = value.cast(symbol_tsf->value_type()).node();
        auto py_symbol_var =
                typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
        py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
        return py_symbol_var;
104 105
    }

106 107
    ~SymbolVarContext() { Transformation::swap_context(context); }
};
108

109 110
}  // namespace

111 112
interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr;
113
pybind11::handle py_device_type = nullptr;
114
PyObject* cpp_use_symbolic_shape;
115 116 117 118 119 120 121

#define REGISTE_APPLY_FUNC(mode) \
    void set_##mode(py::object pyf) { mode = pyf.ptr(); }

REGISTE_APPLY_FUNC(cpp_use_symbolic_shape)

#undef REGISTE_APPLY_FUNC
122

123 124 125
PyArray_Descr* _dtype_promotion(PyObject* const* args, size_t nargs);
CompNode _get_device(PyObject* const* args, size_t nargs);

M
Megvii Engine Team 已提交
126 127
PyObject* py_apply(
        PyObject* self, PyObject* const* args, size_t nargs /* , PyObject* kwnames */) {
128 129 130 131 132
    try {
        // if (kwnames && PyTuple_GET_SIZE(kwnames)) {
        //     PyErr_SetString(PyExc_TypeError, "keyword argument not allowed");
        //     return nullptr;
        // }
133
        if (nargs < 2) {
M
Megvii Engine Team 已提交
134 135 136 137
            PyErr_SetString(
                    PyExc_TypeError,
                    "py_apply expects one Op and at least one tensor "
                    "as argument");
138 139
            return nullptr;
        }
140

141
        auto* py_op = args[0];
142

143 144 145
        ++args;
        --nargs;

146
        auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
147
        SmallVector<ValueRef, 8> tensors(nargs);
148

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        SmallVector<bool, 8> is_symbol_var(nargs, false);
        ComputingGraph* cg = nullptr;
        for (size_t i = 0; i < nargs; ++i) {
            if ((!TensorWrapper::try_cast(args[i])) &&
                py::isinstance<PySymbolVar>(py::handle(args[i]))) {
                is_symbol_var[i] = true;
                ComputingGraph* cur_cg =
                        py::handle(args[i]).cast<PySymbolVar*>()->m_node->owner_graph();
                if (cg == nullptr) {
                    cg = cur_cg;
                } else {
                    mgb_assert(cg == cur_cg);
                }
            }
        }

        mgb::CompNode target_cn;
        mgb::DType target_dtype;

        auto convert_pyinput_to_tensor = [&](size_t i) -> ValueRef {
            if (!target_dtype.valid()) {
                target_dtype = npy::dtype_np2mgb_descr(_dtype_promotion(args, nargs));
                target_cn = _get_device(args, nargs);
            }
            HostTensorND ht(target_cn);
            ht = npy::np2tensor(args[i], npy::Meth::copy_into(&ht), target_dtype);
175
            if (PyArray_Check(args[i]) || PyList_Check(args[i])) {  // non scaler
176
                // py_tuple is not allowed here because of tracing
177 178 179 180 181 182 183 184 185 186 187
                return imperative::apply(
                        CreateTensor(CreateTensor::Const, target_cn, ht.layout()),
                        HostStorage::make(ht.storage()))[0];
            } else {  // scaler
                return imperative::apply(
                        CreateTensor(CreateTensor::Const, target_cn, target_dtype, {}),
                        HostStorage::make(ht.storage()))[0];
            }
        };

        if (cg != nullptr) {
188
            // swap to a special context to reuse scalar handle
189 190
            size_t symbol_var_idx = 8;
            SymbolVarContext context(cg);
191
            context.init();
192
            for (size_t i = 0; i < nargs; ++i) {
193 194 195
                if (is_symbol_var[i]) {
                    symbol_var_idx = i;
                    tensors[i] = context.symvar2val(args[i]);
196 197 198
                } else if (
                        DTypePromoteCfg::convert_input_enabled &&
                        op->same_type<Elemwise>()) {
199
                    tensors[i] = convert_pyinput_to_tensor(i);
200 201 202 203
                } else {
                    PyErr_SetString(
                            PyExc_TypeError, "py_apply expects tensor as inputs");
                    return nullptr;
204
                }
205
            }
206
            auto outputs = imperative::apply(*op, tensors);
207
            auto ret = pybind11::tuple(outputs.size());
208
            auto typeobj = py::handle(args[symbol_var_idx]).get_type();
209
            for (size_t i = 0; i < outputs.size(); ++i) {
210
                ret[i] = context.val2symvar(typeobj, outputs[i]);
211 212 213
            }
            return ret.release().ptr();
        }
214 215

        for (size_t i = 0; i < nargs; ++i) {
216
            if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) {
217
                tensors[i] = tw->m_tensor->data();
218 219 220
            } else if (
                    DTypePromoteCfg::convert_input_enabled &&
                    op->same_type<Elemwise>()) {
221
                tensors[i] = convert_pyinput_to_tensor(i);
222 223 224
            } else {
                PyErr_SetString(PyExc_TypeError, "py_apply expects tensor as inputs");
                return nullptr;
225 226 227
            }
        }

228
        auto outputs = [&] { return imperative::apply(*op, tensors); }();
229 230 231
        size_t nout = outputs.size();
        auto ret = py::tuple(nout);
        for (size_t i = 0; i < nout; ++i) {
232
            ret[i] = TensorWrapper::make(py_tensor_type, std::move(outputs[i]));
233 234
        }
        return ret.release().ptr();
M
Megvii Engine Team 已提交
235 236
    }
    PYEXT17_TRANSLATE_EXC_RET(nullptr)
237 238
}

239 240 241 242 243 244 245 246 247 248 249 250 251 252
namespace {

template <typename T>
py::handle py_type() {
    if constexpr (std::is_same_v<T, py::int_>) {
        return (PyObject*)&PyLong_Type;
    } else if constexpr (std::is_same_v<T, py::float_>) {
        return (PyObject*)&PyFloat_Type;
    } else if constexpr (std::is_same_v<T, py::tuple>) {
        return (PyObject*)&PyTuple_Type;
    } else if constexpr (std::is_same_v<T, py::list>) {
        return (PyObject*)&PyList_Type;
    } else {
        static_assert(std::is_same_v<T, T>);
253
    }
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
}

template <typename T>
auto scalar2storage(T val, CompNode cn, DType dtype) {
    using max_ctype_t = DTypeScalar::max_ctype;
    DTypeScalar scalar(dtype);
    scalar.set_retain_dtype(val);
    HostTensorStorage storage(cn);
    auto* raw_ptr = reinterpret_cast<dt_byte*>(new max_ctype_t());
    std::shared_ptr<dt_byte> raw_storage = {
            raw_ptr, [](dt_byte* ptr) { delete reinterpret_cast<max_ctype_t*>(ptr); }};
    storage.only_reset_raw_storage(cn, dtype.size(), raw_storage, 0);
    std::memcpy(storage.ptr(), scalar.storage(), dtype.size());
    return HostStorage::make(std::move(storage));
}

template <typename ctype>
auto vec2storage(Span<DTypeScalar> vec, CompNode cn, DType dtype) {
    mgb_assert(vec.size() <= MEGDNN_MAX_NDIM);
    // TODO: use storage cache and modify ConstTensorCache to return (Host, Device)
    auto* raw_ptr = new ctype[MEGDNN_MAX_NDIM];
    for (size_t i = 0; i < vec.size(); ++i) {
        raw_ptr[i] = vec[i].get_cast<ctype>();
277
    }
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    mgb_assert(sizeof(ctype) == dtype.size());
    std::shared_ptr<dt_byte> raw_storage = {
            reinterpret_cast<dt_byte*>(raw_ptr),
            [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
    HostTensorStorage storage(cn);
    storage.only_reset_raw_storage(cn, sizeof(ctype) * vec.size(), raw_storage, 0);
    return HostStorage::make(std::move(storage));
}

struct HostTensorArgs {
    ValueShape shape;
    DType dtype;
    HostStorage::ref_t storage;

    HostTensorND as_tensor_nd() const {
        HostTensorND ret(CompNode::default_cpu(), shape.as_tensor_shape(), dtype);
        ret.only_reset_raw_storage(*storage);
        return ret;
    }
};

template <typename seq_type, typename ctype>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    auto size = obj.size();
    if (size > MEGDNN_MAX_NDIM) {
        return false;
    }
    ctype items[size];
    for (size_t i = 0; i < size; ++i) {
        py::handle item = obj[i];
        if (item.get_type().is(py_type<py::int_>())) {
            items[i] = (ctype)(dt_int32)item.template cast<py::int_>();
        } else if (item.get_type().is(py_type<py::float_>())) {
            items[i] = (ctype)(dt_float32)item.template cast<py::float_>();
        } else {
            return false;
314
        }
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
    }
    mgb_assert(sizeof(ctype) == dtype.size());
    auto* raw_ptr = new ctype[size];
    std::shared_ptr<dt_byte> raw_storage = {
            reinterpret_cast<dt_byte*>(raw_ptr),
            [](dt_byte* ptr) { delete[] reinterpret_cast<ctype*>(ptr); }};
    HostTensorStorage storage(cn);
    storage.only_reset_raw_storage(cn, sizeof(ctype) * size, raw_storage, 0);
    std::memcpy(storage.ptr(), items, sizeof(ctype) * size);
    ret.dtype = dtype;
    ret.shape = {size};
    ret.storage = HostStorage::make(std::move(storage));
    return true;
}

template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, HostTensorArgs& ret) {
    auto size = obj.size();
    if (size > MEGDNN_MAX_NDIM) {
        return false;
    }
    DTypeScalar items[size];
    DType dtype;
    for (size_t i = 0; i < size; ++i) {
        auto&& item = obj[i];
        if (item.get_type().is(py_type<py::int_>())) {
            items[i] = (dt_int32)item.template cast<py::int_>();
            if (!dtype.valid()) {
                dtype = dtype::Int32();
            } else if (dtype != dtype::Int32() && dtype != dtype::Float32()) {
                return false;
            }
        } else if (item.get_type().is(py_type<py::float_>())) {
            items[i] = (dt_float32)item.template cast<py::float_>();
            if (!dtype.valid()) {
                dtype = dtype::Float32();
            } else if (dtype == dtype::Int32()) {
                dtype = dtype::Float32();
            } else if (dtype != dtype::Float32()) {
                return false;
355
            }
356
        } else {
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545
            return false;
        }
    }
    if (!dtype.valid()) {
        dtype = dtype::Float32();
    }
    ret.dtype = dtype;
    ret.shape = {size};
    if (dtype == dtype::Int32()) {
        ret.storage = vec2storage<dt_int32>({items, size}, cn, dtype);
    } else if (dtype == dtype::Float32()) {
        ret.storage = vec2storage<dt_float32>({items, size}, cn, dtype);
    } else {
        mgb_assert(false);
    }
    return true;
}

template <typename seq_type>
bool pyseq2hval(seq_type obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (dtype == dtype::Int32()) {
        return pyseq2hval<seq_type, dt_int32>(obj, cn, dtype, ret);
    } else if (dtype == dtype::Float32()) {
        return pyseq2hval<seq_type, dt_float32>(obj, cn, dtype, ret);
    } else if (!dtype.valid()) {
        return pyseq2hval<seq_type>(obj, cn, ret);
    } else {
        return false;
    }
}

bool pyarr2hval(py::array obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    auto data = obj.cast<py::array>();
    auto strides = data.strides();
    bool need_squeeze = false;
    for (size_t i = 0; i < data.ndim(); ++i) {
        if (strides[i] == 0) {
            need_squeeze = true;
            break;
        }
    }
    if (need_squeeze) {
        std::vector<size_t> shape;
        for (size_t i = 0; i < data.ndim(); ++i) {
            shape.push_back(data.shape(i));
        }
        data = data.squeeze();
        data.resize(shape);
    }
    HostTensorND retnd(cn);
    retnd = npy::np2tensor(data.ptr(), npy::Meth::copy_into(&retnd), dtype);
    if (!dtype.valid()) {
        dtype = retnd.dtype();
    }
    mgb_assert(
            retnd.layout().is_empty() || retnd.layout().is_contiguous(),
            "host value should be continuous");
    for (size_t i = 0; i < data.ndim(); ++i) {
        ret.shape[ret.shape.ndim++] = data.shape(i);
    }
    ret.dtype = dtype;
    ret.storage = HostStorage::make(retnd.storage());
    return true;
}

bool pyint2hval(py::int_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (!dtype.valid()) {
        dtype = dtype::Int32();
    }
    ret.dtype = dtype;
    ret.storage = scalar2storage((dt_int32)obj, cn, dtype);
    return true;
}

bool pyfloat2hval(py::float_ obj, CompNode cn, DType dtype, HostTensorArgs& ret) {
    if (!dtype.valid()) {
        dtype = dtype::Float32();
    }
    ret.dtype = dtype;
    ret.storage = scalar2storage((dt_float32)obj, cn, dtype);
    return true;
}

HostTensorArgs pyobj2hval(py::object obj, CompNode cn, DType dtype) {
    HostTensorArgs ret;
    bool success = false;
    // check order: float -> int -> tuple(int -> float) -> list(int -> float)
    // only handle `exact` pytype, isinstance also accepts subtype
    // for example, isinstance(True, int) == True
    if (obj.get_type().is(py_type<py::float_>())) {
        success = pyfloat2hval(py::float_(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::int_>())) {  // py::bool_ is py::int_
        success = pyint2hval(py::int_(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::tuple>())) {
        success = pyseq2hval<py::tuple>(py::tuple(obj), cn, dtype, ret);
    } else if (obj.get_type().is(py_type<py::list>())) {
        success = pyseq2hval<py::list>(py::list(obj), cn, dtype, ret);
    } else if (obj.is_none()) {
        obj = py::list(0);
    }
    if (!success) {
        success = pyarr2hval(obj, cn, dtype, ret);
    }
    mgb_assert(success);
    return ret;
}

struct PyArgDesc {
    const char* name;
    py::object (*default_value)();
};

struct PyArgDescs {
    std::vector<PyArgDesc> items;
    ssize_t (*name2idx)(const char* name);
};

py::tuple parse_args(py::tuple args, const PyArgDescs& descs) {
    size_t nr_args = args.size();
    size_t nr_items = descs.items.size();
    mgb_assert(nr_args <= nr_items, "too many args");
    if (nr_args == nr_items) {
        return args;
    }
    py::tuple ret(nr_items);
    for (size_t i = 0; i < nr_args; ++i) {
        ret[i] = args[i];
    }
    for (size_t i = nr_args; i < nr_items; ++i) {
        ret[i] = descs.items[i].default_value();
    }
    return ret;
}

py::tuple parse_args_and_kwargs(
        py::tuple args, py::dict kwargs, const PyArgDescs& descs) {
    size_t nr_args = args.size();
    size_t nr_kwargs = kwargs.size();
    size_t nr_items = descs.items.size();
    mgb_assert(nr_args + nr_kwargs <= nr_items, "too many args");
    if (nr_args == nr_items) {
        return args;
    }
    py::tuple ret(nr_items);
    for (size_t i = 0; i < nr_args; ++i) {
        ret[i] = args[i];
    }
    bool has_value[nr_items - nr_args];
    for (size_t i = nr_args; i < nr_items; ++i) {
        has_value[i - nr_args] = false;
    }
    for (auto&& [k, v] : kwargs) {
        auto key = py::str(k).cast<std::string>();
        ssize_t index = descs.name2idx(key.c_str());
        mgb_assert(index >= nr_args);
        ret[index] = v;
        has_value[index - nr_args] = true;
    }
    for (size_t i = nr_args; i < nr_items; ++i) {
        if (!has_value[i - nr_args]) {
            ret[i] = descs.items[i].default_value();
        }
    }
    return ret;
}

CompNode as_comp_node(const std::string& name) {
    thread_local struct {
        std::string name;
        CompNode cn;
    } cached;
    if (cached.name != name) {
        cached.name = name;
        cached.cn = CompNode::load(name);
    }
    return cached.cn;
}

CompNode as_comp_node(py::object py_device) {
    std::optional<std::string> device_name;
    if (py_device.is_none() || py::str::check_(py_device)) {
        auto cls = py::handle(reinterpret_cast<PyObject*>(py_tensor_type));
        auto dmap_callback = cls.attr("dmap_callback");
        std::string name;
        if (dmap_callback.is_none() && py_device.is_none()) {
            name = get_default_device();
        } else {
            if (py_device.is_none()) {
                py_device = py::str(get_default_device());
546
            }
547 548
            if (!dmap_callback.is_none()) {
                py_device = dmap_callback(py_device);
549
            }
550 551 552 553 554 555 556 557 558 559 560
            name = py::str(py_device).cast<std::string>();
        }
        return as_comp_node(name);
    } else {
        if (py::isinstance(py_device, py_device_type)) {
            py_device = py_device.attr("_cn");
        }
        mgb_assert(py::isinstance(py_device, py_comp_node_type));
        return py_device.cast<CompNode>();
    }
}
561

562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
template <char... Chars>
bool compare_cstr(const char* cstr) {
    return (((*cstr++) == Chars) && ...) && *cstr == '\0';
}

ssize_t name2idx(const char* name) {
    const char* ch = name;
    // TODO: trie
    // clang-format off
    switch (*ch++) {
    case 'd':
        switch (*ch++) {
        // data
        case 'a': return compare_cstr<'t', 'a'>(ch) ? 0 : -1;
        // dtype
        case 't': return compare_cstr<'y', 'p', 'e'>(ch) ? 1 : -1;
        // device
        case 'e': return compare_cstr<'v', 'i', 'c', 'e'>(ch) ? 2 : -1;
        }
    case 'i':
        // is_const
        return compare_cstr<'s', '_', 'c', 'o', 'n', 's', 't'>(ch) ? 3 : -1;
    case 'n':
        switch (*ch++) {
        // no_cache
        case 'o': return compare_cstr<'_', 'c', 'a', 'c', 'h', 'e'>(ch) ? 4 : -1;
        // name
        case 'a': return compare_cstr<'m', 'e'>(ch) ? 5 : -1;
        }
    }
    // clang-format on
    return -1;
}

}  // namespace

TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
    static PyArgDescs descs = {
            {
                    {"data", []() -> py::object { return py::none(); }},
                    {"dtype", []() -> py::object { return py::none(); }},
                    {"device", []() -> py::object { return py::none(); }},
                    {"is_const", []() -> py::object { return py::bool_(false); }},
                    {"no_cache", []() -> py::object { return py::bool_(false); }},
                    {"name", []() -> py::object { return py::none(); }},
            },
            name2idx};
    py::detail::loader_life_support life_sup;  // FIXME!!!required to cast DType
    auto tup = py::reinterpret_borrow<py::tuple>(args);
    if (kwargs) {
        tup = parse_args_and_kwargs(
                tup, py::reinterpret_borrow<py::dict>(kwargs), descs);
    } else {
        tup = parse_args(tup, descs);
    }
    mgb_assert(tup.size() == 6);
    if (auto* t = try_cast(tup[0].ptr())) {
        m_tensor = t->m_tensor->copy();
    } else {
        auto data = tup[0];
        DType dtype = tup[1].cast<DType>();
        bool is_const = tup[3].cast<bool>();
        bool no_cache = tup[4].cast<bool>();
        std::string name;
        if (!tup[5].is_none()) {
            name = tup[5].cast<std::string>();
        }
        CompNode cn = as_comp_node(tup[2]);

        {
            CreateTensor::Kind kind = is_const ? CreateTensor::Const
                                    : no_cache ? CreateTensor::Unique
                                               : CreateTensor::Common;
            auto&& hval = pyobj2hval(data, cn, dtype);
            auto val = imperative::apply(
                    CreateTensor(kind, cn, hval.dtype, hval.shape), hval.storage)[0];
            m_tensor.emplace(val);
        }

        if (!name.empty()) {
            m_tensor->reset(imperative::apply(RenameValue(name), m_tensor->data())[0]);
643 644
        }
    }
645
    mgb_assert(m_tensor->data());
646 647
}

648
PyObject* TensorWrapper::module_trace_info() {
649
    if (auto module_trace_info = module_trace_info_map.try_get(m_tensor->data())) {
650 651 652
        if (module_trace_info->ptr()) {
            return module_trace_info->inc_ref().ptr();
        }
653
    }
654 655 656 657 658
    PyErr_SetString(
            PyExc_AttributeError,
            "Has no attribute named \'_NodeMixin__node\', please "
            "set it first");
    return nullptr;
659 660 661
}

void TensorWrapper::set_module_trace_info(PyObject* obj) {
662
    // TODO: erase when obj == nullptr
663
    module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj);
664 665
}

666 667 668 669 670
void TensorWrapper::_set_name(PyObject* dest) {
    auto py_dest = py::reinterpret_borrow<py::object>(dest);
    auto name = py_dest.cast<std::string>();
    m_tensor->set_name(name);
}
671

672 673
PyObject* TensorWrapper::_detail() {
    return py::str(m_tensor->data().unwrap().to_string()).release().ptr();
674 675
}

676 677
void TensorWrapper::_watch() {
    m_tensor->data().watch();
678 679
}

680
PyObject* TensorWrapper::shape() {
681
    auto shape = m_tensor->shape();
682

683
    if (!shape) {
684 685
        Py_RETURN_NONE;
    }
686 687 688
    py::tuple ret(shape->ndim);
    for (size_t i = 0; i < shape->ndim; ++i) {
        ret[i] = shape->at(i);
689 690 691 692 693 694 695 696 697 698 699 700 701
    }
    return ret.release().ptr();
}

PyObject* TensorWrapper::dtype() {
    return py::cast(m_tensor->dtype()).release().ptr();
}

PyObject* TensorWrapper::device() {
    return py::cast(m_tensor->comp_node()).release().ptr();
}

PyObject* TensorWrapper::numpy() {
702
    auto hv = m_tensor->numpy();
703
    if (!hv) {
704 705 706
        PyErr_SetString(PyExc_ValueError, "tensor invalid");
        return nullptr;
    }
707 708
    auto arr = py::reinterpret_steal<py::array>(
            npy::ndarray_from_tensor(hv->as_nd(true), npy::ShareType::TRY_SHARE));
709
    if (hv->shape().is_scalar()) {
710 711 712 713 714 715 716
        mgb_assert(PyArray_Check(arr.ptr()));
        return PyArray_Squeeze(reinterpret_cast<PyArrayObject*>(arr.ptr()));
    }
    return arr.release().ptr();
}

void TensorWrapper::reset(PyObject* tensor) {
717
    TensorWrapper* t = TensorWrapper::try_cast(tensor);
718 719 720
    if (!t) {
        throw py::type_error("expect Tensor");
    }
721
    m_tensor->reset(t->m_tensor->data());
722 723
}

724
PyObject* TensorWrapper::detach() {
725 726
    auto detached = imperative::apply(DetachGrad(), m_tensor->data())[0];
    return TensorWrapper::make(py_tensor_type, detached).release().ptr();
727 728
}

M
Megvii Engine Team 已提交
729
PyObject* TensorWrapper::_dev_tensor() {
730 731 732
    auto dv = m_tensor->data().dev_tensor();
    // TODO: handle scalar
    return py::cast(dv->as_nd(true)).release().ptr();
733 734 735
}

void TensorWrapper::_drop() {
736
    imperative::apply(DTRCommand(DTRCommand::Drop), m_tensor->data());
737 738
}

739
PyObject* TensorWrapper::isscalar() {
740
    if (m_tensor->is_scalar()) {
741 742 743 744 745 746 747
        Py_RETURN_TRUE;
    } else {
        Py_RETURN_FALSE;
    }
}

struct TensorWeakRef {
748
    ValueWeakRef data;
749

750
    TensorWeakRef(const TensorWrapper& tw) : data(tw.m_tensor->data()) {}
751 752

    py::object operator()() {
753
        if (auto p = data.lock()) {
754
            return TensorWrapper::make(py_tensor_type, p);
755 756 757 758 759
        }
        return py::none();
    }
};

760 761 762 763 764 765 766 767 768 769 770 771 772
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
#else
#define WRAP_FUNC_PY35(FUNC)                                \
    PyObject* py35_##FUNC(PyObject* self, PyObject* args) { \
        auto* arr = &PyTuple_GET_ITEM(args, 0);             \
        auto size = PyTuple_GET_SIZE(args);                 \
        return FUNC(self, arr, size);                       \
    }
WRAP_FUNC_PY35(py_apply);
WRAP_FUNC_PY35(dtype_promotion);
WRAP_FUNC_PY35(get_device);
773 774 775
WRAP_FUNC_PY35(make_shape_tuple);
WRAP_FUNC_PY35(getitem_cpp);
WRAP_FUNC_PY35(setitem_cpp);
776
WRAP_FUNC_PY35(split_cpp);
777
WRAP_FUNC_PY35(expand_dims_cpp);
778
WRAP_FUNC_PY35(squeeze_cpp);
779
WRAP_FUNC_PY35(transpose_cpp);
780 781
WRAP_FUNC_PY35(broadcast_cpp);
WRAP_FUNC_PY35(reshape_cpp);
782
WRAP_FUNC_PY35(adaptive_pool2d_cpp);
783
WRAP_FUNC_PY35(Const);
784
WRAP_FUNC_PY35(astype_cpp);
785 786
WRAP_FUNC_PY35(matmul_cpp);
WRAP_FUNC_PY35(batched_matmul_cpp);
787 788
WRAP_FUNC_PY35(convert_single_value_cpp);
WRAP_FUNC_PY35(convert_inputs_cpp);
789
WRAP_FUNC_PY35(astensor1d_cpp);
790
WRAP_FUNC_PY35(pixel_shuffle_cpp);
791 792 793 794 795
#undef WRAP_FUNC_PY35
#define MGE_PY_INTERFACE(NAME, FUNC) \
    { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr }
#endif

796
void init_tensor(py::module m) {
797
    imperative::Tensor::static_initialize();
798 799 800 801 802

    static auto& transformations = TransformationManager::get_instance();

    using Segment = TransformationManager::Segment;

803 804 805 806 807 808
    using Channel = interpreter::Interpreter::Channel;

    auto* channel =
            imperative::ResourceManager::create_global<std::unique_ptr<Channel>>(
                    interpreter::Interpreter::inst().create_channel())
                    ->get();
809
    interpreter_for_py = channel;
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827
    MGB_MARK_USED_VAR(
            transformations
                    .register_at<Segment::Eval>(
                            std::make_shared<InterpreterTransformation>(
                                    std::shared_ptr<Channel>(channel, [](Channel*) {})))
                    .release());
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::Scalar>(
                                      std::make_shared<ScalarTransformation>())
                              .release());
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::DTypePromote>(
                                      std::make_shared<DTypePromoteTransformation>())
                              .release());
    MGB_MARK_USED_VAR(transformations
                              .register_at<Segment::DimExpansion>(
                                      std::make_shared<DimExpansionTransformation>())
                              .release());
828

M
Megvii Engine Team 已提交
829 830
    static py::exception<interpreter::AsyncError> py_async_error(
            m, "AsyncError", PyExc_RuntimeError);
831 832
    py::register_exception_translator([](std::exception_ptr p) {
        try {
M
Megvii Engine Team 已提交
833 834
            if (p)
                std::rethrow_exception(p);
835 836 837 838 839 840 841 842 843 844
        } catch (const interpreter::AsyncError& e) {
            pyext17::pybind11_translate_exception(e.nested_ptr());
            if (PyErr_Occurred()) {
                PyObject *exc, *val, *tb;
                PyErr_Fetch(&exc, &val, &tb);
                PyErr_NormalizeException(&exc, &val, &tb);
                if (tb) {
                    PyException_SetTraceback(val, tb);
                }
                auto val2 = py_async_error.py::object::operator()(
M
Megvii Engine Team 已提交
845 846
                        "An async error is reported. See above for the actual cause."
                        " Hint: This is where it is reported, not where it happened."
847
                        " You may call `megengine.config.async_level = 0 "
M
Megvii Engine Team 已提交
848 849 850
                        "to get better error reporting.");
                PyException_SetCause(
                        val2.ptr(), val);  // PyException_SetCause steals reference
851 852
                Py_XDECREF(exc);
                Py_XDECREF(tb);
M
Megvii Engine Team 已提交
853 854
                PyErr_Restore(
                        py_async_error.inc_ref().ptr(), val2.release().ptr(), nullptr);
855 856 857 858 859 860
            } else {
                py_async_error("Unkown async error");
            }
        }
    });

M
Megvii Engine Team 已提交
861 862 863 864 865 866 867 868 869
    auto* tensor_type =
            TensorWrapper::wrap_t::type()
                    .def<&TensorWrapper::numpy>("numpy")
                    .def_getset<&TensorWrapper::shape>("shape")
                    .def_getset<&TensorWrapper::dtype>("dtype")
                    .def_getset<&TensorWrapper::device>("device")
                    .def<&TensorWrapper::reset>("_reset")
                    .def<&TensorWrapper::isscalar>("_isscalar")
                    .def<&TensorWrapper::detach>("detach")
870
                    // TODO: remove this
M
Megvii Engine Team 已提交
871 872
                    .def<&TensorWrapper::_dev_tensor>("_dev_tensor")
                    .def<&TensorWrapper::_drop>("_drop")
873 874 875
                    .def<&TensorWrapper::_detail>("_detail")
                    .def<&TensorWrapper::_set_name>("_set_name")
                    .def<&TensorWrapper::_watch>("_watch")
M
Megvii Engine Team 已提交
876 877 878 879 880 881
                    .def_getset<
                            &TensorWrapper::module_trace_info,
                            &TensorWrapper::set_module_trace_info>("_NodeMixin__node")
                    .finalize();
    if (!tensor_type)
        throw py::error_already_set();
882 883 884
    py::setattr(m, "Tensor", tensor_type);

    py::class_<TensorWeakRef>(m, "TensorWeakRef")
M
Megvii Engine Team 已提交
885
            .def(py::init<const TensorWrapper&>())
886
            .def("__call__", &TensorWeakRef::operator());
887

888 889 890
    py::class_<PySymbolVar, std::shared_ptr<PySymbolVar>>(m, "SymbolVar")
            .def_property_readonly(
                    "dtype", [](PySymbolVar* v) { return v->m_node->dtype(); })
M
Megvii Engine Team 已提交
891 892 893
            .def_property(
                    "var", [](PySymbolVar* v) { return v->m_node; },
                    [](PySymbolVar* s, cg::VarNode* v) { s->m_node = v; })
894
            .def_property_readonly(
M
Megvii Engine Team 已提交
895
                    "device", [](PySymbolVar* v) { return v->m_node->comp_node(); })
896
            .def_property_readonly(
M
Megvii Engine Team 已提交
897
                    "graph", [](PySymbolVar* v) { return v->m_node->owner_graph(); })
898 899 900
            .def_property_readonly(
                    "shape",
                    [](PySymbolVar* v) -> const TensorShape* {
M
Megvii Engine Team 已提交
901
                        auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
902 903
                        return mgr.infer_shape_fallible(v->m_node);
                    })
M
Megvii Engine Team 已提交
904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
            .def("numpy",
                 [](PySymbolVar* v) {
                     auto&& mgr = v->m_node->owner_graph()->static_infer_manager();
                     auto&& type = mgr.get_infer_type(v->m_node);
                     using InferType = cg::static_infer::InferType;
                     if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) {
                         throw py::value_error("value invalid!");
                     }
                     auto* val = mgr.infer_value_fallible(v->m_node);
                     if (!val) {
                         throw py::value_error("value invalid!");
                     }
                     auto np_val = py::cast(*val).attr("numpy")();
                     return np_val;
                 })
919 920 921 922 923 924
            .def("_isscalar", [](PySymbolVar* v) { return v->is_scalar; })
            .def(py::init([](cg::VarNode* node) {
                     return std::make_shared<PySymbolVar>(node);
                 }),
                 py::arg() = nullptr);

925
    static PyMethodDef method_defs[] = {
926 927 928
            MGE_PY_INTERFACE(apply, py_apply),
            MGE_PY_INTERFACE(dtype_promotion, dtype_promotion),
            MGE_PY_INTERFACE(get_device, get_device),
929 930 931
            MGE_PY_INTERFACE(make_shape_tuple, make_shape_tuple),
            MGE_PY_INTERFACE(getitem_cpp, getitem_cpp),
            MGE_PY_INTERFACE(setitem_cpp, setitem_cpp),
932
            MGE_PY_INTERFACE(split_cpp, split_cpp),
933
            MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp),
934
            MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp),
935
            MGE_PY_INTERFACE(transpose_cpp, transpose_cpp),
936 937
            MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp),
            MGE_PY_INTERFACE(reshape_cpp, reshape_cpp),
938
            MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp),
939
            MGE_PY_INTERFACE(Const, Const),
940
            MGE_PY_INTERFACE(astype_cpp, astype_cpp),
941 942
            MGE_PY_INTERFACE(matmul_cpp, matmul_cpp),
            MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp),
943 944
            MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp),
            MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp),
945
            MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp),
946
            MGE_PY_INTERFACE(pixel_shuffle_cpp, pixel_shuffle_cpp),
947
            {nullptr, nullptr, 0, nullptr}};
M
Megvii Engine Team 已提交
948
    for (auto&& def : method_defs) {
949 950
        if (def.ml_meth != nullptr) {
            auto* func = PyCFunction_NewEx(&def, nullptr, nullptr);
M
Megvii Engine Team 已提交
951 952
            if (!func)
                throw py::error_already_set();
953 954 955
            py::setattr(m, def.ml_name, func);
        }
    }
956

957 958 959 960
    static constexpr auto sync_py_task_q = [] {
        py::gil_scoped_release _;
        py_task_q.wait_all_task_finish();
    };
961

962
    m.def("clear_candidates", [channel]() { channel->clear_candidates(); });
963 964
    m.def("set_option", [channel](std::string name, size_t value) {
        channel->set_option(name, value);
M
Megvii Engine Team 已提交
965
    });
966
    m.def("get_option",
967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
          [channel](std::string name) { return channel->get_option(name); });
    m.def("push_scope", [channel](std::string name) {
        Transformation::push_scope(name);
        channel->push_scope(name);
    });
    m.def("pop_scope", [channel](std::string name) {
        channel->pop_scope(name);
        Transformation::pop_scope(name);
    });
    m.def("start_profile", [channel](imperative::Profiler::options_t options) {
        channel->sync();
        imperative::Profiler::load_options(std::move(options));
        imperative::Profiler::start_profile();
        channel->start_profile();
    });
    m.def("stop_profile", [channel]() -> std::function<void(std::string, std::string)> {
        channel->stop_profile();
        channel->sync();
        imperative::Profiler::stop_profile();
        auto results = std::make_shared<imperative::Profiler::bundle_t>(
                imperative::Profiler::collect());
        return [results = results](std::string basename, std::string format) mutable {
            imperative::Profiler::dump_profile(basename, format, std::move(*results));
            results = nullptr;
        };
    });
    m.def("sync", [channel]() {
        if (channel->check_available()) {
            channel->sync();
        }
        sync_py_task_q();
    });
    m.def("full_sync", [channel]() {
        if (channel->check_available()) {
            channel->sync();
        }
        CompNode::sync_all();
        CompNode::foreach ([](CompNode cn) {
            auto err = cn.check_async_error();
            mgb_assert(!err, "%s", err->what());
        });
        sync_py_task_q();
    });
    m.def("close", [channel]() {
        channel->close();
        sync_py_task_q();
M
Megvii Engine Team 已提交
1013 1014 1015 1016 1017 1018 1019 1020
    });

    py::handle grad_key_type =
            GradKeyWrapper::wrap_t::type()
                    .def<&GradKeyWrapper::attach>("attach")
                    .def<&GradKeyWrapper::is_attached_to>("is_attached_to")
                    .def_getset<&GradKeyWrapper::get_name, &GradKeyWrapper::set_name>(
                            "name")
1021 1022 1023 1024
                    .def<&GradKeyWrapper::enter>("enter")
                    .def<&GradKeyWrapper::exit>("exit")
                    .def<&GradKeyWrapper::suppress>("suppress")
                    .def<&GradKeyWrapper::resume>("resume")
M
Megvii Engine Team 已提交
1025 1026 1027
                    .finalize();
    if (!grad_key_type)
        throw py::error_already_set();
1028
    py::setattr(m, "GradKey", grad_key_type);
1029
    m.def("backward", &GradKeyWrapper::backward);
1030
    m.def("get_backward_closure", &GradKeyWrapper::get_backward_closure);
1031

1032 1033 1034 1035
    m.def("set_py_tensor_type", [](py::object type_obj) {
        py_tensor_type = reinterpret_cast<PyTypeObject*>(type_obj.inc_ref().ptr());
    });

1036 1037 1038
    m.def("set_py_device_type",
          [](py::object type_obj) { py_device_type = type_obj.inc_ref(); });

1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055
    /**
     * \brief trace proxy
     *
     */
    struct Trace {
        bool symbolic = false;
        bool no_exec = false;
        bool capture_as_const = false;
        bool profile = false;
        bool record_input_shapes = false;
        py::function options_visitor;
        std::shared_ptr<TracingTransformation> tracing;
        std::shared_ptr<CompiledTransformation> compiled;
        std::shared_ptr<LazyEvalTransformation> lazy_eval;
        std::pair<size_t, std::shared_ptr<GraphProfiler>> profiler;
        std::optional<TraceResult> trace_result;
        std::function<bool(py::object, py::object)> array_comparator;
1056 1057 1058
        std::unique_ptr<CleanupGuard<>> tracing_guard;
        std::unique_ptr<CleanupGuard<>> compiled_guard;
        std::unique_ptr<CleanupGuard<>> lazy_eval_guard;
1059 1060

        bool compare_value(ValueRef lhs, ValueRef rhs) {
1061 1062
            auto lvalue = lhs.cast_ref<HostValue>();
            auto rvalue = rhs.cast_ref<HostValue>();
1063
            if (lvalue->shape() != rvalue->shape()) {
1064 1065
                return false;
            }
1066
            if (lvalue->shape().total_nr_elems() == 1) {
1067 1068 1069 1070
                return lvalue->item() == rvalue->item();
            }
            HostTensorND lnd = lvalue->as_nd(true);
            HostTensorND rnd = rvalue->as_nd(true);
1071
            auto larr = py::reinterpret_steal<py::array>(
1072
                    npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
1073
            auto rarr = py::reinterpret_steal<py::array>(
1074
                    npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107
            return array_comparator(larr, rarr);
        }

        void enter() {
            auto& self = *this;
            if (!self.trace_result) {  // untraced
                self.tracing = std::make_shared<TracingTransformation>(
                        self.capture_as_const, self.record_input_shapes);
                if (self.symbolic) {
                    self.lazy_eval =
                            std::make_shared<LazyEvalTransformation>(self.no_exec);
                    self.options_visitor(py::cast(&self.lazy_eval->options()));
                }
            } else if (!self.compiled) {  // traced but not compiled
                using namespace std::placeholders;
                self.compiled = std::make_shared<CompiledTransformation>(
                        *self.trace_result, self.record_input_shapes);
                self.compiled->set_value_comparator(
                        std::bind(&Trace::compare_value, this, _1, _2));
                self.options_visitor(py::cast(&self.compiled->options()));
                self.compiled->compile();
            }
            // register transformations
            if (self.compiled) {
                if (self.profile) {
                    auto& current_graph = self.compiled->graph();
                    if (self.profiler.first != self.compiled->graph().id()) {
                        // graph changed
                        self.profiler = std::make_pair(
                                current_graph.id(),
                                std::make_shared<GraphProfiler>(&current_graph));
                    }
                }
1108 1109
                compiled_guard =
                        transformations.register_at<Segment::Trace>(self.compiled);
1110 1111 1112
                // start execute because InputCallback depends
                self.compiled->execute();
            } else if (self.tracing) {
1113 1114
                tracing_guard =
                        transformations.register_at<Segment::Trace>(self.tracing);
1115
                if (self.lazy_eval) {
1116 1117
                    lazy_eval_guard =
                            transformations.register_at<Segment::Eval>(self.lazy_eval);
1118 1119 1120 1121 1122 1123 1124 1125 1126
                }
            } else {
                mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
            }
        }

        void exit() {
            auto& self = *this;
            if (self.tracing) {
1127
                tracing_guard.reset();
1128 1129 1130 1131
                self.trace_result = self.tracing->get_result();
                self.tracing.reset();
                if (self.lazy_eval) {
                    auto lazy_eval = std::move(self.lazy_eval);
1132
                    lazy_eval_guard.reset();
1133 1134 1135
                    lazy_eval->check_exception();
                }
            } else if (self.compiled) {
1136
                compiled_guard.reset();
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
                self.compiled->wait();
            } else {
                mgb_throw(MegBrainError, "invalid state: neither tracing nor compiled");
            }
        }

        VarNodeArray dump(
                std::shared_ptr<ComputingGraph> graph,
                std::vector<std::tuple<std::string, std::string, TensorShape>> inputs,
                std::vector<std::pair<std::string, std::string>> outputs,
                bool prefer_input_names) {
            auto& self = *this;
            mgb_assert(self.trace_result);
            // mark is like "arg_0", "kwarg_xxx", "output_0" ...
            std::unordered_map<std::string, size_t> mark2var;
            for (size_t i = 0; i < self.trace_result->vars.size(); ++i) {
                auto& name = self.trace_result->vars[i].mark;
                if (!name.empty()) {
                    mark2var[name] = i;
                }
            }
            std::vector<std::tuple<size_t, std::string, TensorShape>> input_vars;
            std::vector<std::pair<size_t, std::string>> output_vars;
            for (auto&& [input_mark, input_name, input_shape] : inputs) {
                mgb_assert(input_shape.ndim, "input shape invalid");
                input_vars.push_back(
                        {mark2var.at(input_mark), input_name, input_shape});
            }
            for (auto&& [output_name, repr] : outputs) {
                output_vars.push_back({mark2var.at(output_name), repr});
            }
            self.options_visitor(py::cast(&graph->options()));
            auto vars = self.trace_result->dump(
                    *graph, input_vars, output_vars, prefer_input_names);
            return vars;
        }
    };

    py::class_<Trace>(m, "Trace")
            .def(py::init<>())
            .def_readwrite("record_input_shapes", &Trace::record_input_shapes)
            .def_readwrite("array_comparator", &Trace::array_comparator)
            .def_readwrite("profile", &Trace::profile)
            .def_property_readonly(
                    "options",
                    [](Trace& self) {
                        if (self.compiled) {
                            return &self.compiled->options();
                        } else {
                            return (ComputingGraph::Options*)nullptr;
                        }
                    })
            .def("get_profile",
                 [](Trace& self) -> py::object {
                     if (self.profiler.second && self.compiled) {
                         auto json = self.profiler.second->to_json_full(
                                 self.compiled->graph().current_comp_seq());
                         return py::str(json->to_string());
                     } else {
                         return py::none();
                     }
                 })
            .def_readwrite("symbolic", &Trace::symbolic)
            .def_readwrite("capture_as_const", &Trace::capture_as_const)
            .def_readwrite("no_exec", &Trace::no_exec)
            .def_readwrite("options_visitor", &Trace::options_visitor)
            .def("enter", &Trace::enter)
            .def("exit", &Trace::exit)
            .def("dump", &Trace::dump)
            .def("begin_excluded_region",
                 [](Trace& self) {
                     mgb_assert(bool(self.tracing) ^ bool(self.compiled));
                     if (self.tracing) {
1210
                         self.tracing_guard.reset();
1211
                     } else if (self.compiled) {
1212
                         self.compiled_guard.reset();
1213
                     }
M
Megvii Engine Team 已提交
1214
                 })
1215 1216 1217
            .def("end_excluded_region", [](Trace& self) {
                mgb_assert(bool(self.tracing) ^ bool(self.compiled));
                if (self.tracing) {
1218 1219
                    self.tracing_guard =
                            transformations.register_at<Segment::Trace>(self.tracing);
1220
                } else if (self.compiled) {
1221 1222
                    self.compiled_guard =
                            transformations.register_at<Segment::Trace>(self.compiled);
1223 1224 1225
                }
            });

1226 1227 1228 1229 1230 1231 1232 1233
    m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
        auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
            auto make_scalar_shape = [&](CompNode device) {
                return imperative::apply(
                        CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
                        HostStorage::make(device))[0];
            };
            return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
1234
        };
1235 1236 1237 1238 1239
        if (py::isinstance<PySymbolVar>(tensor)) {
            auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
            SymbolVarContext context(graph);
            context.init();
            auto output = reduce_to_scalar(
1240
                    *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
1241
            auto typeobj = tensor.get_type();
1242
            return context.val2symvar(typeobj, output);
1243 1244 1245 1246 1247 1248
        } else {
            auto* tw = TensorWrapper::try_cast(tensor.ptr());
            auto output = reduce_to_scalar(
                    *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
            return TensorWrapper::make(py_tensor_type, output);
        }
1249 1250
    });

1251 1252 1253 1254 1255 1256 1257
    m.def("name_tensor", [](std::string name, py::object tensor) {
        auto* tw = TensorWrapper::try_cast(tensor.ptr());
        auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
        tw->m_tensor->reset(output);
    });

    m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
1258
        SmallVector<ValueRef> values(tensors.size());
1259 1260
        for (size_t i = 0; i < tensors.size(); ++i) {
            values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270
        }
        auto outputs = imperative::apply(GetGradKey(), values);
        if (outputs[0].is<GradKeyValue>()) {
            return true;
        } else {
            return false;
        }
    });

    m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
1271
        SmallVector<ValueRef> values(tensors.size());
1272 1273
        for (size_t i = 0; i < tensors.size(); ++i) {
            values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
1274
        }
1275 1276
        auto output = imperative::apply(GetGradKey(), values)[0];
        if (!output) {
1277 1278
            return py::none();
        }
1279 1280
        return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
                GradKeyWrapper::get(output.cast<GradKeyValue>())));
1281 1282
    });

1283
    m.def("set_grad", [](py::function backward_fn, std::vector<py::object> inputs,
1284 1285
                         std::vector<py::object> outputs) {
        GenericFunction generic_backward_fn =
1286
                [backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
1287 1288 1289 1290 1291 1292 1293 1294 1295 1296
            py::list output_grad_tws;
            for (auto&& output_grad : output_grads) {
                if (output_grad) {
                    output_grad_tws.append(
                            TensorWrapper::make(py_tensor_type, output_grad));
                } else {
                    output_grad_tws.append(py::none());
                }
            }
            py::tuple input_grad_tws = backward_fn(*output_grad_tws);
1297 1298 1299
            ValueRefList input_grads(input_grad_tws.size());
            for (size_t i = 0; i < input_grad_tws.size(); ++i) {
                auto input_grad_tw = input_grad_tws[i];
1300
                if (!input_grad_tw.is_none()) {
1301 1302
                    input_grads[i] =
                            py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
1303
                } else {
1304
                    input_grads[i] = {};
1305 1306 1307 1308
                }
            }
            return input_grads;
        };
1309
        SmallVector<ValueRef> values(inputs.size() + outputs.size());
1310 1311
        for (size_t i = 0; i < inputs.size(); ++i) {
            values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
1312
        }
1313 1314 1315
        for (size_t i = 0; i < outputs.size(); ++i) {
            values[i + inputs.size()] =
                    outputs[i].cast<TensorWrapper>().m_tensor->data();
1316
        }
1317 1318
        auto wrapped_output_values =
                imperative::apply(SetGrad(generic_backward_fn, inputs.size()), values);
1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329
        std::vector<py::object> wrapped_outputs;
        mgb_assert(wrapped_output_values.size() == outputs.size());
        for (auto&& output_value : wrapped_output_values) {
            wrapped_outputs.push_back(
                    TensorWrapper::make(py_tensor_type, output_value));
        }
        return wrapped_outputs;
    });

    static py::function module_trace_hook;

1330 1331
    static auto get_module_trace = [] {
        static std::shared_ptr<ModuleTraceTransformation> module_trace_transformation;
1332 1333 1334 1335
        if (!module_trace_transformation) {
            mgb_assert(module_trace_hook);
            module_trace_transformation =
                    std::make_shared<ModuleTraceTransformation>(module_trace_hook);
1336 1337 1338 1339
            MGB_MARK_USED_VAR(transformations
                                      .register_at<Segment::ModuleTrace>(
                                              module_trace_transformation)
                                      .release());
1340
        }
1341 1342
        return module_trace_transformation;
    };
1343

1344 1345
    m.def("set_cpp_use_symbolic_shape", &set_cpp_use_symbolic_shape);

1346 1347 1348
    m.def("set_module_tracing", [=] { get_module_trace()->enable(); });

    m.def("unset_module_tracing", [=] { get_module_trace()->disable(); });
1349

1350
    m.def("is_tracing_module", [=] { return get_module_trace()->enabled(); });
1351

1352 1353 1354 1355
    m.def("set_module_trace_hook", [](py::function function) {
        module_trace_hook = function;
        module_trace_hook.inc_ref();
    });
1356

1357 1358 1359
    auto atexit = py::module::import("atexit");
    atexit.attr("register")(py::cpp_function([]() { module_trace_hook = {}; }));

1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
    m.def("begin_record_values", [] { Value::begin_record_values(); });

    m.def("end_record_values", [] {
        std::vector<std::pair<size_t, std::string>> reprs;
        auto values = Value::end_record_values();
        for (auto&& value : values) {
            reprs.push_back({value.id(), value.to_string()});
        }
        return reprs;
    });

1371
    m.def("print_stats", [] { Stats::print(); });
1372

1373
    m.def("reset_stats", [] { Stats::reset(); });
1374

1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431
    m.def("_get_convert_inputs",
          []() -> bool { return DTypePromoteCfg::convert_input_enabled; });
    m.def("_set_convert_inputs", [](bool flag) -> bool {
        bool ret = DTypePromoteCfg::convert_input_enabled;
        DTypePromoteCfg::convert_input_enabled = flag;
        return ret;
    });
    m.def("_get_amp_dtype_autocast",
          []() -> bool { return DTypePromoteCfg::amp_dtype_autocast_enabled; });
    m.def("_set_amp_dtype_autocast", [](bool flag) -> bool {
        bool ret = DTypePromoteCfg::amp_dtype_autocast_enabled;
        DTypePromoteCfg::amp_dtype_autocast_enabled = flag;
        return ret;
    });

    static auto get_amp_prec_dtype = [](bool is_high) -> std::string {
        DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
                                : DTypePromoteCfg::amp_low_prec_dtype;
        mgb_assert(target.category() == DTypeCategory::FLOAT);
        std::string ret = target.name();
        transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
        return ret;
    };

    static auto set_amp_prec_dtype = [](bool is_high,
                                        std::string dtype_name) -> std::string {
        DType& target = is_high ? DTypePromoteCfg::amp_high_prec_dtype
                                : DTypePromoteCfg::amp_low_prec_dtype;
        std::string ret = target.name();

        if (dtype_name == "float32") {
            target = dtype::Float32();
        } else if (dtype_name == "float16") {
            target = dtype::Float16();
        } else if (dtype_name == "bfloat16") {
            target = dtype::BFloat16();
        } else {
            mgb_assert(
                    false, "casted type of amp should be float, but you give %s\n",
                    dtype_name.c_str());
        }

        transform(ret.begin(), ret.end(), ret.begin(), ::tolower);
        return ret;
    };

    m.def("_get_amp_high_prec_dtype",
          []() -> std::string { return get_amp_prec_dtype(true); });
    m.def("_set_amp_high_prec_dtype", [](std::string dtype_name) -> std::string {
        return set_amp_prec_dtype(true, dtype_name);
    });
    m.def("_get_amp_low_prec_dtype",
          []() -> std::string { return get_amp_prec_dtype(false); });
    m.def("_set_amp_low_prec_dtype", [](std::string dtype_name) -> std::string {
        return set_amp_prec_dtype(false, dtype_name);
    });

1432 1433
    m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); });

1434
    py::register_exception<TraceError>(m, "TraceError");
1435 1436
}

1437 1438
#undef MGE_PY_INTERFACE

M
Megvii Engine Team 已提交
1439
}  // namespace mgb::imperative::python