From 4cffb0a321daef861fa63fccd1a3e613e085e961 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Wed, 3 Jun 2020 09:37:47 +0800 Subject: [PATCH] New control sink support dynamic loss scale Signed-off-by: zhoufeng --- .../convert_tuple_input_to_dynamic_input.cc | 2 +- .../pass/convert_tuple_output_to_maketuple.cc | 5 +- .../ccsrc/session/ascend_control_parser.cc | 133 +++++++++++++++--- .../ccsrc/session/ascend_control_parser.h | 3 - mindspore/ccsrc/session/ascend_session.cc | 11 +- mindspore/ccsrc/session/kernel_graph.cc | 22 +-- mindspore/ccsrc/session/kernel_graph.h | 2 +- mindspore/ccsrc/session/session_basic.cc | 5 +- mindspore/ccsrc/utils/union_find_set.h | 85 +++++++++++ 9 files changed, 211 insertions(+), 57 deletions(-) create mode 100644 mindspore/ccsrc/utils/union_find_set.h diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc index ccc4fd526..ab2395b1f 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_tuple_input_to_dynamic_input.cc @@ -69,7 +69,7 @@ CNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNo MS_EXCEPTION_IF_NULL(cnode); auto inputs = cnode->inputs(); (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(plant_inputs)); - } else if (AnfAlgo::IsTupleOutput(input_node)) { + } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); } else { dyn_input_sizes.push_back(-1); diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc index 66b3dc1d8..c6a53c544 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc @@ -68,8 +68,9 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { return nullptr; } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), - [](const AnfNodePtr &node) { return AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); })) { + if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { + return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); + })) { return ConvertTupleInputToMakeTuple(func_graph, cnode); } return nullptr; diff --git a/mindspore/ccsrc/session/ascend_control_parser.cc b/mindspore/ccsrc/session/ascend_control_parser.cc index 949b1af2a..55dfbcbb3 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/session/ascend_control_parser.cc @@ -18,6 +18,7 @@ #include #include "session/ascend_control_parser.h" #include "session/anf_runtime_algorithm.h" +#include "utils/union_find_set.h" static constexpr size_t kCNodePrim = 0; static constexpr size_t kCNodeCallArg = 1; @@ -57,6 +58,110 @@ void AscendControlParser::ChildGraphDataAssign(const std::map kg, const NotNull *> union_find_set, + const NotNull *> memo) { + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + const std::map> &real_inputs = kg->real_inputs(); + for (auto &iter : real_inputs) { + auto ¶ = iter.first; + if (para->isa()) { + union_find_set->Add(para); + } + for (auto &arg : iter.second) { + if (!arg->isa()) { + continue; + } + union_find_set->Add(arg); + } + } + for (auto &child : kg->child_graph_order()) { + InitUnionFindSet(NOT_NULL(child), union_find_set, memo); + } +} + +static void UnionParentParameter(NotNull kg, const NotNull *> union_find_set, + const NotNull *> memo) { + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + const std::map> &real_inputs = kg->real_inputs(); + for (auto &iter : real_inputs) { + auto ¶ = iter.first; + for (auto &arg : iter.second) { + if (!arg->isa()) { + continue; + } + union_find_set->Union(arg, para); + } + } + for (auto &child : kg->child_graph_order()) { + UnionParentParameter(NOT_NULL(child), union_find_set, memo); + } +} + +static UnionFindSet MakeUnionFindSet(NotNull root_kg) { + UnionFindSet result; + std::set memo; + InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); + memo.clear(); + UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); + return result; +} + +static void RecursiveReplaceNode(NotNull kg, NotNull main_parameter, + const std::set ¶meter_reuse_set, + const NotNull *> memo) { + if (parameter_reuse_set.empty()) { + MS_LOG(EXCEPTION) << "parameter_reuse_set is empty."; + } + if (memo->find(kg.get()) != memo->end()) { + return; + } + memo->insert(kg.get()); + + for (auto ¶ : parameter_reuse_set) { + if (para == main_parameter.get()) { + continue; + } + MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " + << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); + kg->ReplaceNode(NOT_NULL(para), main_parameter); + } + + for (auto &child : kg->child_graph_order()) { + RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); + } +} + +static void ReuseParameter(NotNull root_kg, NotNull *> parameter_set) { + auto parameter_reuse_sets = parameter_set->GetSets(); + for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { + if (parameter_reuse_set.size() <= 1) { + continue; + } + + AnfNodePtr main_parameter = key; + std::set root_inputs_set; + const auto &root_inputs_vector = root_kg->inputs(); + root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); + for (auto &node : parameter_reuse_set) { + if (root_inputs_set.find(node) == root_inputs_set.end()) { + continue; + } + + main_parameter = node; + } + + std::set memo; + RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); + } +} + void AscendControlParser::LinkGraph(NotNull kg) { std::set memo; ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); @@ -68,6 +173,11 @@ void AscendControlParser::LinkGraph(NotNull kg) { } graph_id_map[g->graph_id()] = g; } + // Make UnionFindSet + UnionFindSet parameter_set = MakeUnionFindSet(kg); + // Reuse Parameter + ReuseParameter(kg, NOT_NULL(¶meter_set)); + // Insert Assign ChildGraphDataAssign(graph_id_map); } @@ -324,29 +434,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull kg, NotNul InsertDependToGraph(kg, NOT_NULL(assign_node)); } -void AscendControlParser::LinkArgsToParam(NotNull to_graph, NotNull target_graph, - NotNull arg, NotNull param) { - if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { - MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple"; - CNodePtr cnode_arg = arg.get()->cast(); - CNodePtr cnode_param = param.get()->cast(); - MS_EXCEPTION_IF_NULL(cnode_arg); - MS_EXCEPTION_IF_NULL(cnode_param); - if (cnode_arg->size() != cnode_param->size()) { - MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param " - << param->DebugString() << " size " << cnode_param->size(); - } - - for (size_t i = 1; i < cnode_param->size(); ++i) { - LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i))); - } - } else if (arg->isa()) { - InsertAssignToGraph(target_graph, arg, param); - } else { - MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type."; - } -} - void AscendControlParser::ExecutorValidate(NotNull root_graph) { std::set memo; (void)RecurseGraph(root_graph, NOT_NULL(&memo)); diff --git a/mindspore/ccsrc/session/ascend_control_parser.h b/mindspore/ccsrc/session/ascend_control_parser.h index 037077766..cee3816a6 100644 --- a/mindspore/ccsrc/session/ascend_control_parser.h +++ b/mindspore/ccsrc/session/ascend_control_parser.h @@ -52,9 +52,6 @@ class AscendControlParser { const CNodePtr &last_label); static std::tuple ParsePartial(NotNull node); - static void LinkArgsToParam(NotNull to_graph, NotNull target_graph, - NotNull arg, NotNull param); - static void InsertAssignToGraph(NotNull kg, NotNull from, NotNull to); static CNodePtr GetNextRealKernel(const std::vector &list, size_t start); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index c23763b2b..0665a6f76 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -224,14 +224,6 @@ static void BindCallArgsWithParameter(const std::vector ¶meters, MS_LOG(INFO) << "Parameter and arg are same"; continue; } - // if arg is a parameter ,then reuse this parameter - if (args[i]->isa()) { - MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() - << " reuse parameter:" << args[i]->DebugString() - << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); - child_graph->ReplaceNode(parameters[i], args[i]); - continue; - } child_graph->SetRealInput(parameters[i], args[i]); } } @@ -412,7 +404,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vectorexecutable()) { @@ -1134,7 +1125,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId MS_EXCEPTION_IF_NULL(backend_arg); MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() << "] will be replaced."; - to_graph->ReplaceNode(backend_parameter, backend_arg); + to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); return; } MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 99d53a7f2..8fa29ae20 100644 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -587,9 +587,7 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { return false; } -void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) { - MS_EXCEPTION_IF_NULL(old_anf_node); - MS_EXCEPTION_IF_NULL(new_anf_node); +void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNull new_anf_node) { MS_EXCEPTION_IF_NULL(inputs_); auto it = node_output_edges_.find(old_anf_node); if (it != node_output_edges_.end()) { @@ -604,16 +602,16 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf continue; } for (size_t i = 1; i < output_node_inputs.size(); i++) { - if (output_node_inputs[i] == old_anf_node) { + if (output_node_inputs[i] == old_anf_node.get()) { output_cnode->set_input(i, new_anf_node); } } // update graph inputs for (size_t i = 0; i < inputs_->size(); i++) { - if ((*inputs_)[i] == old_anf_node) { + if ((*inputs_)[i] == old_anf_node.get()) { MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() << ",new graph input:" << new_anf_node->DebugString(); - (*inputs_)[i] = new_anf_node; + (*inputs_)[i] = new_anf_node.get(); break; } } @@ -621,7 +619,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); // update output depend relations - node_output_edges_[new_anf_node] = it->second; + node_output_edges_[new_anf_node.get()] = it->second; (void)node_output_edges_.erase(old_anf_node); } // update graph inputs in child graph @@ -633,7 +631,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; iter->second = it_real_inputs->second; } else { - real_inputs_[new_anf_node] = it_real_inputs->second; + real_inputs_[new_anf_node.get()] = it_real_inputs->second; } // erase old parameter in map real_inputs_.erase(old_anf_node); @@ -697,7 +695,6 @@ std::set KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "Update graph id: " << graph_id_; std::map> real_inputs_map; - std::vector> replace_list; for (auto &it : real_inputs_) { auto parameter = it.first; MS_EXCEPTION_IF_NULL(parameter); @@ -722,16 +719,9 @@ void KernelGraph::UpdateCallRealInput() { MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " insert real input:" << new_real_input->DebugString(); (void)real_inputs.insert(new_real_input); - if (new_real_input->isa()) { - replace_list.emplace_back(parameter, new_real_input); - parameter = new_real_input; - } } real_inputs_map[parameter] = real_inputs; } - for (auto [parameter, arg] : replace_list) { - ReplaceNode(parameter, arg); - } real_inputs_ = real_inputs_map; } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 497bc8df9..2cd7a2340 100644 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -99,7 +99,7 @@ class KernelGraph : public FuncGraph { std::vector *MutableValidInputs() { return &valid_inputs_; } std::vector valid_inputs() const { return valid_inputs_; } // replace node in graph - void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node); + void ReplaceNode(NotNull old_anf_node, NotNull new_anf_node); // set stream label of graph void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } // get stream label of graph diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index 6564e6148..d3befcefe 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -459,6 +459,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); continue; + } else if (IsValueNode(anf)) { + continue; } MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; } @@ -613,6 +615,7 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP if (ExistSummaryNode(graph.get())) { graph->set_summary_node_exist(true); } + opt::BackendCommonOptimization(graph); return graph; } @@ -626,7 +629,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector ¶ auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); if (backend_parameter == nullptr) { // for example "def f(x,y,z) {return x + y}", parameter z in unused - CreateNewParameterFromParameter(parameter, false, graph); + CreateNewParameterFromParameter(parameter, true, graph); MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); continue; } diff --git a/mindspore/ccsrc/utils/union_find_set.h b/mindspore/ccsrc/utils/union_find_set.h new file mode 100644 index 000000000..1c98c73b9 --- /dev/null +++ b/mindspore/ccsrc/utils/union_find_set.h @@ -0,0 +1,85 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ +#define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ + +#include +#include + +namespace mindspore { +template +class UnionFindSet { + public: + UnionFindSet() : union_find_set_() {} + void Add(const T &elem) { + if (union_find_set_.find(elem) != union_find_set_.end()) { + return; + } + + union_find_set_[elem] = elem; + } + + T Find(const T &key) { + T key_parent = key; + auto iter = union_find_set_.find(key_parent); + if (iter == union_find_set_.end()) { + MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; + } + while (key_parent != iter->second) { + key_parent = iter->second; + iter = union_find_set_.find(key_parent); + if (iter == union_find_set_.end()) { + MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; + } + } + + T tmp = key; + T tmp_parent; + while (tmp != key_parent) { + iter = union_find_set_.find(tmp); + if (iter == union_find_set_.end()) { + MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp; + } + tmp_parent = iter->second; + union_find_set_[tmp] = key_parent; + tmp = tmp_parent; + } + return key_parent; + } + + void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); } + + std::map> GetSets() { + std::map> result; + for (auto &iter : union_find_set_) { + (void)Find(iter.first); + } + for (auto &iter : union_find_set_) { + T parent = Find(iter.first); + result[parent].insert(iter.first); + } + return result; + } + + private: + std::map union_find_set_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ -- GitLab