未验证 提交 351d37d9 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel] Add conv2d and pool flops (#48084)

* add pool flops

* add annotations and tests
上级 35f3c258
......@@ -356,6 +356,32 @@ class TestFLOPSAPI(unittest.TestCase):
)
== 144
)
self.assertTrue(
flops(
'pool',
{'X': [[12, 12]]},
{},
)
== 12 * 12
)
self.assertTrue(
flops(
'conv2d',
{
'Bias': [],
'Filter': [[3, 3, 2, 2]],
'Input': [[8, 3, 4, 4]],
'ResidualData': [],
},
{
'dilations': [1, 1],
'groups': 1,
'paddings': [1, 1],
'strides': [1, 1],
},
)
== 14400
)
if __name__ == '__main__':
......
......@@ -69,6 +69,85 @@ def _c_embedding_flops(input_shapes, attrs):
return 0
@register_flops("conv2d")
def _conv2d_flops(input_shapes, attrs):
"""FLOPs computation for conv2d op.
For conv2d(input,filter):
active_elements = batch_size * numel(output)
conv_flops = 2 * macs_per_position_conv * active_elements
bias_flops = out_channels * active_elements
equation: flops = conv_flops + bias_flops
"""
bias = (
input_shapes.get('Bias')[0]
if len(input_shapes.get('Bias')) > 0
else None
)
input = input_shapes.get('Input')[0]
weight = input_shapes.get('Filter')[0]
padding = attrs.get('paddings')
stride = attrs.get('strides')
dilation = attrs.get('dilations')
groups = attrs.get('groups')
batch_size = input[0]
in_channels = input[1]
out_channels = weight[0]
kernel_dims = list(weight[2:])
input_dims = list(input[2:])
length = len(input_dims)
paddings = (
padding
if isinstance(padding, list)
else [
padding,
]
* length
)
strides = (
stride
if isinstance(stride, list)
else [
stride,
]
* length
)
dilations = (
dilation
if isinstance(dilation, list)
else [
dilation,
]
* length
)
output_dims = []
for idx, input_dim in enumerate(input_dims):
output_dim = (
input_dim
+ 2 * paddings[idx]
- (dilations[idx] * (kernel_dims[idx] - 1) + 1)
) // strides[idx] + 1
output_dims.append(output_dim)
filters_per_channel = out_channels // groups
macs_conv_per_position = (
prod(kernel_dims) * in_channels * filters_per_channel
)
active_elements = batch_size * prod(output_dims)
overall_conv_macs = macs_conv_per_position * active_elements
overall_conv_flops = 2 * overall_conv_macs
overall_bias_flops = 0
if bias is not None:
overall_bias_flops = out_channels * active_elements
return overall_conv_flops + overall_bias_flops
@register_flops("dropout")
def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op.
......@@ -195,7 +274,7 @@ def _matmul_v2_flops(input_shapes, attrs):
shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
equation: flops = 2 * numel(output) * dim_n
equation: flops = 2 * numel(outputs) * dim_n
"""
x_shape = input_shapes.get('X')[0]
y_shape = input_shapes.get('Y')[0]
......@@ -281,3 +360,13 @@ def _transpose2_flops(input_shapes, attrs):
equation: flops = 0
"""
return 0
@register_flops("pool")
def _pool_flops(input_shapes, attrs):
"""FLOPs computation for pool op.
For pool(input):
equation: flops = (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
return prod(input)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册