提交 196fdbe1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4280 adapt input to attr between cpu and aicpu embeddinglookup

Merge pull request !4280 from wuxuejian/embedding_input_adapt
......@@ -50,6 +50,12 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
if (!ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(AnfAlgo::GetCNodeName(cnode), &reg)) {
continue;
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookup->name() ||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimEmbeddingLookupCommGrad->name()) {
if (!AnfAlgo::HasNodeAttr(kAttrPrimitiveTarget, cnode)) {
continue;
}
}
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
}
return node;
......
......@@ -26,7 +26,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class Net(nn.Cell):
def __init__(self, offset):
super(Net, self).__init__()
self.embedding = P.EmbeddingLookup()
self.embedding = P.EmbeddingLookup().add_prim_attr("primitive_target", "CPU")
self.offset = offset
def construct(self, param, index):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册