From 6316a03c6776d3af022c3acebaf78f3c3ce304a9 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 23 Jul 2020 17:45:35 +0800 Subject: [PATCH] deal tuple getitem control for new added memcpy --- .../hccl/hccl_kernel_metadata.cc | 6 ++ .../insert_memcpy_async_for_hccl_op.cc | 59 ++++++++++++------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index 55742d383..b2283e5c3 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -20,11 +20,17 @@ #include "utils/utils.h" #include "backend/kernel_compiler/hccl/hcom_util.h" #include "backend/session/anf_runtime_algorithm.h" +#include "frontend/parallel/context.h" namespace mindspore { namespace kernel { namespace { std::string GetKernelFormat(const CNodePtr &kernel_node, size_t index) { + auto parallel_context_instance = parallel::ParallelContext::GetInstance(); + MS_EXCEPTION_IF_NULL(parallel_context_instance); + if (parallel_context_instance->enable_parallel_optimizer()) { + return kOpFormat_DEFAULT; + } const std::set kReduceNoSupportedSet = {kOpFormat_FRAC_Z, kOpFormat_FRACTAL_Z_C04, kOpFormat_C1HWNCoC0}; auto op_name = AnfAlgo::GetCNodeName(kernel_node); auto format = AnfAlgo::GetPrevNodeOutputFormat(kernel_node, index); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index d5c1da153..b0bdfd30c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -40,6 +40,38 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) { return real_node->isa(); } +void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node, + const std::vector &memcpy_async_list) { + MS_EXCEPTION_IF_NULL(control_depend); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); + make_tuple_inputs.emplace_back(hccl_node); + auto make_tuple = graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + control_depend->set_input(IntToSize(index), make_tuple); +} + +void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node, + const std::vector &memcpy_async_list) { + MS_EXCEPTION_IF_NULL(tuple_getitem); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(tuple_getitem); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); + } + } +} + void TransferControl(const CNodePtr &hccl_node, const std::vector &memcpy_async_list, const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(hccl_node); @@ -53,25 +85,13 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector &m } // find hccl_node's output which is a control depend for (const auto &node_index : iter->second) { - if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) { - continue; - } - CNodePtr control_depend = node_index.first->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - std::vector new_inputs; - for (size_t i = 0; i < control_depend->size(); ++i) { - if (i == IntToSize(node_index.second)) { - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); - make_tuple_inputs.emplace_back(hccl_node); - auto make_tuple = graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - new_inputs.push_back(make_tuple); - } else { - new_inputs.push_back(control_depend->input(i)); - } + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { + SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); + } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) { + DealControlForGetitem(output->cast(), graph, hccl_node, memcpy_async_list); } - control_depend->set_inputs(new_inputs); } } } // namespace @@ -148,11 +168,10 @@ const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_gr if (func_graph == nullptr || node == nullptr || !node->isa()) { return nullptr; } - auto cnode = node->cast(); if (!AnfAlgo::IsCommunicationOp(node)) { return nullptr; } - InsertMemcpyAsync(func_graph, cnode); + InsertMemcpyAsync(func_graph, node->cast()); return nullptr; } } // namespace opt -- GitLab