diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc index bb708e02a2390d1232e641bd699878c67554c640..446a67d2fc4a0a78132c05bb8d27e7d0926b8d19 100644 --- a/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/add_memcpy_async.cc @@ -23,6 +23,12 @@ namespace mindspore { namespace opt { namespace { +bool InputIsParameterOrValueNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_with_index = AnfAlgo::VisitKernel(node, 0); + return kernel_with_index.first->isa() || kernel_with_index.first->isa(); +} + const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); @@ -39,7 +45,8 @@ const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, if (manager->node_users().find(input) == manager->node_users().end()) { MS_LOG(EXCEPTION) << "node has no output in manager"; } - if (manager->node_users()[input].size() > 1) { + // when input is used by others or is a parameter or is a value node, insert a memcpy_async + if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { replace = true; new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); } else {