trace.cpp 2.4 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file imperative/python/src/trace.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23 24
 */

#include "./trace.h"
#include "./helper.h"
#include "megbrain/imperative/ops/autogen.h"

namespace py = pybind11;

namespace mgb::imperative::python {

apply_result_t apply_trace(ApplyContext& ctx) {
    apply_result_t outputs;

    if (ctx.backward) {
25
        // reach here when compiled=True
26
        // call megbrain_graph.py apply(BackwardGraph, *args)
27 28
        auto args = py::tuple(ctx.nargs + 1);
        args[0] = py::cast(ctx.op);
29
        for (size_t i = 0; i < ctx.nargs; i++) {
30
            args[i + 1] = py::cast(ctx.args[i]->m_var);
31
        }
32 33
        py::object ret = py::reinterpret_steal<py::object>(
                PyObject_Call(cpp_apply_backward_varnode, args.ptr(), nullptr));
34 35 36 37 38 39 40
        if (!ret) {
            throw py::value_error("invalid py object call");
        }

        // assumption: python function always returns PyList
        auto tup = py::reinterpret_borrow<py::list>(ret);
        for (auto i = 0; i < tup.size(); i++) {
41
            auto pitem = tup[i].cast<cg::VarNode*>();
42 43 44 45 46
            outputs.emplace_back(std::make_shared<Tensor>(pitem));
        }
        return outputs;
    }

47
    PyObject* pyf;
48 49 50 51 52 53 54 55
    if (is_compiled) {
        // run apply in compiled mode, step 2, 3, etc
        pyf = cpp_apply_compiled_mode;
    } else {
        // run first step, both symbolic and non symbolic
        pyf = cpp_apply_with_tracing;
    }

56 57
    auto args = py::tuple(ctx.nargs + 1);
    args[0] = py::cast(ctx.op);
58
    for (size_t i = 0; i < ctx.nargs; i++) {
59
        args[i + 1] = TensorWrapper::make(ctx.args[i]->shared_from_this());
60
    }
61 62
    auto ret = py::reinterpret_steal<py::object>(
            PyObject_Call(pyf, args.ptr(), nullptr));
63 64 65 66

    // assumption: python function always returns PyList
    auto tup = py::reinterpret_borrow<py::list>(ret);
    for (auto i = 0; i < tup.size(); i++) {
67
        auto tw = TensorWrapper::try_cast(tup[i].ptr());
68 69 70 71 72 73
        outputs.emplace_back(tw->m_tensor);
    }
    return outputs;
}

} // namespace mgb::imperative::python