提交 b6c8ef99 编写于 作者: P phlrain

update sig; test=develop

上级 e037504b
......@@ -17,15 +17,29 @@
namespace phi {
KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
if (ctx.IsDenseTensorInput("W")) {
LOG(ERROR) << "is dense here";
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
} else {
LOG(ERROR) << "is selcted rows";
return KernelSignature(
"sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
}
}
KernelSignature EmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
if (ctx.IsDenseTensorInput("W")) {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册