diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index cc4a4f4ee6dcf39544fca7313233a83360395dc1..e8a1c31feeebc49daf41f3c76adb37aadef06d27 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -78,6 +78,14 @@ class Graph(_imperative_rt.ComputingGraph): opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) return opnode.outputs[0] + def make_h2d(self, *, dtype, device): + device = as_device(device).to_c() + return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) + + +def dump(*args): + return _imperative_rt.dump_graph([i._node for i in args]) + class VarNode(TensorBase): def __init__(self, node: _imperative_rt.VarNode): @@ -92,6 +100,14 @@ class VarNode(TensorBase): def op(self): return self.graph._wrap(self._node.owner) + @property + def name(self): + return self._node.name + + @name.setter + def name(self, name): + self._node.name = name + @property def dtype(self): return self._node.dtype @@ -118,6 +134,14 @@ class OpNode: def graph(self) -> Graph: return self._node.graph + @property + def name(self): + return self._node.name + + @name.setter + def name(self, name): + self._node.name = name + @property def inputs(self): return tuple(map(self.graph._wrap, self._node.inputs)) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 27899e026b85eb32e7989fdba22c781bb6e6e746..f827466f44f07e1967ffec9c25c80bed92899998 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -11,6 +11,7 @@ #include "./graph_rt.h" +#include "megbrain/serialization/serializer.h" #include "megbrain/imperative/opr_utility.h" #include "megbrain/opr/io.h" #include "megbrain/opr/basic_arith.h" @@ -47,7 +48,8 @@ void init_graph_rt(py::module m) { py::class_>(m, "VarNode") .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) - .def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) + .def_property("name", py::overload_cast<>(&VarNode::name, py::const_), + py::overload_cast(&VarNode::name)) .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { @@ -75,7 +77,8 @@ void init_graph_rt(py::module m) { py::class_>(m, "OperatorNode") .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) - .def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) + .def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_), + py::overload_cast(&cg::OperatorNodeBase::name)) .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { return to_tuple(opr->input()); }) @@ -99,6 +102,15 @@ void init_graph_rt(py::module m) { }) .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options)); + m.def("dump_graph", [](const std::vector& dest_vars) { + using namespace mgb::serialization; + std::vector buf; + auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf)); + SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); + dumper->dump(symvars); + return py::bytes(reinterpret_cast(&buf[0]), buf.size()); + }); + #define CURRENT_CLASS cg::ComputingGraph::Options auto PyComputingGraphOptions = py::class_(PyComputingGraph, "Options") @@ -198,6 +210,20 @@ void init_graph_rt(py::module m) { return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); }); + m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional name) { + if (!cn.valid()) { + throw py::type_error("device must be valid"); + } + if (!dtype.valid()) { + throw py::type_error("dtype must be valid"); + } + OperatorNodeConfig config; + if (name) { + config.name(*name); + } + return opr::Host2DeviceCopy::make(graph, std::make_shared(cn, dtype), config).node(); + }, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); + m.def("input_callback", [input_callback](std::function callback, const CompNode& comp_node, const DType& dtype,