提交 6c4c4ca6 编写于 作者: M Megvii Engine Team

Merge pull request #459 from Qsingle:fix_overflow_of_flops_calculate

GitOrigin-RevId: c1333e2089c5aa9379d64a3e37194e750704b02e
......@@ -90,7 +90,7 @@ def flops_convNd(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return np.prod(outputs[0].shape) * (
module.in_channels // module.groups * np.prod(module.kernel_size) + bias
float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias
)
......
......@@ -487,7 +487,7 @@ def flops_conv(opnode: ConvolutionForward, inputs, outputs):
NCHW = np.prod(outputs[0].shape)
bias = 1 if isinstance(opnode, ConvBiasForward) else 0
# N x Cout x H x W x (Cin x Kw x Kh)
return NCHW * (num_input * kw * kh + bias)
return NCHW * (float(num_input * kw * kh) + bias)
@register_receptive_field(ConvolutionForward, ConvBiasForward)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册