From 6dbb26967e874cdd9ef6c385dfc216270dd678c0 Mon Sep 17 00:00:00 2001 From: lichenever Date: Wed, 15 Jul 2020 09:26:52 +0800 Subject: [PATCH] fix embeddinglookup bug --- mindspore/ccsrc/frontend/parallel/step_parallel.cc | 6 ++++++ mindspore/nn/layer/embedding.py | 6 +++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index e9ff347fa..6b9cfd9d3 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -611,6 +611,12 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { ScopePtr scope = node->scope(); MS_EXCEPTION_IF_NULL(scope); replace_node->set_scope(scope); + PrimitivePtr prim = GetValueNode(replace_node->input(0)); + if (prim->name() == EMBEDDING_LOOKUP) { + auto attrs = prim->attrs(); + attrs[TARGET] = MakeValue(CPU); + (void)prim->SetAttrs(attrs); + } if (index == replace_op.size() - 1) { (void)replace_node->set_operator_info(node->operator_info()); } diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index a0887886a..3c4245d70 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -21,7 +21,7 @@ from mindspore.common.initializer import initializer from ..cell import Cell from ..._checkparam import Validator as validator -__all__ = ['Embedding'] +__all__ = ['Embedding', 'EmbeddingLookup'] class Embedding(Cell): r""" @@ -147,7 +147,7 @@ class EmbeddingLookup(Cell): def construct(self, params, indices): if self.target == "CPU": - out = self.embeddinglookup(params, ids, 0) + out = self.embeddinglookup(params, indices, 0) else: - out = self.gatherv2(param, ids, 0) + out = self.gatherv2(params, indices, 0) return out -- GitLab