提交 6dbb2696 编写于 作者: L lichenever

fix embeddinglookup bug

上级 bc0a53cf
......@@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
if (prim->name() == EMBEDDING_LOOKUP) {
auto attrs = prim->attrs();
attrs[TARGET] = MakeValue(CPU);
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info());
}
......
......@@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell
from ..._checkparam import Validator as validator
__all__ = ['Embedding']
__all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell):
r"""
......@@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):
def construct(self, params, indices):
if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0)
out = self.embeddinglookup(params, indices, 0)
else:
out = self.gatherv2(param, ids, 0)
out = self.gatherv2(params, indices, 0)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册