diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 803c910d1d46ab41b26f1727a27b50c2da66de98..4e01e9003f63ef73b2607325623aebf35306fbea 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() { void FuncGraph::add_parameter(const ParameterPtr &p) { if (manager_.lock()) { - std::vector new_params = parameters_; - new_params.push_back(p); - manager_.lock()->SetParameters(shared_from_base(), new_params); + manager_.lock()->AddParameter(shared_from_base(), p); } else { parameters_.push_back(p); } @@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { p->set_name(name); p->debug_info()->set_name(name); - std::vector new_params = parameters_; - // append parameter - new_params.push_back(p); - if (manager_.lock()) { - manager_.lock()->SetParameters(shared_from_base(), new_params); + manager_.lock()->AddParameter(shared_from_base(), p); } else { parameters_.push_back(p); } diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index 5f09dfe6b585e50fc617dfcf85b791c8cead5bc2..b1be892a53928ad19d819a24a1dc000c61b406b3 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase { const std::vector ¶meters() const { return parameters_; } virtual ParameterPtr add_parameter(); void add_parameter(const ParameterPtr &p); + void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } void set_parameters(const std::vector ¶ms) { parameters_ = params; } // add a weight parameter with specific name ParameterPtr AddWeightParameter(const std::string &name); diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 291a752405f3151ecd684e677ce8dfdd6d77e843..cf56500aeae6a984a64a63c90e060d0e287ec225 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector &changes, EdgeTupl for (auto &iter : changes) { auto operation = iter.op; auto args = iter.args; - if (operation == Change::kTxSetEdge) { - auto edge = args.cast(); - auto old_node = edge.root_node->input(edge.index); - (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; - (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; - (*rms)[old_node] += 1; - (*adds)[edge.new_node] += 1; - edge.root_node->set_input(edge.index, edge.new_node); - } else if (operation == Change::kTxSetParams) { - auto param = args.cast(); - MS_EXCEPTION_IF_NULL(param.func_graph); - auto old_parameters = param.func_graph->parameters(); - for (auto &p : param.params) { - (*adds)[p] += 1; - } - for (auto &p : old_parameters) { - (*rms)[p] += 1; - } - param.func_graph->set_parameters(param.params); + switch (operation) { + case Change::kTxSetEdge: { + auto edge = args.cast(); + auto old_node = edge.root_node->input(edge.index); + (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; + (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; + (*rms)[old_node] += 1; + (*adds)[edge.new_node] += 1; + edge.root_node->set_input(edge.index, edge.new_node); + } break; + case Change::kTxSetParams: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + auto old_parameters = param.func_graph->parameters(); + for (auto &p : param.params) { + (*adds)[p] += 1; + } + for (auto &p : old_parameters) { + (*rms)[p] += 1; + } + param.func_graph->set_parameters(param.params); + } break; + case Change::kTxAddParam: { + auto param = args.cast(); + MS_EXCEPTION_IF_NULL(param.func_graph); + (*adds)[param.param] += 1; + auto param_node = param.param->cast(); + param.func_graph->append_parameter(param_node); + } break; + default: + break; } } } @@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector { void KeepRoots(const std::vector &roots = {}); void RemoveRoots(); void SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters); + void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter); void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); @@ -400,6 +401,7 @@ class FuncGraphTransaction { // set parameters of a func graph void SetParameters(FuncGraphPtr fg, const std::vector ¶ms); + void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m); // replace old_node with new_node bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); @@ -427,6 +429,18 @@ struct ArgsOfSetParams { } }; +// args for add param +struct ArgsOfAddParam { + FuncGraphPtr func_graph; + AnfNodePtr param; + bool operator==(const ArgsOfAddParam &other) const { return &other == this; } + + friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) { + os << "[ArgsOfAddParam]"; + return os; + } +}; + // args for set edge struct ArgsOfSetEdge { CNodePtr root_node; @@ -441,7 +455,7 @@ struct ArgsOfSetEdge { }; struct Change { - enum OpName { kTxSetParams, kTxSetEdge }; + enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam }; OpName op; Any args; Change(OpName name, const Any ¶) : op(name), args(para) {}