diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index e248a4f7f5b1dd6a01444597672896084da90a4b..5fbedfaaa7ff0c3c33476141b0233d509d9d39f5 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -297,6 +297,10 @@ class TestFLOPSAPI(unittest.TestCase): self.assertTrue( flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12 ) + self.assertTrue( + flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {}) + == 0 + ) if __name__ == '__main__': diff --git a/python/paddle/utils/flops.py b/python/paddle/utils/flops.py index cfcdf940569fae9e4c16dfed568d882bf67ab9c4..114ca6d9ab6c77bc1676b374e91c58b2b154986d 100644 --- a/python/paddle/utils/flops.py +++ b/python/paddle/utils/flops.py @@ -60,6 +60,15 @@ def register_flops(op_type): return register +@register_flops("c_embedding") +def _c_embedding_flops(input_shapes, attrs): + """FLOPs computation for c_embedding op. + For c_embedding(input): + equation: flops = 0 + """ + return 0 + + @register_flops("dropout") def _dropout_flops(input_shapes, attrs): """FLOPs computation for dropout op.