提交 4cffb0a3 编写于 作者: Z zhoufeng

New control sink support dynamic loss scale

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 71dce2f5
......@@ -69,7 +69,7 @@ CNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNo
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
(void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(plant_inputs));
} else if (AnfAlgo::IsTupleOutput(input_node)) {
} else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) {
ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes);
} else {
dyn_input_sizes.push_back(-1);
......
......@@ -68,8 +68,9 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
return nullptr;
}
if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(),
[](const AnfNodePtr &node) { return AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); })) {
if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) {
return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node);
})) {
return ConvertTupleInputToMakeTuple(func_graph, cnode);
}
return nullptr;
......
......@@ -18,6 +18,7 @@
#include <memory>
#include "session/ascend_control_parser.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/union_find_set.h"
static constexpr size_t kCNodePrim = 0;
static constexpr size_t kCNodeCallArg = 1;
......@@ -57,6 +58,110 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
}
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
return;
}
memo->insert(kg.get());
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
if (para->isa<Parameter>()) {
union_find_set->Add(para);
}
for (auto &arg : iter.second) {
if (!arg->isa<Parameter>()) {
continue;
}
union_find_set->Add(arg);
}
}
for (auto &child : kg->child_graph_order()) {
InitUnionFindSet(NOT_NULL(child), union_find_set, memo);
}
}
static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) {
return;
}
memo->insert(kg.get());
const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs();
for (auto &iter : real_inputs) {
auto &para = iter.first;
for (auto &arg : iter.second) {
if (!arg->isa<Parameter>()) {
continue;
}
union_find_set->Union(arg, para);
}
}
for (auto &child : kg->child_graph_order()) {
UnionParentParameter(NOT_NULL(child), union_find_set, memo);
}
}
static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) {
UnionFindSet<AnfNodePtr> result;
std::set<KernelGraphPtr> memo;
InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
memo.clear();
UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo));
return result;
}
static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter,
const std::set<AnfNodePtr> &parameter_reuse_set,
const NotNull<std::set<KernelGraphPtr> *> memo) {
if (parameter_reuse_set.empty()) {
MS_LOG(EXCEPTION) << "parameter_reuse_set is empty.";
}
if (memo->find(kg.get()) != memo->end()) {
return;
}
memo->insert(kg.get());
for (auto &para : parameter_reuse_set) {
if (para == main_parameter.get()) {
continue;
}
MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to "
<< main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get());
kg->ReplaceNode(NOT_NULL(para), main_parameter);
}
for (auto &child : kg->child_graph_order()) {
RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo);
}
}
static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) {
auto parameter_reuse_sets = parameter_set->GetSets();
for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) {
if (parameter_reuse_set.size() <= 1) {
continue;
}
AnfNodePtr main_parameter = key;
std::set<AnfNodePtr> root_inputs_set;
const auto &root_inputs_vector = root_kg->inputs();
root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end());
for (auto &node : parameter_reuse_set) {
if (root_inputs_set.find(node) == root_inputs_set.end()) {
continue;
}
main_parameter = node;
}
std::set<KernelGraphPtr> memo;
RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo));
}
}
void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
std::set<KernelGraphPtr> memo;
ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo));
......@@ -68,6 +173,11 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) {
}
graph_id_map[g->graph_id()] = g;
}
// Make UnionFindSet
UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg);
// Reuse Parameter
ReuseParameter(kg, NOT_NULL(&parameter_set));
// Insert Assign
ChildGraphDataAssign(graph_id_map);
}
......@@ -324,29 +434,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
InsertDependToGraph(kg, NOT_NULL(assign_node));
}
void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) {
if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) {
MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple";
CNodePtr cnode_arg = arg.get()->cast<CNodePtr>();
CNodePtr cnode_param = param.get()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode_arg);
MS_EXCEPTION_IF_NULL(cnode_param);
if (cnode_arg->size() != cnode_param->size()) {
MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param "
<< param->DebugString() << " size " << cnode_param->size();
}
for (size_t i = 1; i < cnode_param->size(); ++i) {
LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i)));
}
} else if (arg->isa<CNode>()) {
InsertAssignToGraph(target_graph, arg, param);
} else {
MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type.";
}
}
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
......
......@@ -52,9 +52,6 @@ class AscendControlParser {
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph,
NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param);
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
......
......@@ -224,14 +224,6 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> &parameters,
MS_LOG(INFO) << "Parameter and arg are same";
continue;
}
// if arg is a parameter ,then reuse this parameter
if (args[i]->isa<Parameter>()) {
MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id()
<< " reuse parameter:" << args[i]->DebugString()
<< " of graph:" << AnfAlgo::GetGraphId(args[i].get());
child_graph->ReplaceNode(parameters[i], args[i]);
continue;
}
child_graph->SetRealInput(parameters[i], args[i]);
}
}
......@@ -412,7 +404,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
VectorRef *const outputs) {
MS_LOG(INFO) << "start";
auto kernel_graph = GetGraph(graph_id);
DumpIR("./run_graph.ir", kernel_graph);
MS_EXCEPTION_IF_NULL(kernel_graph);
// if none of child graph and no anf output exists
if (!kernel_graph->executable()) {
......@@ -1134,7 +1125,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId
MS_EXCEPTION_IF_NULL(backend_arg);
MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString()
<< "] will be replaced.";
to_graph->ReplaceNode(backend_parameter, backend_arg);
to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg));
return;
}
MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node"
......
......@@ -587,9 +587,7 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
return false;
}
void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) {
MS_EXCEPTION_IF_NULL(old_anf_node);
MS_EXCEPTION_IF_NULL(new_anf_node);
void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) {
MS_EXCEPTION_IF_NULL(inputs_);
auto it = node_output_edges_.find(old_anf_node);
if (it != node_output_edges_.end()) {
......@@ -604,16 +602,16 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
continue;
}
for (size_t i = 1; i < output_node_inputs.size(); i++) {
if (output_node_inputs[i] == old_anf_node) {
if (output_node_inputs[i] == old_anf_node.get()) {
output_cnode->set_input(i, new_anf_node);
}
}
// update graph inputs
for (size_t i = 0; i < inputs_->size(); i++) {
if ((*inputs_)[i] == old_anf_node) {
if ((*inputs_)[i] == old_anf_node.get()) {
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
<< ",new graph input:" << new_anf_node->DebugString();
(*inputs_)[i] = new_anf_node;
(*inputs_)[i] = new_anf_node.get();
break;
}
}
......@@ -621,7 +619,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
// update front to backend map
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
// update output depend relations
node_output_edges_[new_anf_node] = it->second;
node_output_edges_[new_anf_node.get()] = it->second;
(void)node_output_edges_.erase(old_anf_node);
}
// update graph inputs in child graph
......@@ -633,7 +631,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
iter->second = it_real_inputs->second;
} else {
real_inputs_[new_anf_node] = it_real_inputs->second;
real_inputs_[new_anf_node.get()] = it_real_inputs->second;
}
// erase old parameter in map
real_inputs_.erase(old_anf_node);
......@@ -697,7 +695,6 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr &parameter) {
void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "Update graph id: " << graph_id_;
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map;
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replace_list;
for (auto &it : real_inputs_) {
auto parameter = it.first;
MS_EXCEPTION_IF_NULL(parameter);
......@@ -722,16 +719,9 @@ void KernelGraph::UpdateCallRealInput() {
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
<< " insert real input:" << new_real_input->DebugString();
(void)real_inputs.insert(new_real_input);
if (new_real_input->isa<Parameter>()) {
replace_list.emplace_back(parameter, new_real_input);
parameter = new_real_input;
}
}
real_inputs_map[parameter] = real_inputs;
}
for (auto [parameter, arg] : replace_list) {
ReplaceNode(parameter, arg);
}
real_inputs_ = real_inputs_map;
}
......
......@@ -99,7 +99,7 @@ class KernelGraph : public FuncGraph {
std::vector<bool> *MutableValidInputs() { return &valid_inputs_; }
std::vector<bool> valid_inputs() const { return valid_inputs_; }
// replace node in graph
void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node);
void ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node);
// set stream label of graph
void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; }
// get stream label of graph
......
......@@ -459,6 +459,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph)
if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
continue;
} else if (IsValueNode<FuncGraph>(anf)) {
continue;
}
MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
}
......@@ -613,6 +615,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
if (ExistSummaryNode(graph.get())) {
graph->set_summary_node_exist(true);
}
opt::BackendCommonOptimization(graph);
return graph;
}
......@@ -626,7 +629,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
if (backend_parameter == nullptr) {
// for example "def f(x,y,z) {return x + y}", parameter z in unused
CreateNewParameterFromParameter(parameter, false, graph);
CreateNewParameterFromParameter(parameter, true, graph);
MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
continue;
}
......
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
#define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
#include <map>
#include <set>
namespace mindspore {
template <class T>
class UnionFindSet {
public:
UnionFindSet() : union_find_set_() {}
void Add(const T &elem) {
if (union_find_set_.find(elem) != union_find_set_.end()) {
return;
}
union_find_set_[elem] = elem;
}
T Find(const T &key) {
T key_parent = key;
auto iter = union_find_set_.find(key_parent);
if (iter == union_find_set_.end()) {
MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
}
while (key_parent != iter->second) {
key_parent = iter->second;
iter = union_find_set_.find(key_parent);
if (iter == union_find_set_.end()) {
MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent;
}
}
T tmp = key;
T tmp_parent;
while (tmp != key_parent) {
iter = union_find_set_.find(tmp);
if (iter == union_find_set_.end()) {
MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp;
}
tmp_parent = iter->second;
union_find_set_[tmp] = key_parent;
tmp = tmp_parent;
}
return key_parent;
}
void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); }
std::map<T, std::set<T>> GetSets() {
std::map<T, std::set<T>> result;
for (auto &iter : union_find_set_) {
(void)Find(iter.first);
}
for (auto &iter : union_find_set_) {
T parent = Find(iter.first);
result[parent].insert(iter.first);
}
return result;
}
private:
std::map<T, T> union_find_set_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册