提交 76dbaa27 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add name, make_h2d, dump_graph to graph runtime

GitOrigin-RevId: b8681a31a81502f12340dafd56c1b4d466b22020
上级 7336b306
......@@ -78,6 +78,14 @@ class Graph(_imperative_rt.ComputingGraph):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
return opnode.outputs[0]
def make_h2d(self, *, dtype, device):
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_h2d(self, device, dtype))
def dump(*args):
return _imperative_rt.dump_graph([i._node for i in args])
class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
......@@ -92,6 +100,14 @@ class VarNode(TensorBase):
def op(self):
return self.graph._wrap(self._node.owner)
@property
def name(self):
return self._node.name
@name.setter
def name(self, name):
self._node.name = name
@property
def dtype(self):
return self._node.dtype
......@@ -118,6 +134,14 @@ class OpNode:
def graph(self) -> Graph:
return self._node.graph
@property
def name(self):
return self._node.name
@name.setter
def name(self, name):
self._node.name = name
@property
def inputs(self):
return tuple(map(self.graph._wrap, self._node.inputs))
......
......@@ -11,6 +11,7 @@
#include "./graph_rt.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
......@@ -47,7 +48,8 @@ void init_graph_rt(py::module m) {
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();})
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_))
.def_property("name", py::overload_cast<>(&VarNode::name, py::const_),
py::overload_cast<std::string>(&VarNode::name))
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();})
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();})
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* {
......@@ -75,7 +77,8 @@ void init_graph_rt(py::module m) {
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_))
.def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->input());
})
......@@ -99,6 +102,15 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options));
m.def("dump_graph", [](const std::vector<VarNode*>& dest_vars) {
using namespace mgb::serialization;
std::vector<uint8_t> buf;
auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf));
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
dumper->dump(symvars);
return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
});
#define CURRENT_CLASS cg::ComputingGraph::Options
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
......@@ -198,6 +210,20 @@ void init_graph_rt(py::module m) {
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) {
if (!cn.valid()) {
throw py::type_error("device must be valid");
}
if (!dtype.valid()) {
throw py::type_error("dtype must be valid");
}
OperatorNodeConfig config;
if (name) {
config.name(*name);
}
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none());
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node,
const DType& dtype,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册