提交 d3e786ef 编写于 作者: M Megvii Engine Team 提交者: “wenjuan”

feat(imperative): load_nerwork_and_run enable weight preprocess

GitOrigin-RevId: 0642b237a7595624f847d6c154802c853b092bd8
上级 c6ff878d
...@@ -121,6 +121,9 @@ def run_model(args, graph, inputs, outputs, data): ...@@ -121,6 +121,9 @@ def run_model(args, graph, inputs, outputs, data):
# must use level0 to avoid unintended opr modification # must use level0 to avoid unintended opr modification
graph.options.graph_opt_level = 0 graph.options.graph_opt_level = 0
if args.weight_preprocess:
graph.enable_weight_preprocess()
logger.info("input tensors: ") logger.info("input tensors: ")
for k, v in data.items(): for k, v in data.items():
logger.info(" {}: {}".format(k, v.shape)) logger.info(" {}: {}".format(k, v.shape))
...@@ -161,8 +164,8 @@ def run_model(args, graph, inputs, outputs, data): ...@@ -161,8 +164,8 @@ def run_model(args, graph, inputs, outputs, data):
func.wait() func.wait()
return [oup_node.get_value().numpy() for oup_node in output_dict.values()] return [oup_node.get_value().numpy() for oup_node in output_dict.values()]
if args.warm_up: for i in range(args.warm_up):
logger.info("warming up") logger.info("warming up {}".format(i))
run() run()
total_time = 0 total_time = 0
...@@ -276,8 +279,9 @@ def main(): ...@@ -276,8 +279,9 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--warm-up", "--warm-up",
action="store_true", type=int,
help="warm up model before do timing " " for better estimation", default=0,
help="times of warm up model before do timing " " for better estimation",
) )
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
...@@ -394,6 +398,13 @@ def main(): ...@@ -394,6 +398,13 @@ def main():
parser.add_argument( parser.add_argument(
"--custom-op-lib", type=str, help="path of the custom op", "--custom-op-lib", type=str, help="path of the custom op",
) )
parser.add_argument(
"--weight-preprocess",
action="store_true",
help="Execute operators with weight preprocess, which can"
"optimize the operator execution time with algo of winograd,"
"im2col ,etc.,but it may consume more memory.",
)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -253,6 +253,10 @@ void init_graph_rt(py::module m) { ...@@ -253,6 +253,10 @@ void init_graph_rt(py::module m) {
} }
return graph.compile(spec); return graph.compile(spec);
}) })
.def("enable_weight_preprocess",
[](cg::ComputingGraph& graph) {
graph.options().graph_opt.enable_weight_preprocess();
})
.def_property_readonly( .def_property_readonly(
"options", "options",
py::overload_cast<>(&cg::ComputingGraph::options)); py::overload_cast<>(&cg::ComputingGraph::options));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册