提交 c0e1504d 编写于 作者: S ShawnXuan

branch id OK


Former-commit-id: 3a320ecf4e94cc9d2a26942a9b3375125bdf510a
上级 d28e2ad5
......@@ -223,9 +223,6 @@ void LogicalGraph::SetDepthAndBranchId4Nodes() {
});
}
}
ForEachLogicalNode<LogicalNode>([&](LogicalNode* node) {
LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
});
// set branch id for nodes
// 1. get longest path
......@@ -236,12 +233,51 @@ void LogicalGraph::SetDepthAndBranchId4Nodes() {
break;
}
}
LOG(INFO) << "longest path size " << max_depth << " " << longest_path_stack.size();
// 2. set branch_id = 0 for ops on longest path
HashMap<LogicalNode*, bool> has_branch_id;
std::queue<LogicalNode*> queue4branch_id;
while (!longest_path_stack.empty()) {
LogicalNode* node = longest_path_stack.top();
longest_path_stack.pop();
LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
// LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth();
has_branch_id[node] = true;
}
// 3. push other nodes to queue
ForEachNode([&](LogicalNode* node) {
if (!has_branch_id[node]) queue4branch_id.push(node);
});
// 4. every depth has 1 node on longest path
HashMap<int, int> depth2branch_num;
for (int i = 0; i < max_depth; ++i) depth2branch_num[i + 1] = 1;
// 5. process other nodes
while (!queue4branch_id.empty()) {
LogicalNode* cur_node = queue4branch_id.front();
queue4branch_id.pop();
bool should_set_branch_id = false;
ForEachInNode(cur_node, [&](LogicalNode* in_node) {
if (!should_set_branch_id) should_set_branch_id = has_branch_id[in_node];
});
ForEachOutNode(cur_node, [&](LogicalNode* Out_node) {
if (!should_set_branch_id) should_set_branch_id = has_branch_id[Out_node];
});
if (should_set_branch_id) {
cur_node->set_branch_id(depth2branch_num[cur_node->depth()]);
depth2branch_num[cur_node->depth()]++;
has_branch_id[cur_node] = true;
} else {
queue4branch_id.push(cur_node);
}
}
// print depth for check
ForEachLogicalNode<LogicalNode>([&](LogicalNode* node) {
LOG(INFO) << node->op_vec().at(0)->op_name() << " " << node->depth() << " "
<< node->branch_id();
});
}
template<typename NodeType>
......@@ -261,6 +297,7 @@ bool LogicalGraph::GetLongestPath(
if (!get_it) longest_path_stack.pop();
return get_it;
}
void LogicalGraph::FixSharedModelNodes(
const HashMap<std::string, std::vector<LogicalNode*>>& op_name2nodes) {
HashSet<std::string> all_shared_model_op_names;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册