提交 056f9f6d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2861 use addparam to replace setparam to reduce overhead

Merge pull request !2861 from xychow/optimize-setparam-to-addparam
......@@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() {
void FuncGraph::add_parameter(const ParameterPtr &p) {
if (manager_.lock()) {
std::vector<AnfNodePtr> new_params = parameters_;
new_params.push_back(p);
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
} else {
parameters_.push_back(p);
}
......@@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
p->set_name(name);
p->debug_info()->set_name(name);
std::vector<AnfNodePtr> new_params = parameters_;
// append parameter
new_params.push_back(p);
if (manager_.lock()) {
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
} else {
parameters_.push_back(p);
}
......
......@@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase {
const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
virtual ParameterPtr add_parameter();
void add_parameter(const ParameterPtr &p);
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
// add a weight parameter with specific name
ParameterPtr AddWeightParameter(const std::string &name);
......
......@@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A
tr.Commit();
}
void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
auto tr = Transact();
tr.AddParameter(fg, parameter);
tr.Commit();
}
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto tr = Transact();
bool success = tr.Replace(old_node, new_node);
......@@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
for (auto &iter : changes) {
auto operation = iter.op;
auto args = iter.args;
if (operation == Change::kTxSetEdge) {
auto edge = args.cast<ArgsOfSetEdge>();
auto old_node = edge.root_node->input(edge.index);
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
(*rms)[old_node] += 1;
(*adds)[edge.new_node] += 1;
edge.root_node->set_input(edge.index, edge.new_node);
} else if (operation == Change::kTxSetParams) {
auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters();
for (auto &p : param.params) {
(*adds)[p] += 1;
}
for (auto &p : old_parameters) {
(*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
switch (operation) {
case Change::kTxSetEdge: {
auto edge = args.cast<ArgsOfSetEdge>();
auto old_node = edge.root_node->input(edge.index);
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
(*rms)[old_node] += 1;
(*adds)[edge.new_node] += 1;
edge.root_node->set_input(edge.index, edge.new_node);
} break;
case Change::kTxSetParams: {
auto param = args.cast<ArgsOfSetParams>();
MS_EXCEPTION_IF_NULL(param.func_graph);
auto old_parameters = param.func_graph->parameters();
for (auto &p : param.params) {
(*adds)[p] += 1;
}
for (auto &p : old_parameters) {
(*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
} break;
case Change::kTxAddParam: {
auto param = args.cast<ArgsOfAddParam>();
MS_EXCEPTION_IF_NULL(param.func_graph);
(*adds)[param.param] += 1;
auto param_node = param.param->cast<ParameterPtr>();
param.func_graph->append_parameter(param_node);
} break;
default:
break;
}
}
}
......@@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
}
void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
}
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node);
......
......@@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
void RemoveRoots();
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
......@@ -400,6 +401,7 @@ class FuncGraphTransaction {
// set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);
// replace old_node with new_node
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
......@@ -427,6 +429,18 @@ struct ArgsOfSetParams {
}
};
// args for add param
struct ArgsOfAddParam {
FuncGraphPtr func_graph;
AnfNodePtr param;
bool operator==(const ArgsOfAddParam &other) const { return &other == this; }
friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) {
os << "[ArgsOfAddParam]";
return os;
}
};
// args for set edge
struct ArgsOfSetEdge {
CNodePtr root_node;
......@@ -441,7 +455,7 @@ struct ArgsOfSetEdge {
};
struct Change {
enum OpName { kTxSetParams, kTxSetEdge };
enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam };
OpName op;
Any args;
Change(OpName name, const Any &para) : op(name), args(para) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册