diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc index 0b8473e419705d665cb94f47453e2134be4e99e7..350da0b13c8ae454d89c3c558b3f16c2fd476bd4 100644 --- a/paddle/phi/ops/compat/embedding_sig.cc +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -17,15 +17,29 @@ namespace phi { KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + if (ctx.IsDenseTensorInput("W")) { + LOG(ERROR) << "is dense here"; + return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } else { + LOG(ERROR) << "is selcted rows"; + return KernelSignature( + "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } } KernelSignature EmbeddingGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("embedding_grad", - {"Ids", "W", GradVarName("Out")}, - {"padding_idx"}, - {GradVarName("W")}); + if (ctx.IsDenseTensorInput("W")) { + return KernelSignature("embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("sparse_weight_embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } } } // namespace phi