提交 8e11204a 编写于 作者: M Megvii Engine Team

perf(mge/module): optimize conv_bn qat module to improve performance

GitOrigin-RevId: 9415b83d9c248907b96adeecac0b7f0dfb664c81
上级 51fa530d
......@@ -132,13 +132,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.size / conv.shape[1]
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
conv = self.bn(orig_conv)
return conv
@classmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册