diff --git a/paddle/phi/ops/compat/c_embedding_sig.cc b/paddle/phi/ops/compat/c_embedding_sig.cc index bed568433ca86eefcdb86c1b5bf37518ac9fa1d8..5287bd91d8068773a09560e4c55372125b00b2a0 100644 --- a/paddle/phi/ops/compat/c_embedding_sig.cc +++ b/paddle/phi/ops/compat/c_embedding_sig.cc @@ -15,6 +15,9 @@ #include "paddle/phi/core/compat/op_utils.h" namespace phi { +KernelSignature CEmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("c_embedding", {"W", "Ids"}, {"start_index"}, {"Out"}); +} KernelSignature CEmbeddingGradOpArgumentMapping( const ArgumentMappingContext& ctx) { @@ -23,8 +26,9 @@ KernelSignature CEmbeddingGradOpArgumentMapping( {"start_index"}, {"W@GRAD"}); } - } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(c_embedding, phi::CEmbeddingOpArgumentMapping); + PD_REGISTER_ARG_MAPPING_FN(c_embedding_grad, phi::CEmbeddingGradOpArgumentMapping);