diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index e0e3085fd7f2960bab469ca2fdfbe8e040b11dfb..8fdb62e7850eff02a746f8dfc50877ab43061026 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -83,9 +83,7 @@ def disable_receptive_field(): _receptive_field_enabled = False -@register_flops( - M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d -) +@register_flops(M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d) def flops_convNd(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) @@ -93,13 +91,16 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias ) + @register_flops(M.ConvTranspose2d) def flops_convNdTranspose(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) - return np.prod(inputs[0].shape) * ( - module.out_channels // module.groups * np.prod(module.kernel_size) - ) + np.prod(outputs[0].shape) * bias + return ( + np.prod(inputs[0].shape) + * (module.out_channels // module.groups * np.prod(module.kernel_size)) + + np.prod(outputs[0].shape) * bias + ) @register_flops(