#include "./imperative_rt.h" #include #include #include #include #include #include "megbrain/imperative.h" #include "megbrain/imperative/interpreter.h" #include "megbrain/imperative/ops/opr_attr.h" #include "./helper.h" namespace py = pybind11; using namespace mgb; using namespace imperative; using namespace interpreter; namespace { std::optional, std::vector, std::vector>> make_backward_graph( const OpDef& opdef, std::vector inputs, std::vector input_requires_grad, std::vector output_has_grad) { auto res = OpDef::make_backward_graph(opdef, SmallVector(inputs.begin(), inputs.end()), SmallVector(input_requires_grad.begin(), input_requires_grad.end()), SmallVector(output_has_grad.begin(), output_has_grad.end())); if (res.backward) { return std::optional, std::vector, std::vector>>{ std::in_place, res.backward, res.save_for_backward, res.input_has_grad}; } else { return {}; } } } // namespace void init_imperative_rt(py::module m) { py::class_(m, "Interpreter") .def("put", [](Interpreter::Channel& self, py::array data, DType dtype, CompNode cn) { if (!cn.valid()) { cn = CompNode::load("xpux"); } constexpr int size_threshhold = TensorShape::MAX_NDIM; if (data.size() > size_threshhold) { return self.put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); } else { HostTensorND ret(cn); return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); } }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { return self.del(handle); }) .def("get_value", [](Interpreter::Channel& self, Interpreter::Handle handle) { PyObject* optr = npy::ndarray_from_tensor(self.get_value(handle), npy::ShareType::TRY_SHARE); return py::reinterpret_steal(optr); }) .def("get_dtype", &Interpreter::Channel::get_dtype) .def("get_device", &Interpreter::Channel::get_device) .def("get_shape", &Interpreter::Channel::get_shape) .def("_get_dev_tensor", &Interpreter::Channel::get_dev_tensor) .def("apply_op", &Interpreter::Channel::apply_op) .def("sync", &Interpreter::Channel::sync); std::unique_ptr ch = Interpreter::inst().create_channel(); m.attr("interpreter") = py::detail::make_caster::cast( std::move(ch), py::return_value_policy::move, {}); for (auto name : {"put", "delete", "get_value", "get_dtype", "get_device", "get_shape", "_get_dev_tensor", "apply_op"}) { m.attr(name) = m.attr("interpreter").attr(name); } m.def("sync", [m]() { m.attr("interpreter").attr("sync")(); py_task_q.wait_all_task_finish(); }); m.def("make_backward_graph", &make_backward_graph); py::class_>(m, "OpDef") .def("ctype", [](const OpDef& opdef) { if (auto attr = opdef.try_cast_final()) { return attr->type.c_str(); } return opdef.dyn_typeinfo()->name; }) .def("__eq__", [](const OpDef& lhs, const OpDef& rhs) { return lhs.is_same(rhs); }) .def("__hash__", &OpDef::hash); }