提交 7e5e868d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3068 [AutoParallel]Fix EmbeddingLookup bug

Merge pull request !3068 from lichen/fix_embeddinglookup
...@@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { ...@@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
ScopePtr scope = node->scope(); ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope); MS_EXCEPTION_IF_NULL(scope);
replace_node->set_scope(scope); replace_node->set_scope(scope);
PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0));
if (prim->name() == EMBEDDING_LOOKUP) {
auto attrs = prim->attrs();
attrs[TARGET] = MakeValue(CPU);
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) { if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info()); (void)replace_node->set_operator_info(node->operator_info());
} }
......
...@@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer ...@@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer
from ..cell import Cell from ..cell import Cell
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
__all__ = ['Embedding'] __all__ = ['Embedding', 'EmbeddingLookup']
class Embedding(Cell): class Embedding(Cell):
r""" r"""
...@@ -147,7 +147,7 @@ class EmbeddingLookup(Cell): ...@@ -147,7 +147,7 @@ class EmbeddingLookup(Cell):
def construct(self, params, indices): def construct(self, params, indices):
if self.target == "CPU": if self.target == "CPU":
out = self.embeddinglookup(params, ids, 0) out = self.embeddinglookup(params, indices, 0)
else: else:
out = self.gatherv2(param, ids, 0) out = self.gatherv2(params, indices, 0)
return out return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册