未验证 提交 c5137b22 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel] Add All Relu Flops (#48083)

* relu flops all

* add annotations and tests

* revision for codestyle
上级 18c0a002
......@@ -293,9 +293,6 @@ class TestFLOPSAPI(unittest.TestCase):
)
== 3 * 12 * 12 * 12 * 2 * 8
)
self.assertTrue(
flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12
)
self.assertTrue(
flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12
)
......@@ -303,6 +300,56 @@ class TestFLOPSAPI(unittest.TestCase):
flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {})
== 0
)
self.assertTrue(
flops(
'elu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'leaky_relu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'prelu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'relu6',
{
'X': [[12, 12]],
},
{},
)
== 144
)
self.assertTrue(
flops(
'silu',
{
'X': [[12, 12]],
},
{},
)
== 144
)
if __name__ == '__main__':
......
......@@ -73,7 +73,7 @@ def _c_embedding_flops(input_shapes, attrs):
def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op.
For dropout(input):
equation: flops = 0
equation: flops = 0
"""
return 0
......@@ -191,7 +191,7 @@ def _matmul_v2_flops(input_shapes, attrs):
"""FLOPs computation for matmul_v2 op.
For matmul_v2(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
......@@ -216,13 +216,43 @@ def _matmul_v2_flops(input_shapes, attrs):
return 2 * macs
@register_flops("relu")
def _relu_flops(input_shapes, attrs):
"""FLOPs computation for relu op.
For relu(input):
def _relu_class_flops(input_shapes, attrs):
"""FLOPs computation for relu_like ops.
For elu/leaky_relu/prelu/relu/relu6/silu (input):
equation: flops = (numel)total number of elements in the input tensor.
"""
return prod(input_shapes.get('X')[0])
input = input_shapes.get('X')[0]
return prod(input)
@register_flops("elu")
def _elu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("leaky_relu")
def _leaky_relu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("prelu")
def _prelu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("relu")
def _relu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("relu6")
def _relu6_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("silu")
def _silu_flops(input_shapes, attrs):
return _relu_class_flops(input_shapes, attrs)
@register_flops("reshape2")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册