未验证 提交 0668650f 编写于 作者: G Ghost Screaming 提交者: GitHub

Add c_embedding forward compat op. (#56377)

* Add c_embedding forward compat op.

* Fix some bugs.

* Polish code style.
上级 1f94081d
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
namespace phi { namespace phi {
KernelSignature CEmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("c_embedding", {"W", "Ids"}, {"start_index"}, {"Out"});
}
KernelSignature CEmbeddingGradOpArgumentMapping( KernelSignature CEmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
...@@ -23,8 +26,9 @@ KernelSignature CEmbeddingGradOpArgumentMapping( ...@@ -23,8 +26,9 @@ KernelSignature CEmbeddingGradOpArgumentMapping(
{"start_index"}, {"start_index"},
{"W@GRAD"}); {"W@GRAD"});
} }
} // namespace phi } // namespace phi
PD_REGISTER_ARG_MAPPING_FN(c_embedding, phi::CEmbeddingOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(c_embedding_grad, PD_REGISTER_ARG_MAPPING_FN(c_embedding_grad,
phi::CEmbeddingGradOpArgumentMapping); phi::CEmbeddingGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册