diff --git a/imperative/python/megengine/utils/module_stats.py b/imperative/python/megengine/utils/module_stats.py index 11d4d4838edf6e66c5c2da52147ed5ece0c382c0..7e06a355d152de5d7053638fffa0e3f3bd2a8745 100644 --- a/imperative/python/megengine/utils/module_stats.py +++ b/imperative/python/megengine/utils/module_stats.py @@ -90,7 +90,7 @@ def flops_convNd(module: M.Conv2d, inputs, outputs): bias = 1 if module.bias is not None else 0 # N x Cout x H x W x (Cin x Kw x Kh + bias) return np.prod(outputs[0].shape) * ( - module.in_channels // module.groups * np.prod(module.kernel_size) + bias + float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias ) diff --git a/imperative/python/megengine/utils/network_node.py b/imperative/python/megengine/utils/network_node.py index d357da9cafb7e79ee906aa180c4d719baba4ca2f..b44897e9905c3b171efbc4db63e28652b071ac74 100644 --- a/imperative/python/megengine/utils/network_node.py +++ b/imperative/python/megengine/utils/network_node.py @@ -487,7 +487,7 @@ def flops_conv(opnode: ConvolutionForward, inputs, outputs): NCHW = np.prod(outputs[0].shape) bias = 1 if isinstance(opnode, ConvBiasForward) else 0 # N x Cout x H x W x (Cin x Kw x Kh) - return NCHW * (num_input * kw * kh + bias) + return NCHW * (float(num_input * kw * kh) + bias) @register_receptive_field(ConvolutionForward, ConvBiasForward)