未验证 提交 08c5f4c1 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel]Add Embedding flops (#47978)

* c_embedding

* add annotations

* add annotations

* revision

* revise attrs
上级 9218e742
......@@ -297,6 +297,10 @@ class TestFLOPSAPI(unittest.TestCase):
self.assertTrue(
flops('softmax', {'X': [[12, 12, 12]]}, {}) == 3 * 12 * 12 * 12
)
self.assertTrue(
flops('c_embedding', {'Ids': [[12, 12]], 'W': [[12, 12, 3]]}, {})
== 0
)
if __name__ == '__main__':
......
......@@ -60,6 +60,15 @@ def register_flops(op_type):
return register
@register_flops("c_embedding")
def _c_embedding_flops(input_shapes, attrs):
"""FLOPs computation for c_embedding op.
For c_embedding(input):
equation: flops = 0
"""
return 0
@register_flops("dropout")
def _dropout_flops(input_shapes, attrs):
"""FLOPs computation for dropout op.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册