trace.cpp 2.4 KB
Newer Older
1 2 3 4
/**
 * \file imperative/python/src/trace.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24
 */

#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_trace(ApplyContext& ctx) {
    apply_result_t outputs;

    if (ctx.backward) {
25
        // reach here when compiled=True
26
        // call megbrain_graph.py apply(BackwardGraph, *args)
27 28
        auto args = py::tuple(ctx.nargs + 1);
        args[0] = py::cast(ctx.op);
29
        for (size_t i = 0; i < ctx.nargs; i++) {
30
            args[i + 1] = py::cast(ctx.args[i]->m_var);
31
        }
32 33
        py::object ret = py::reinterpret_steal<py::object>(
                PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
34
        if (!ret) throw py::error_already_set();
35 36 37

        // assumption: python function always returns PyList
        auto tup = py::reinterpret_borrow<py::list>(ret);
38
        for (size_t i = 0; i < tup.size(); i++) {
39
            auto pitem = tup[i].cast<cg::VarNode*>();
40 41 42 43 44
            outputs.emplace_back(std::make_shared<Tensor>(pitem));
        }
        return outputs;
    }

45
    PyObject* pyf;
46 47 48 49 50 51 52 53
    if (is_compiled) {
        // run apply in compiled mode, step 2, 3, etc
        pyf = cpp_apply_compiled_mode;
    } else {
        // run first step, both symbolic and non symbolic
        pyf = cpp_apply_with_tracing;
    }

54 55
    auto args = py::tuple(ctx.nargs + 1);
    args[0] = py::cast(ctx.op);
56
    for (size_t i = 0; i < ctx.nargs; i++) {
57
        args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
58
    }
59 60 61
    auto pyout = PyObject_Call(pyf, args.ptr(), nullptr);
    if (!pyout) throw py::error_already_set();
    auto ret = py::reinterpret_steal<py::object>(pyout);
62 63 64 65

    // assumption: python function always returns PyList
    auto tup = py::reinterpret_borrow<py::list>(ret);
    for (auto i = 0; i < tup.size(); i++) {
66
        auto tw = TensorWrapper::try_cast(tup[i].ptr());
67 68 69 70 71 72
        outputs.emplace_back(tw->m_tensor);
    }
    return outputs;
}

} // namespace mgb::imperative::python