提交 f50db0b8 编写于 作者: M Megvii Engine Team

Merge pull request #464 from fanhqme2:patch-1

GitOrigin-RevId: c08f061e2eb200532c99c1e668788e227502abe6
......@@ -84,7 +84,7 @@ def disable_receptive_field():
@register_flops(
M.Conv1d, M.Conv2d, M.Conv3d, M.ConvTranspose2d, M.LocalConv2d, M.DeformableConv2d
M.Conv1d, M.Conv2d, M.Conv3d, M.LocalConv2d, M.DeformableConv2d
)
def flops_convNd(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
......@@ -93,6 +93,14 @@ def flops_convNd(module: M.Conv2d, inputs, outputs):
float(module.in_channels // module.groups) * np.prod(module.kernel_size) + bias
)
@register_flops(M.ConvTranspose2d)
def flops_convNdTranspose(module: M.Conv2d, inputs, outputs):
bias = 1 if module.bias is not None else 0
# N x Cout x H x W x (Cin x Kw x Kh + bias)
return np.prod(inputs[0].shape) * (
module.out_channels // module.groups * np.prod(module.kernel_size)
) + np.prod(outputs[0].shape) * bias
@register_flops(
M.batchnorm._BatchNorm, M.SyncBatchNorm, M.GroupNorm, M.LayerNorm, M.InstanceNorm,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册