提交 8a918717 编写于 作者: M Megvii Engine Team

feat(mgb): add megengine inference

GitOrigin-RevId: 6ffec6b418273b13ab76c0ef30a4749219b17aef
上级 379a28f9
......@@ -156,6 +156,9 @@ def run_model(args, graph, inputs, outputs, data):
func = graph.compile(outputs)
if args.get_static_mem_info:
func.get_static_memory_alloc_info(args.get_static_mem_info)
def run():
if not args.embed_input:
for key in inp_dict:
......@@ -389,6 +392,11 @@ def main():
help="embed input data as SharedDeviceTensor in model, "
"to remove memory copy for inputs",
)
parser.add_argument(
"--get-static-mem-info",
type=str,
help="Record the static graph's static memory info.",
)
args = parser.parse_args()
if args.verbose:
......
......@@ -215,7 +215,10 @@ void init_graph_rt(py::module m) {
}
}
return ret;
});
})
.def("get_static_memory_alloc_info",
&cg::AsyncExecutable::get_static_memory_alloc_info,
py::call_guard<py::gil_scoped_release>());
auto PyComputingGraph = py::class_<cg::ComputingGraph, std::shared_ptr<cg::ComputingGraph>>(m, "ComputingGraph")
.def(py::init(py::overload_cast<>(&cg::ComputingGraph::make)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册