imperative_rt.cpp 4.3 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/python/src/imperative_rt.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
#include "./imperative_rt.h"

#include <future>
#include <variant>
#include <unordered_map>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>

#include "megbrain/imperative.h"
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "./helper.h"

namespace py = pybind11;

using namespace mgb;
using namespace imperative;
using namespace interpreter;


namespace {

std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>
make_backward_graph(
    const OpDef& opdef, std::vector<LogicalTensorDesc> inputs,
    std::vector<bool> input_requires_grad,
    std::vector<bool> output_has_grad) {
    auto res = OpDef::make_backward_graph(opdef,
        SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()),
        SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()),
        SmallVector<bool>(output_has_grad.begin(), output_has_grad.end()));
    if (res.backward) {
        return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{
                std::in_place, res.backward, res.save_for_backward, res.input_has_grad};
    } else {
        return {};
    }
}
} // namespace

void init_imperative_rt(py::module m) {
    py::class_<Interpreter::Channel>(m, "Interpreter")
        .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) {
                if (!cn.valid()) {
                    cn = CompNode::load("xpux");
                }
                constexpr int size_threshhold = TensorShape::MAX_NDIM;
                if (data.size() > size_threshhold) {
                    return self.put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype));
                } else {
                    HostTensorND ret(cn);
                    return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype));
                }
            }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none())
        .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) {
                return self.del(handle);
            })
        .def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) {
                PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE);
                return py::reinterpret_steal<py::object>(optr);
            })
        .def("get_dtype", &Interpreter::Channel::get_dtype)
        .def("get_device", &Interpreter::Channel::get_device)
        .def("get_shape", &Interpreter::Channel::get_shape)
        .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor)
        .def("apply_op", &Interpreter::Channel::apply_op)
78
        .def("sync", &Interpreter::Channel::sync, py::call_guard<py::gil_scoped_release>());
79 80 81 82 83 84 85 86 87 88

    std::unique_ptr<Interpreter::Channel> ch = Interpreter::inst().create_channel();
    m.attr("interpreter") = py::detail::make_caster<decltype(ch)>::cast(
        std::move(ch), py::return_value_policy::move, {});
    for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op"}) {
        m.attr(name) = m.attr("interpreter").attr(name);
    }

    m.def("sync", [m]() {
            m.attr("interpreter").attr("sync")();
89
            py::gil_scoped_release _;
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            py_task_q.wait_all_task_finish();
         });

    m.def("make_backward_graph", &make_backward_graph);

    py::class_<OpDef, std::shared_ptr<OpDef>>(m, "OpDef")
        .def("ctype", [](const OpDef& opdef) {
                if (auto attr = opdef.try_cast_final<OprAttr>()) {
                    return attr->type.c_str();
                }
                return opdef.dyn_typeinfo()->name;
            })
        .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) {
                return lhs.is_same(rhs);
            })
        .def("__hash__", &OpDef::hash);
}