未验证 提交 c5137b22 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel] Add All Relu Flops (#48083)

* relu flops all

* add annotations and tests

* revision for codestyle
上级 18c0a002
...@@ -293,9 +293,6 @@ class TestFLOPSAPI(unittest.TestCase): ...@@ -293,9 +293,6 @@ class TestFLOPSAPI(unittest.TestCase):
) )
== 3 * 12 * 12 * 12 * 2 * 8 == 3 * 12 * 12 * 12 * 2 * 8
) )
self.assertTrue(
flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12
)
self.assertTrue( self.assertTrue(
flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12 flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12
) )
...@@ -303,6 +300,56 @@ class TestFLOPSAPI(unittest.TestCase): ...@@ -303,6 +300,56 @@ class TestFLOPSAPI(unittest.TestCase):
flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {}) flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {})
== 0 == 0
) )
self.assertTrue(
flops(
'elu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'leaky_relu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'prelu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'relu6',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'silu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -73,7 +73,7 @@ def _c_embedding_flops(input_shapes, attrs): ...@@ -73,7 +73,7 @@ def _c_embedding_flops(input_shapes, attrs):
def _dropout_flops(input_shapes, attrs): def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op. """FLOPs computation for dropout op.
For dropout(input): For dropout(input):
equation: flops = 0 equation: flops = 0
""" """
return 0 return 0
...@@ -191,7 +191,7 @@ def _matmul_v2_flops(input_shapes, attrs): ...@@ -191,7 +191,7 @@ def _matmul_v2_flops(input_shapes, attrs):
"""FLOPs computation for matmul_v2 op. """FLOPs computation for matmul_v2 op.
For matmul_v2(input,other): For matmul_v2(input,other):
input_shapes = [shape_of_input, shape_of_ohther] input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
suppose n > m and dim_n = odim_m_1: suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m] shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
...@@ -216,13 +216,43 @@ def _matmul_v2_flops(input_shapes, attrs): ...@@ -216,13 +216,43 @@ def _matmul_v2_flops(input_shapes, attrs):
return 2 * macs return 2 * macs
@register_flops("relu") def _relu_class_flops(input_shapes, attrs):
def _relu_flops(input_shapes, attrs): """FLOPs computation for relu_like ops.
"""FLOPs computation for relu op. For elu/leaky_relu/prelu/relu/relu6/silu (input):
For relu(input):
equation: flops = (numel)total number of elements in the input tensor. equation: flops = (numel)total number of elements in the input tensor.
""" """
return prod(input_shapes.get('X')[0]) input = input_shapes.get('X')[0]
return prod(input)
@register_flops("elu")
def _elu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("leaky_relu")
def _leaky_relu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("prelu")
def _prelu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("relu")
def _relu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("relu6")
def _relu6_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("silu")
def _silu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("reshape2") @register_flops("reshape2")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册