提交 b6c8ef99 编写于 作者: P phlrain

update sig; test=develop

上级 e037504b
...@@ -17,15 +17,29 @@ ...@@ -17,15 +17,29 @@
namespace phi { namespace phi {
KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) {
LOG(ERROR) << "is dense here";
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
} else {
LOG(ERROR) << "is selcted rows";
return KernelSignature(
"sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
}
} }
KernelSignature EmbeddingGradOpArgumentMapping( KernelSignature EmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) {
return KernelSignature("embedding_grad", return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")}, {"Ids", "W", GradVarName("Out")},
{"padding_idx"}, {"padding_idx"},
{GradVarName("W")}); {GradVarName("W")});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册