提交 8d2cb9bc 编写于 作者: K kswang

splitsort reorder getitem

上级 5306172f
......@@ -65,7 +65,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
result.outputs = outputs;
result.graph_id = kInvalidGraphId;
GraphId graph_id = kInvalidGraphId;
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
CreateOtherSession(target);
graph_id = other_sess_->CompileGraph(lst, outputs);
} else {
......@@ -76,7 +76,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
MS_LOG(INFO) << "PrecompileOnly, stop run graph";
return result;
}
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
other_sess_->BuildGraph(graph_id);
} else if (!is_multi_graph_sink_) {
target_sess_->BuildGraph(graph_id);
......@@ -279,7 +279,7 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s
VectorRef outputs;
// call ms rungraph (graphId, input ,output)
if (target != target_device_ && target != "") {
if (target != target_device_ && !target.empty()) {
other_sess_->RunGraph(g, inputs, &outputs);
} else {
target_sess_->RunGraph(g, inputs, &outputs);
......
......@@ -129,6 +129,62 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
}
}
bool IsGetItemNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Inputs of apply node is empty";
}
if (!IsValueNode<Primitive>(inputs[0])) {
return true;
}
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(inputs[0]);
return node_prim->name() == prim::kPrimTupleGetItem->name();
}
return false;
}
std::vector<AnfNodePtr> ReorderGetItemNode(const std::vector<AnfNodePtr> &nodes) {
std::vector<AnfNodePtr> result;
std::map<size_t, std::vector<AnfNodePtr>> insert_positions;
std::map<AnfNodePtr, size_t> node_positions;
for (auto &node : nodes) {
if (IsGetItemNode(node)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "Invalid get item node";
}
auto &parent = inputs[1];
auto iter = node_positions.find(parent);
if (iter != node_positions.end()) {
size_t position = iter->second;
auto iter_nodes = insert_positions.find(position);
if (iter_nodes != insert_positions.end()) {
iter_nodes->second.push_back(node);
} else {
(void)insert_positions.insert(
std::pair<size_t, std::vector<AnfNodePtr>>(position, std::vector<AnfNodePtr>{node}));
}
continue;
}
}
result.emplace_back(node);
node_positions[node] = result.size();
}
size_t insert_num = 0;
for (auto &item : insert_positions) {
size_t position = item.first + insert_num;
(void)result.insert(result.begin() + position, item.second.begin(), item.second.end());
insert_num += item.second.size();
}
return result;
}
std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &default_target) {
std::vector<AnfNodePtr> result;
std::stack<AnfNodePtr> to_visit;
......@@ -144,8 +200,8 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
handle_target = next_target;
}
auto &node = to_visit.top();
to_visit.pop();
MS_EXCEPTION_IF_NULL(node);
to_visit.pop();
result.emplace_back(node);
if (!node->isa<CNode>()) {
continue;
......@@ -178,7 +234,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
}
}
std::reverse(result.begin(), result.end());
return result;
return ReorderGetItemNode(result);
}
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册