From b6c8ef99c574cef748c39da50d3fbd127f6d2016 Mon Sep 17 00:00:00 2001 From: phlrain Date: Thu, 24 Feb 2022 09:12:14 +0000 Subject: [PATCH] update sig; test=develop --- paddle/phi/ops/compat/embedding_sig.cc | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc index 0b8473e4197..350da0b13c8 100644 --- a/paddle/phi/ops/compat/embedding_sig.cc +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -17,15 +17,29 @@ namespace phi { KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + if (ctx.IsDenseTensorInput("W")) { + LOG(ERROR) << "is dense here"; + return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } else { + LOG(ERROR) << "is selcted rows"; + return KernelSignature( + "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } } KernelSignature EmbeddingGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("embedding_grad", - {"Ids", "W", GradVarName("Out")}, - {"padding_idx"}, - {GradVarName("W")}); + if (ctx.IsDenseTensorInput("W")) { + return KernelSignature("embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("sparse_weight_embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } } } // namespace phi -- GitLab