提交 df0d89ed 编写于 作者: Z zhousiyi

use addparam to optimize setparam to reduce manager overhead

上级 f42e36b6
...@@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() { ...@@ -68,9 +68,7 @@ ParameterPtr FuncGraph::add_parameter() {
void FuncGraph::add_parameter(const ParameterPtr &p) { void FuncGraph::add_parameter(const ParameterPtr &p) {
if (manager_.lock()) { if (manager_.lock()) {
std::vector<AnfNodePtr> new_params = parameters_; manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
new_params.push_back(p);
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params);
} else { } else {
parameters_.push_back(p); parameters_.push_back(p);
} }
...@@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { ...@@ -82,12 +80,8 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) {
p->set_name(name); p->set_name(name);
p->debug_info()->set_name(name); p->debug_info()->set_name(name);
std::vector<AnfNodePtr> new_params = parameters_;
// append parameter
new_params.push_back(p);
if (manager_.lock()) { if (manager_.lock()) {
manager_.lock()->SetParameters(shared_from_base<FuncGraph>(), new_params); manager_.lock()->AddParameter(shared_from_base<FuncGraph>(), p);
} else { } else {
parameters_.push_back(p); parameters_.push_back(p);
} }
......
...@@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase { ...@@ -158,6 +158,7 @@ class FuncGraph : public FuncGraphBase {
const std::vector<AnfNodePtr> &parameters() const { return parameters_; } const std::vector<AnfNodePtr> &parameters() const { return parameters_; }
virtual ParameterPtr add_parameter(); virtual ParameterPtr add_parameter();
void add_parameter(const ParameterPtr &p); void add_parameter(const ParameterPtr &p);
void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); }
void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; } void set_parameters(const std::vector<AnfNodePtr> &params) { parameters_ = params; }
// add a weight parameter with specific name // add a weight parameter with specific name
ParameterPtr AddWeightParameter(const std::string &name); ParameterPtr AddWeightParameter(const std::string &name);
......
...@@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A ...@@ -420,6 +420,12 @@ void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector<A
tr.Commit(); tr.Commit();
} }
void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter) {
auto tr = Transact();
tr.AddParameter(fg, parameter);
tr.Commit();
}
bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
auto tr = Transact(); auto tr = Transact();
bool success = tr.Replace(old_node, new_node); bool success = tr.Replace(old_node, new_node);
...@@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl ...@@ -532,25 +538,37 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl
for (auto &iter : changes) { for (auto &iter : changes) {
auto operation = iter.op; auto operation = iter.op;
auto args = iter.args; auto args = iter.args;
if (operation == Change::kTxSetEdge) { switch (operation) {
auto edge = args.cast<ArgsOfSetEdge>(); case Change::kTxSetEdge: {
auto old_node = edge.root_node->input(edge.index); auto edge = args.cast<ArgsOfSetEdge>();
(*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1; auto old_node = edge.root_node->input(edge.index);
(*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1; (*rm_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, old_node))] += 1;
(*rms)[old_node] += 1; (*add_edges)[std::make_pair(edge.root_node, std::make_pair(edge.index, edge.new_node))] += 1;
(*adds)[edge.new_node] += 1; (*rms)[old_node] += 1;
edge.root_node->set_input(edge.index, edge.new_node); (*adds)[edge.new_node] += 1;
} else if (operation == Change::kTxSetParams) { edge.root_node->set_input(edge.index, edge.new_node);
auto param = args.cast<ArgsOfSetParams>(); } break;
MS_EXCEPTION_IF_NULL(param.func_graph); case Change::kTxSetParams: {
auto old_parameters = param.func_graph->parameters(); auto param = args.cast<ArgsOfSetParams>();
for (auto &p : param.params) { MS_EXCEPTION_IF_NULL(param.func_graph);
(*adds)[p] += 1; auto old_parameters = param.func_graph->parameters();
} for (auto &p : param.params) {
for (auto &p : old_parameters) { (*adds)[p] += 1;
(*rms)[p] += 1; }
} for (auto &p : old_parameters) {
param.func_graph->set_parameters(param.params); (*rms)[p] += 1;
}
param.func_graph->set_parameters(param.params);
} break;
case Change::kTxAddParam: {
auto param = args.cast<ArgsOfAddParam>();
MS_EXCEPTION_IF_NULL(param.func_graph);
(*adds)[param.param] += 1;
auto param_node = param.param->cast<ParameterPtr>();
param.func_graph->append_parameter(param_node);
} break;
default:
break;
} }
} }
} }
...@@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN ...@@ -599,6 +617,10 @@ void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector<AnfN
changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params});
} }
void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr &param) {
changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param});
}
bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) {
MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(old_node);
MS_EXCEPTION_IF_NULL(new_node); MS_EXCEPTION_IF_NULL(new_node);
......
...@@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { ...@@ -310,6 +310,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
void KeepRoots(const std::vector<FuncGraphPtr> &roots = {}); void KeepRoots(const std::vector<FuncGraphPtr> &roots = {});
void RemoveRoots(); void RemoveRoots();
void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters); void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &parameters);
void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &parameter);
void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false);
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value);
...@@ -400,6 +401,7 @@ class FuncGraphTransaction { ...@@ -400,6 +401,7 @@ class FuncGraphTransaction {
// set parameters of a func graph // set parameters of a func graph
void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params); void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> &params);
void AddParameter(FuncGraphPtr fg, const AnfNodePtr &param);
// replace old_node with new_node // replace old_node with new_node
bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
...@@ -427,6 +429,18 @@ struct ArgsOfSetParams { ...@@ -427,6 +429,18 @@ struct ArgsOfSetParams {
} }
}; };
// args for add param
struct ArgsOfAddParam {
FuncGraphPtr func_graph;
AnfNodePtr param;
bool operator==(const ArgsOfAddParam &other) const { return &other == this; }
friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddParam &) {
os << "[ArgsOfAddParam]";
return os;
}
};
// args for set edge // args for set edge
struct ArgsOfSetEdge { struct ArgsOfSetEdge {
CNodePtr root_node; CNodePtr root_node;
...@@ -441,7 +455,7 @@ struct ArgsOfSetEdge { ...@@ -441,7 +455,7 @@ struct ArgsOfSetEdge {
}; };
struct Change { struct Change {
enum OpName { kTxSetParams, kTxSetEdge }; enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam };
OpName op; OpName op;
Any args; Any args;
Change(OpName name, const Any &para) : op(name), args(para) {} Change(OpName name, const Any &para) : op(name), args(para) {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册