diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index b58efe6951924c270379ba530c17f6e3b03142a5..011d9fc4dfd0596a2004c7642ad74eaeb1c08757 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -227,6 +227,75 @@ class TestFLOPSAPI(unittest.TestCase): def test_flops(self): self.assertTrue(flops('relu', {'X': [[12, 12]]}, {'output': 4}) == 144) self.assertTrue(flops('dropout', {}, {'output': 4}) == 0) + self.assertTrue( + flops( + 'transpose2', + { + 'X': [[12, 12, 12]], + }, + {}, + ) + == 0 + ) + self.assertTrue( + flops( + 'reshape2', + { + 'X': [[12, 12, 12]], + }, + {}, + ) + == 0 + ) + self.assertTrue( + flops( + 'unsqueeze2', + { + 'X': [[12, 12, 12]], + }, + {}, + ) + == 0 + ) + self.assertTrue( + flops( + 'layer_norm', + {'Bias': [[128]], 'Scale': [[128]], 'X': [[32, 128, 28, 28]]}, + {'epsilon': 0.01}, + ) + == 32 * 128 * 28 * 28 * 8 + ) + self.assertTrue( + flops( + 'elementwise_add', {'X': [[12, 12, 12]], 'Y': [[2, 2, 12]]}, {} + ) + == 12 * 12 * 12 + ) + self.assertTrue( + flops('gelu', {'X': [[12, 12, 12]]}, {}) == 5 * 12 * 12 * 12 + ) + self.assertTrue( + flops( + 'matmul', + {'X': [[3, 12, 12, 8]], 'Y': [[12, 12, 8]]}, + {'transpose_X': False, 'transpose_Y': True}, + ) + == 3 * 12 * 12 * 12 * 2 * 8 + ) + self.assertTrue( + flops( + 'matmul_v2', + {'X': [[3, 12, 12, 8]], 'Y': [[12, 12, 8]]}, + {'trans_x': False, 'trans_y': True}, + ) + == 3 * 12 * 12 * 12 * 2 * 8 + ) + self.assertTrue( + flops('relu', {'X': [[12, 12, 12]]}, {}) == 12 * 12 * 12 + ) + self.assertTrue( + flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12 + ) if __name__ == '__main__': diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py index 4dfe200f59be90aff0b4b1d163841ed531c885ca..cfcdf940569fae9e4c16dfed568d882bf67ab9c4 100644 --- a/python/paddle/utils/flops.py +++ b/python/paddle/utils/flops.py @@ -40,7 +40,11 @@ def flops(op_type: str, input_shapes: dict, attrs: dict) -> int: return 0 else: func = _FLOPS_COMPUTE_FUNC_MAP[op_type] - return func(input_shapes, attrs) + try: + flops = func(input_shapes, attrs) + except Exception as e: + return 0 + return flops def register_flops(op_type): @@ -58,9 +62,182 @@ def register_flops(op_type): @register_flops("dropout") def _dropout_flops(input_shapes, attrs): + """FLOPs computation for dropout op. + For dropout(input): + equation: flops = 0 + """ return 0 +def _elementwise_flops_compute(input_shapes, attrs): + input_x = input_shapes.get("X")[0] + input_y = input_shapes.get("Y")[0] + dim_x = len(input_x) + dim_y = len(input_y) + dim_output = max(dim_x, dim_y) + output = [] + for i in range(dim_output): + in_x = input_x[dim_x - 1 - i] if i < dim_x else 1 + in_y = input_y[dim_y - 1 - i] if i < dim_y else 1 + output.append(max(in_x, in_y)) + return prod(output) + + +@register_flops("elementwise_add") +def _elementwise_add_flops(input_shapes, attrs): + """FLOPs computation for elementwise_add op. + For elementwise_add(input,other): + input_shapes = [shape_of_input, shape_of_ohther] + shape_of_input = [dim1, dim2, dim3 ...] + shape_of_other = [odim1, odim2, odim3...] + equation: flops = max(dim1, odim1) * max(dim2, odim2) * max()... + """ + return _elementwise_flops_compute(input_shapes, attrs) + + +@register_flops("elementwise_mul") +def _elementwise_mul_flops(input_shapes, attrs): + """FLOPs computation for elementwise_mul op. + For elementwise_mul(input,other): + input_shapes = [shape_of_input, shape_of_ohther] + shape_of_input = [dim1, dim2, dim3 ...] + shape_of_other = [odim1, odim2, odim3...] + equation: flops = max(dim1, odim1) * max(dim2, odim2)* max()... + """ + return _elementwise_flops_compute(input_shapes, attrs) + + +@register_flops("elementwise_div") +def _elementwise_mul_flops(input_shapes, attrs): + """FLOPs computation for elementwise_div op. + For elementwise_div(input,other): + input_shapes = [shape_of_input, shape_of_ohther] + shape_of_input = [dim1, dim2, dim3 ...] + shape_of_other = [odim1, odim2, odim3...] + equation: flops = max(dim1,odim1)*max(dim2,odim2)*max()... + """ + return _elementwise_flops_compute(input_shapes, attrs) + + +@register_flops("gelu") +def _gelu_flops(input_shapes, attrs): + """FLOPs computation for gelu op. + For gelu(input): + equation: flops = 5 * (numel)total number of elements in the input tensor. + """ + input = input_shapes.get('X')[0] + return prod(input) * 5 + + +@register_flops("layer_norm") +def _layer_norm_flops(input_shapes, attrs): + """FLOPs computation for layer_norm op. + For layer_norm(input): + equation: + 1): WITHOUT epsilon flops = 7 * (numel)total number of elements in the input tensor. + 2): WITH epsilon flops = 8 * (numel)total number of elements in the input tensor. + """ + input = input_shapes.get('X')[0] + flops = prod(input) * 7 + if attrs.get('epsilon'): + flops += prod(input) + return flops + + +@register_flops("matmul") +def _matmul_flops(input_shapes, attrs): + """FLOPs computation for matmul op. + For matmul(input,other): + input_shapes = [shape_of_input, shape_of_ohther] + shape_of_input = [dim1,dim2 ...dim_n_1,dim_n] length:n + shape_of_other = [odim1,odim2 ... odim(n-m)... odim_m_1,dim_m] length:m + suppose n > m and dim_n = odim_m_1: + shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1)) ... dim_n_1, dim_m] + equation: flops = 2 * numel(output) * dim_n + """ + x_shape = input_shapes.get("X")[0] + y_shape = input_shapes.get("Y")[0] + if attrs.get('transpose_X') or attrs.get('transpose_x'): + x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] + + if attrs.get('transpose_Y') or attrs.get('transpose_y'): + y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1] + dim_x = len(x_shape) + dim_y = len(y_shape) + output_len = max(dim_x, dim_y) + output_shape = [] + + for idx in range(output_len, 2, -1): + x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1 + y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1 + output_shape.append(max(x_idx, y_idx)) + + macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1] + return 2 * macs + + +@register_flops("matmul_v2") +def _matmul_v2_flops(input_shapes, attrs): + """FLOPs computation for matmul_v2 op. + For matmul_v2(input,other): + input_shapes = [shape_of_input, shape_of_ohther] + shape_of_input = [dim1, dim2 ...dim_n_1, dim_n] length:n + shape_of_other = [odim1, odim2 ... odim(n-m) ... odim_m_1, dim_m] length:m + suppose n > m and dim_n = odim_m_1: + shape_of_output = [dim1, dim2 ... max(dim(n-m), odim(n-m)), max(dim(n-m+1), odim(n-m+1))...dim_n_1, dim_m] + equation: flops = 2 * numel(output) * dim_n + """ + x_shape = input_shapes.get('X')[0] + y_shape = input_shapes.get('Y')[0] + if attrs.get('trans_x') is not None: + x_shape[-1], x_shape[-2] = x_shape[-2], x_shape[-1] + if attrs.get('trans_y') is not None: + y_shape[-1], y_shape[-2] = y_shape[-2], y_shape[-1] + dim_x = len(x_shape) + dim_y = len(y_shape) + output_len = max(dim_x, dim_y) + output_shape = [] + for idx in range(output_len, 2, -1): + x_idx = x_shape[dim_x - idx] if idx <= dim_x else 1 + y_idx = y_shape[dim_y - idx] if idx <= dim_y else 1 + output_shape.append(max(x_idx, y_idx)) + + macs = prod(output_shape) * x_shape[-2] * x_shape[-1] * y_shape[-1] + return 2 * macs + + @register_flops("relu") def _relu_flops(input_shapes, attrs): + """FLOPs computation for relu op. + For relu(input): + equation: flops = (numel)total number of elements in the input tensor. + """ return prod(input_shapes.get('X')[0]) + + +@register_flops("reshape2") +def _reshape2_flops(input_shapes, attrs): + """FLOPs computation for reshape2 op. + For reshape2(input): + equation: flops = 0 + """ + return 0 + + +@register_flops("softmax") +def _softmax_flops(input_shapes, attrs): + """FLOPs computation for softmax op. + For softmax(input): + equation: flops = 3 * (numel)total number of elements in the input tensor. + """ + input = input_shapes.get('X')[0] + return prod(input) * 3 + + +@register_flops("transpose2") +def _transpose2_flops(input_shapes, attrs): + """FLOPs computation for transpose2 op. + For transpose2(input): + equation: flops = 0 + """ + return 0