From 25f97b76ea6415999474f0cb9e5d222074567e36 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Aug 2022 13:31:39 +0800 Subject: [PATCH] feat(imperative): load_nerwork_and_run enable weight preprocess GitOrigin-RevId: 0642b237a7595624f847d6c154802c853b092bd8 --- .../megengine/tools/load_network_and_run.py | 19 +++++++++++++++---- imperative/python/src/graph_rt.cpp | 4 ++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/tools/load_network_and_run.py b/imperative/python/megengine/tools/load_network_and_run.py index d226ad1ae..1012e06fc 100755 --- a/imperative/python/megengine/tools/load_network_and_run.py +++ b/imperative/python/megengine/tools/load_network_and_run.py @@ -121,6 +121,9 @@ def run_model(args, graph, inputs, outputs, data): # must use level0 to avoid unintended opr modification graph.options.graph_opt_level = 0 + if args.weight_preprocess: + graph.enable_weight_preprocess() + logger.info("input tensors: ") for k, v in data.items(): logger.info(" {}: {}".format(k, v.shape)) @@ -161,8 +164,8 @@ def run_model(args, graph, inputs, outputs, data): func.wait() return [oup_node.get_value().numpy() for oup_node in output_dict.values()] - if args.warm_up: - logger.info("warming up") + for i in range(args.warm_up): + logger.info("warming up {}".format(i)) run() total_time = 0 @@ -276,8 +279,9 @@ def main(): ) parser.add_argument( "--warm-up", - action="store_true", - help="warm up model before do timing " " for better estimation", + type=int, + default=0, + help="times of warm up model before do timing " " for better estimation", ) parser.add_argument( "--verbose", @@ -394,6 +398,13 @@ def main(): parser.add_argument( "--custom-op-lib", type=str, help="path of the custom op", ) + parser.add_argument( + "--weight-preprocess", + action="store_true", + help="Execute operators with weight preprocess, which can" + "optimize the operator execution time with algo of winograd," + "im2col ,etc.,but it may consume more memory.", + ) args = parser.parse_args() diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index de7f28b99..a6085aef1 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -253,6 +253,10 @@ void init_graph_rt(py::module m) { } return graph.compile(spec); }) + .def("enable_weight_preprocess", + [](cg::ComputingGraph& graph) { + graph.options().graph_opt.enable_weight_preprocess(); + }) .def_property_readonly( "options", py::overload_cast<>(&cg::ComputingGraph::options)); -- GitLab