diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index e888d3c09c8950e03f789d202bb1e733ba771aa1..3de9944ae2d6dee2213e2a770378003c6517a5c7 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -24,6 +24,8 @@ import paddle.fluid.layers as layers import paddle.fluid.core as core import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 +from paddle.utils.flops import flops + class TestProfiler(unittest.TestCase): @classmethod @@ -221,6 +223,12 @@ class TestProfilerAPIError(unittest.TestCase): self.assertTrue(global_profiler != prof) +class TestFLOPSAPI(unittest.TestCase): + def test_flops(self): + self.assertTrue(flops('relu', ([12, 12],), output=4) == 144) + self.assertTrue(flops('dropout', ([12, 12],), **{'output': 4}) == 0) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py new file mode 100644 index 0000000000000000000000000000000000000000..16cad95a360e7314aa3c180dd3ddb788469fd557 --- /dev/null +++ b/python/paddle/utils/flops.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from numpy import prod + +_FLOPS_COMPUTE_FUNC_MAP = {} + + +def flops(op_type: str, input_shapes: tuple, **attrs) -> int: + """ + count flops for operation. + + Args: + op_type (str): the type of operation. + input_shapes (tuple): the shapes of inputs. + attrs (dict): the attributes of the operation. + + Returns: + the total flops of the operation. + """ + + if op_type not in _FLOPS_COMPUTE_FUNC_MAP: + return 0 + else: + func = _FLOPS_COMPUTE_FUNC_MAP[op_type] + return func(input_shapes, **attrs) + + +def register_flops(op_type): + """ + register flops computation function for operation. + """ + + def register(func): + global _FLOPS_COMPUTE_FUNC_MAP + _FLOPS_COMPUTE_FUNC_MAP[op_type] = func + return func + + return register + + +@register_flops("dropout") +def _dropout_flops(input_shapes, **attrs): + return 0 + + +@register_flops("relu") +def _relu_flops(input_shapes, **attrs): + return prod(input_shapes[0])