From f1563d2d37cf2f6bcc3400a94dec0b249aff84ff Mon Sep 17 00:00:00 2001 From: huanghui Date: Tue, 21 Jul 2020 20:37:41 +0800 Subject: [PATCH] insert memcpy async if hccl op cascade --- .../ascend/ascend_backend_optimization.cc | 2 + .../insert_memcpy_async_for_cascade.cc | 114 ++++++++++++++++++ .../insert_memcpy_async_for_cascade.h | 39 ++++++ .../insert_memcpy_async_for_hccl_op.cc | 97 ++++++++------- .../insert_memcpy_async_for_hccl_op.h | 2 +- .../insert_memcpy_async_for_hccl_op_test.cc | 49 +++++++- .../insert_memcpy_async_for_hccl_op.py | 43 +++++-- 7 files changed, 288 insertions(+), 58 deletions(-) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index d88727df4..67ebf57f2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -87,6 +87,7 @@ #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" @@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); optimizer->AddPassManager(other_pm); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc new file mode 100644 index 000000000..0f1946926 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" +#include +#include +#include +#include "utils/utils.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "frontend/optimizer/opt.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +namespace { +bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(cur_hccl); + MS_EXCEPTION_IF_NULL(graph); + if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem); + MS_EXCEPTION_IF_NULL(prev_node); + if (!AnfAlgo::IsCommunicationOp(prev_node)) { + return false; + } + auto prev_hccl_op = prev_node->cast(); + MS_EXCEPTION_IF_NULL(prev_hccl_op); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto iter = node_users.find(prev_hccl_op); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + for (const auto &node_index : iter->second) { + AnfNodePtr output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + bool is_contain = false; + for (size_t i = 1; i < cur_hccl->size(); ++i) { + if (cur_hccl->input(i) == output) { + is_contain = true; + break; + } + } + if (!is_contain) { + return true; + } + } + } + return false; +} +} // namespace + +AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(hccl_node); + std::vector memcpy_async_list; + std::vector new_inputs = {hccl_node->input(0)}; + for (size_t i = 1; i < hccl_node->size(); ++i) { + auto input = hccl_node->input(i); + MS_EXCEPTION_IF_NULL(input); + // when input is also a hccl op and just part outputs of it linking with cur_hccl_op + if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { + auto memcpy_async = CreateMemcpyAsyncOp(graph, input); + auto kernel_info = std::make_shared(); + memcpy_async->set_kernel_info(kernel_info); + MS_EXCEPTION_IF_NULL(kernel_select_); + kernel_select_->SelectKernel(memcpy_async->cast()); + new_inputs.push_back(memcpy_async); + memcpy_async_list.push_back(memcpy_async); + } else { + new_inputs.push_back(input); + } + } + + if (!memcpy_async_list.empty()) { + CNodePtr new_hccl_node = std::make_shared(*hccl_node); + new_hccl_node->set_inputs(new_inputs); + return new_hccl_node; + } + return nullptr; +} + +const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (func_graph == nullptr || node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + if (!AnfAlgo::IsCommunicationOp(node)) { + return nullptr; + } + return InsertMemcpyAsync(func_graph, cnode); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h new file mode 100644 index 000000000..e1a29f574 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertMemcpyAsyncForCascade : public PatternProcessPass { + public: + explicit InsertMemcpyAsyncForCascade(bool multigraph = true) + : PatternProcessPass("insert_memcpy_async_for_cascade", multigraph), + kernel_select_(std::make_shared()) {} + ~InsertMemcpyAsyncForCascade() override = default; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index 2585006be..d5c1da153 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -32,12 +32,17 @@ const std::set kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe bool IsParameterOrValueNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); - return kernel_with_index.first->isa() || kernel_with_index.first->isa(); + auto real_node = kernel_with_index.first; + MS_EXCEPTION_IF_NULL(real_node); + if (real_node->isa()) { + return true; + } + return real_node->isa(); } -void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { +void TransferControl(const CNodePtr &hccl_node, const std::vector &memcpy_async_list, + const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(hccl_node); - MS_EXCEPTION_IF_NULL(memcpy_async); MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, } // find hccl_node's output which is a control depend for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - int output_index = node_index.second; - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - CNodePtr control_depend = output->cast(); - MS_EXCEPTION_IF_NULL(control_depend); - std::vector new_inputs; - for (size_t i = 0; i < control_depend->size(); ++i) { - if (i == IntToSize(output_index)) { - new_inputs.push_back(memcpy_async); - } else { - new_inputs.push_back(control_depend->input(i)); - } + if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) { + continue; + } + CNodePtr control_depend = node_index.first->cast(); + MS_EXCEPTION_IF_NULL(control_depend); + std::vector new_inputs; + for (size_t i = 0; i < control_depend->size(); ++i) { + if (i == IntToSize(node_index.second)) { + std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); + make_tuple_inputs.emplace_back(hccl_node); + auto make_tuple = graph->NewCNode(make_tuple_inputs); + MS_EXCEPTION_IF_NULL(make_tuple); + new_inputs.push_back(make_tuple); + } else { + new_inputs.push_back(control_depend->input(i)); } - control_depend->set_inputs(new_inputs); } + control_depend->set_inputs(new_inputs); } } } // namespace -bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { +bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, + const CNodePtr &cur_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(cur_node); // when input is a parameter or is a value node if (IsParameterOrValueNode(input)) { return true; } - // when input is a Ref or some special cnodes - if (kernel_query_->IsTbeRef(input) || - kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { - return true; - } + if (input->isa()) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(input); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager"; - } - // when input is used by others - if (iter->second.size() > 1) { - return true; + // when input is a Ref cnode + if (kernel_query_->IsTbeRef(input)) { + return true; + } + + // when input is some special cnodes + if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { + return true; + } + + // when input is used by others + auto iter = node_users.find(input); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "node has no output in manager"; + } + if (iter->second.size() > 1) { + return true; + } } return false; } @@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(hccl_node); - bool has_insert_memcpy = false; - AnfNodePtr memcpy_async = nullptr; + std::vector memcpy_async_list; std::vector new_inputs = {hccl_node->input(0)}; for (size_t i = 1; i < hccl_node->size(); ++i) { auto input = hccl_node->input(i); - if (NeedInsertMemcpy(graph, input)) { - memcpy_async = CreateMemcpyAsyncOp(graph, input); - has_insert_memcpy = true; + if (NeedInsertMemcpy(graph, input, hccl_node)) { + auto memcpy_async = CreateMemcpyAsyncOp(graph, input); new_inputs.push_back(memcpy_async); + memcpy_async_list.push_back(memcpy_async); } else { new_inputs.push_back(input); } } - if (has_insert_memcpy) { + if (!memcpy_async_list.empty()) { CNodePtr new_hccl_node = std::make_shared(*hccl_node); new_hccl_node->set_inputs(new_inputs); auto manager = graph->manager(); @@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co MS_LOG(DEBUG) << "end replace"; // transer hccl op's control to the memcpy_async - if (hccl_node->size() == 2) { - TransferControl(new_hccl_node, memcpy_async, graph); - } + TransferControl(new_hccl_node, memcpy_async_list, graph); } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h index 35e5f1b82..e69866c0b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h @@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { private: void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; - bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; + bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const; KernelQueryPtr kernel_query_; }; } // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc index 103d0f21a..e3441d45e 100644 --- a/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc @@ -22,6 +22,7 @@ #include "utils/utils.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "backend/optimizer/common/optimizer.h" +#include "ir/param_value.h" #define private public #define protected public #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" @@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { ~MockInsertMemcpyForHcclKernelQuery() override = default; bool IsTbeRef(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - if (cnode == nullptr) { + if (!node->isa()) { return false; } - auto name = AnfAlgo::GetCNodeName(cnode); - return name == "ApplyMomentum"; + return AnfAlgo::GetCNodeName(node->cast()) == "ApplyMomentum"; } }; @@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { AbstractBasePtrList args_spec_list{x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); EXPECT_NE(kg, nullptr); + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + param->set_default_param(std::make_shared()); + } auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { ASSERT_TRUE(g != nullptr); std::vector shp_x{1, 64, 112, 112}; auto x_abstract = std::make_shared(kFloat32, shp_x); - AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; + AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; auto kg = GetKernelGraph(g, args_spec_list); EXPECT_NE(kg, nullptr); + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + param->set_default_param(std::make_shared()); + } + auto optimizer = std::make_shared(); auto pm = std::make_shared(); auto pass = std::make_shared(); @@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); } + +TEST_F(TestHWInsertMemcpyForHccl, test_cond5) { + get_py_fun_.SetDoResolve(true); + FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before"); + ASSERT_TRUE(g != nullptr); + std::vector shp_x{1, 64, 112, 112}; + auto x_abstract = std::make_shared(kFloat32, shp_x); + AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract}; + auto kg = GetKernelGraph(g, args_spec_list); + EXPECT_NE(kg, nullptr); + + for (auto p : kg->parameters()) { + auto param = p->cast(); + EXPECT_NE(param, nullptr); + param->set_default_param(std::make_shared()); + } + + auto optimizer = std::make_shared(); + auto pm = std::make_shared(); + auto pass = std::make_shared(); + pass->kernel_query_ = std::make_shared(); + pm->AddPass(pass); + optimizer->AddPassManager(pm); + auto new_graph = optimizer->Optimize(kg); + kg->SetExecOrderByDefault(); + + FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after"); + EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py index 7ffcfd057..082c8144f 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py @@ -17,6 +17,7 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P all_reduce = P.AllReduce() +broadcast = P.Broadcast(1) memcpy_async = Primitive('memcpy_async') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') @@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): fns = FnDict() @fns - def before(a, b, c, d, e): - res1 = apply_momentun(a, b, c, d, e) - res2 = all_reduce(a) - res = control_depend(res1, res2) - res = make_tuple(res, res2) + def before(a, b): + x = relu(a) + y = all_reduce(b) + res = control_depend(x, y) return res @fns - def after(a, b, c, d, e): - res1 = apply_momentun(a, b, c, d, e) - res2 = memcpy_async(a) - res3 = all_reduce(res2) - res = control_depend(res1, res2) - res = make_tuple(res, res3) + def after(a, b): + x = relu(a) + y1 = memcpy_async(b) + y2 = all_reduce(y1) + res = control_depend(x, make_tuple(y1, y2)) + return make_tuple(res) + + return fns[tag] + + +def test_insert_memcpy_async_for_hccl_op_cond5(tag): + fns = FnDict() + + @fns + def before(a, b, c): + x = relu(a) + y = broadcast((b, c)) + res = control_depend(x, y) + return res + + @fns + def after(a, b, c): + x = relu(a) + m1 = memcpy_async(b) + m2 = memcpy_async(c) + y = broadcast(m1, m2) + res = control_depend(x, make_tuple(m1, m2, y)) return make_tuple(res) return fns[tag] -- GitLab