From 08c5f4c18d2cd4e6485eadc7b9791a70f8e16932 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 1 Dec 2022 09:38:53 +0800 Subject: [PATCH] [Auto Parallel]Add Embedding flops (#47978) * c_embedding * add annotations * add annotations * revision * revise attrs --- python/paddle/fluid/tests/unittests/test_profiler.py | 4 ++++ python/paddle/utils/flops.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index e248a4f7f5..5fbedfaaa7 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -297,6 +297,10 @@ class TestFLOPSAPI(unittest.TestCase): self.assertTrue( flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12 ) + self.assertTrue( + flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {}) + == 0 + ) if __name__ == '__main__': diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py index cfcdf94056..114ca6d9ab 100644 --- a/python/paddle/utils/flops.py +++ b/python/paddle/utils/flops.py @@ -60,6 +60,15 @@ def register_flops(op_type): return register +@register_flops("c_embedding") +def _c_embedding_flops(input_shapes, attrs): + """FLOPs computation for c_embedding op. + For c_embedding(input): + equation: flops = 0 + """ + return 0 + + @register_flops("dropout") def _dropout_flops(input_shapes, attrs): """FLOPs computation for dropout op. -- GitLab