提交 aa80712a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!973 Enhance add_memcpy_async pass

Merge pull request !973 from huanghui/enhance-add-memcpy-async-pass
...@@ -23,6 +23,12 @@ ...@@ -23,6 +23,12 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
bool InputIsParameterOrValueNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_with_index = AnfAlgo::VisitKernel(node, 0);
return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>();
}
const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
...@@ -39,7 +45,8 @@ const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, ...@@ -39,7 +45,8 @@ const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph,
if (manager->node_users().find(input) == manager->node_users().end()) { if (manager->node_users().find(input) == manager->node_users().end()) {
MS_LOG(EXCEPTION) << "node has no output in manager"; MS_LOG(EXCEPTION) << "node has no output in manager";
} }
if (manager->node_users()[input].size() > 1) { // when input is used by others or is a parameter or is a value node, insert a memcpy_async
if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) {
replace = true; replace = true;
new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); new_inputs.push_back(CreateMemcpyAsyncOp(graph, input));
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册