From b1baee60f58588aca309c85b9b80db52ee29c506 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Apr 2021 12:35:55 +0800 Subject: [PATCH] feat(imperative/utils): add optimize-for-inference interface for opgraph GitOrigin-RevId: 9f93f821905dc05e3968247129920a0a1d43712f --- imperative/python/megengine/utils/network.py | 92 ++++++++++++-------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/imperative/python/megengine/utils/network.py b/imperative/python/megengine/utils/network.py index 4a3f5ac1..80121d33 100644 --- a/imperative/python/megengine/utils/network.py +++ b/imperative/python/megengine/utils/network.py @@ -11,7 +11,7 @@ import fnmatch import itertools import re from collections import OrderedDict -from typing import Dict, List +from typing import Dict, List, Sequence import numpy as np @@ -87,6 +87,58 @@ class Network: for o in opr.outputs: self.all_vars_map[o.var.id] = o + def optimize_for_inference(self, dest_vars, **kwargs): + r""" + Applies optimize_for_inference pass for operator graph. + + :param dest_vars: list of output vars in the operator graph + + :Keyword Arguments: + + * enable_io16xc32 -- + whether to use float16 for I/O between oprs and use + float32 as internal computation precision. Note the output var would be + changed to float16. + * enable_ioc16 -- + whether to use float16 for both I/O and computation + precision. + + * enable_hwcd4 -- + whether to use NHWCD4 data layout. This is faster on some + OpenCL backend. + * enable_nchw88 -- + whether to use NCHW88 data layout, currently + used in X86 AVX backend. + * enable_nchw44 -- + whether to use NCHW44 data layout, currently + used in arm backend. + * enable_nchw44_dot -- + whether to use NCHW44_dot data layout, currently + used in armv8.2+dotprod backend. + * enable_nchw4 -- + whether to use NCHW4 data layout, currently + used in nvidia backend(based on cudnn). + * enable_nchw32 -- + whether to use NCHW32 data layout, currently + used in nvidia backend with tensorcore(based on cudnn). + * enable_chwn4 -- + whether to use CHWN4 data layout, currently + used in nvidia backend with tensorcore. + + * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty + into one opr. + * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z + input for inference on nvidia backend(this optimization pass will + result in mismatch of the precision of output of training and + inference) + """ + + if not isinstance(dest_vars, Sequence): + dest_vars = [dest_vars] + dest_vars = list(G.VarNode(var.var) for var in dest_vars) + new_vars = G.optimize_for_inference(dest_vars, **kwargs) + return list(self._get_var(var) for var in new_vars) + def dump( self, file, @@ -126,42 +178,8 @@ class Network: :Keyword Arguments: - * enable_io16xc32 -- - whether to use float16 for I/O between oprs and use - float32 as internal computation precision. Note the output var would be - changed to float16. - * enable_ioc16 -- - whether to use float16 for both I/O and computation - precision. - - * enable_hwcd4 -- - whether to use NHWCD4 data layout. This is faster on some - OpenCL backend. - * enable_nchw88 -- - whether to use NCHW88 data layout, currently - used in X86 AVX backend. - * enable_nchw44 -- - whether to use NCHW44 data layout, currently - used in arm backend. - * enable_nchw44_dot -- - whether to use NCHW44_dot data layout, currently - used in armv8.2+dotprod backend. - * enable_nchw4 -- - whether to use NCHW4 data layout, currently - used in nvidia backend(based on cudnn). - * enable_nchw32 -- - whether to use NCHW32 data layout, currently - used in nvidia backend with tensorcore(based on cudnn). - * enable_chwn4 -- - whether to use CHWN4 data layout, currently - used in nvidia backend with tensorcore. - - * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty - into one opr. - * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z - input for inference on nvidia backend(this optimization pass will - result in mismatch of the precision of output of training and - inference) + See also :py:meth:`optimize_for_inference`. + """ self._compile() -- GitLab