未验证 提交 b35ec181 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel] Add matmul flops (#47816)

* add matmul_flops

* add dropout,layer_norm ...

* add dropout,layer_norm ...

* add elementwise_flops

* add dict check

* add unitests

* add equation for flops computation

* regularized annotations
上级 d983fc34
......@@ -227,6 +227,75 @@ class TestFLOPSAPI(unittest.TestCase):
def test_flops(self):
self.assertTrue(flops('relu', {'X': [[12, 12]]}, {'output': 4}) == 144)
self.assertTrue(flops('dropout', {}, {'output': 4}) == 0)
self.assertTrue(
flops(
'transpose2',
{
'X': [[12, 12, 12]],
},
{},
)
== 0
)
self.assertTrue(
flops(
'reshape2',
{
'X': [[12, 12, 12]],
},
{},
)
== 0
)
self.assertTrue(
flops(
'unsqueeze2',
{
'X': [[12, 12, 12]],
},
{},
)
== 0
)
self.assertTrue(
flops(
'layer_norm',
{'Bias': [[128]], 'Scale': [[128]], 'X': [[32, 128, 28, 28]]},
{'epsilon': 0.01},
)
== 32 * 128 * 28 * 28 * 8
)
self.assertTrue(
flops(
'elementwise_add', {'X': [[12, 12, 12]], 'Y': [[2, 2, 12]]}, {}
)
== 12 * 12 * 12
)
self.assertTrue(
flops('gelu', {'X': [[12, 12, 12]]}, {}) == 5 * 12 * 12 * 12
)
self.assertTrue(
flops(
'matmul',
{'X': [[3, 12, 12, 8]], 'Y': [[12, 12, 8]]},
{'transpose_X': False, 'transpose_Y': True},
)
== 3 * 12 * 12 * 12 * 2 * 8
)
self.assertTrue(
flops(
'matmul_v2',
{'X': [[3, 12, 12, 8]], 'Y': [[12, 12, 8]]},
{'trans_x': False, 'trans_y': True},
)
== 3 * 12 * 12 * 12 * 2 * 8
)
self.assertTrue(
flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12
)
self.assertTrue(
flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12
)
if __name__ == '__main__':
......
......@@ -40,7 +40,11 @@ def flops(op_type: str, input_shapes: dict, attrs: dict) -> int:
return 0
else:
func = _FLOPS_COMPUTE_FUNC_MAP[op_type]
return func(input_shapes, attrs)
try:
flops = func(input_shapes, attrs)
except Exception as e:
return 0
return flops
def register_flops(op_type):
......@@ -58,9 +62,182 @@ def register_flops(op_type):
@register_flops("dropout")
def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op.
For dropout(input):
equation: flops = 0
"""
return 0
def _elementwise_flops_compute(input_shapes, attrs):
input_x = input_shapes.get("X")[0]
input_y = input_shapes.get("Y")[0]
dim_x = len(input_x)
dim_y = len(input_y)
dim_output = max(dim_x, dim_y)
output = []
for i in range(dim_output):
in_x = input_x[dim_x - 1 - i] if i < dim_x else 1
in_y = input_y[dim_y - 1 - i] if i < dim_y else 1
output.append(max(in_x, in_y))
return prod(output)
@register_flops("elementwise_add")
def _elementwise_add_flops(input_shapes, attrs):
"""FLOPs computation for elementwise_add op.
For elementwise_add(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2, dim3 ...]
shape_of_other = [odim1, odim2, odim3...]
equation: flops = max(dim1, odim1) * max(dim2, odim2) * max()...
"""
return _elementwise_flops_compute(input_shapes, attrs)
@register_flops("elementwise_mul")
def _elementwise_mul_flops(input_shapes, attrs):
"""FLOPs computation for elementwise_mul op.
For elementwise_mul(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2, dim3 ...]
shape_of_other = [odim1, odim2, odim3...]
equation: flops = max(dim1, odim1) * max(dim2, odim2)* max()...
"""
return _elementwise_flops_compute(input_shapes, attrs)
@register_flops("elementwise_div")
def _elementwise_mul_flops(input_shapes, attrs):
"""FLOPs computation for elementwise_div op.
For elementwise_div(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2, dim3 ...]
shape_of_other = [odim1, odim2, odim3...]
equation: flops = max(dim1,odim1)*max(dim2,odim2)*max()...
"""
return _elementwise_flops_compute(input_shapes, attrs)
@register_flops("gelu")
def _gelu_flops(input_shapes, attrs):
"""FLOPs computation for gelu op.
For gelu(input):
equation: flops = 5 * (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
return prod(input) * 5
@register_flops("layer_norm")
def _layer_norm_flops(input_shapes, attrs):
"""FLOPs computation for layer_norm op.
For layer_norm(input):
equation:
1): WITHOUT epsilon flops = 7 * (numel)total number of elements in the input tensor.
2): WITH epsilon flops = 8 * (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
flops = prod(input) * 7
if attrs.get('epsilon'):
flops += prod(input)
return flops
@register_flops("matmul")
def _matmul_flops(input_shapes, attrs):
"""FLOPs computation for matmul op.
For matmul(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1,dim2 ...dim_n_1,dim_n] length:n
shape_of_other = [odim1,odim2 ... odim(n-m)... odim_m_1,dim_m] length:m
suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1)) ... dim_n_1, dim_m]
equation: flops = 2 * numel(output) * dim_n
"""
x_shape = input_shapes.get("X")[0]
y_shape = input_shapes.get("Y")[0]
if attrs.get('transpose_X') or attrs.get('transpose_x'):
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
if attrs.get('transpose_Y') or attrs.get('transpose_y'):
y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
dim_x = len(x_shape)
dim_y = len(y_shape)
output_len = max(dim_x, dim_y)
output_shape = []
for idx in range(output_len, 2, -1):
x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
output_shape.append(max(x_idx, y_idx))
macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
return 2 * macs
@register_flops("matmul_v2")
def _matmul_v2_flops(input_shapes, attrs):
"""FLOPs computation for matmul_v2 op.
For matmul_v2(input,other):
input_shapes = [shape_of_input, shape_of_ohther]
shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n
shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m
suppose n > m and dim_n = odim_m_1:
shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m]
equation: flops = 2 * numel(output) * dim_n
"""
x_shape = input_shapes.get('X')[0]
y_shape = input_shapes.get('Y')[0]
if attrs.get('trans_x') is not None:
x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1]
if attrs.get('trans_y') is not None:
y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1]
dim_x = len(x_shape)
dim_y = len(y_shape)
output_len = max(dim_x, dim_y)
output_shape = []
for idx in range(output_len, 2, -1):
x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1
y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1
output_shape.append(max(x_idx, y_idx))
macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1]
return 2 * macs
@register_flops("relu")
def _relu_flops(input_shapes, attrs):
"""FLOPs computation for relu op.
For relu(input):
equation: flops = (numel)total number of elements in the input tensor.
"""
return prod(input_shapes.get('X')[0])
@register_flops("reshape2")
def _reshape2_flops(input_shapes, attrs):
"""FLOPs computation for reshape2 op.
For reshape2(input):
equation: flops = 0
"""
return 0
@register_flops("softmax")
def _softmax_flops(input_shapes, attrs):
"""FLOPs computation for softmax op.
For softmax(input):
equation: flops = 3 * (numel)total number of elements in the input tensor.
"""
input = input_shapes.get('X')[0]
return prod(input) * 3
@register_flops("transpose2")
def _transpose2_flops(input_shapes, attrs):
"""FLOPs computation for transpose2 op.
For transpose2(input):
equation: flops = 0
"""
return 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册