diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index c7b2a37ca9a2b2fa4cce4b2952f64635e61d063d..690cda07ef9854bdb9594342b0ba834b9c55016a 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -5,6 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import contextlib from functools import partial import numpy as np @@ -87,30 +88,20 @@ def disable_receptive_field(): @register_flops( - m.Conv1d, m.Conv2d, m.Conv3d, + m.Conv1d, m.Conv2d, m.Conv3d, m.ConvTranspose2d, m.LocalConv2d, m.DeformableConv2d ) def flops_convNd(module: m.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 - group = module.groups - ic = inputs[0].shape[1] - oc = outputs[0].shape[1] - goc = oc // group - gic = ic // group - N = outputs[0].shape[0] - HW = np.prod(outputs[0].shape[2:]) # N x Cout x H x W x (Cin x Kw x Kh + bias) - return N * HW * goc * (gic * np.prod(module.kernel_size) + bias) - - -@register_flops(m.ConvTranspose2d) -def flops_deconvNd(module: m.ConvTranspose2d, inputs, outputs): - return np.prod(inputs[0].shape) * outputs[0].shape[1] * np.prod(module.kernel_size) + return np.prod(outputs[0].shape) * ( + module.in_channels // module.groups * np.prod(module.kernel_size) + bias + ) @register_flops(m.Linear) def flops_linear(module: m.Linear, inputs, outputs): - bias = 1 if module.bias is not None else 0 - return np.prod(outputs[0].shape) * module.in_features + bias = module.out_features if module.bias is not None else 0 + return np.prod(outputs[0].shape) * module.in_features + bias @register_flops(m.BatchMatMulActivation) @@ -340,6 +331,31 @@ def module_stats( param_stats["name"] = name + "-b" params.append(param_stats) + @contextlib.contextmanager + def adjust_stats(module, training=False): + """Adjust module to training/eval mode temporarily. + + Args: + module (M.Module): used module. + training (bool): training mode. True for train mode, False fro eval mode. + """ + + def recursive_backup_stats(module, mode): + for m in module.modules(): + # save prev status to _prev_training + m._prev_training = m.training + m.train(mode, recursive=False) + + def recursive_recover_stats(module): + for m in module.modules(): + # recover prev status and delete attribute + m.training = m._prev_training + delattr(m, "_prev_training") + + recursive_backup_stats(module, mode=training) + yield module + recursive_recover_stats(module) + # multiple inputs to the network if not isinstance(input_size[0], tuple): input_size = [input_size] @@ -355,8 +371,9 @@ def module_stats( ) inputs = [zeros(in_size, dtype=np.float32) for in_size in input_size] - model.eval() - model(*inputs) + with adjust_stats(model, training=False) as model: + model(*inputs) + for h in hooks: h.remove()