提交 8a918717 编写于 作者: M Megvii Engine Team

feat(mgb): add megengine inference

GitOrigin-RevId: 6ffec6b418273b13ab76c0ef30a4749219b17aef
上级 379a28f9
...@@ -156,6 +156,9 @@ def run_model(args, graph, inputs, outputs, data): ...@@ -156,6 +156,9 @@ def run_model(args, graph, inputs, outputs, data):
func = graph.compile(outputs) func = graph.compile(outputs)
if args.get_static_mem_info:
func.get_static_memory_alloc_info(args.get_static_mem_info)
def run(): def run():
if not args.embed_input: if not args.embed_input:
for key in inp_dict: for key in inp_dict:
...@@ -389,6 +392,11 @@ def main(): ...@@ -389,6 +392,11 @@ def main():
help="embed input data as SharedDeviceTensor in model, " help="embed input data as SharedDeviceTensor in model, "
"to remove memory copy for inputs", "to remove memory copy for inputs",
) )
parser.add_argument(
"--get-static-mem-info",
type=str,
help="Record the static graph's static memory info.",
)
args = parser.parse_args() args = parser.parse_args()
if args.verbose: if args.verbose:
......
...@@ -215,7 +215,10 @@ void init_graph_rt(py::module m) { ...@@ -215,7 +215,10 @@ void init_graph_rt(py::module m) {
} }
} }
return ret; return ret;
}); })
.def("get_static_memory_alloc_info",
&cg::AsyncExecutable::get_static_memory_alloc_info,
py::call_guard<py::gil_scoped_release>());
auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph") auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make))) .def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册