Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0cefb946
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0cefb946
编写于
7月 11, 2018
作者:
Y
Yan Chunwei
提交者:
GitHub
7月 11, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add topological sortting (#12059)
上级
f9202447
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
188 addition
and
3 deletion
+188
-3
paddle/fluid/inference/analysis/data_flow_graph.cc
paddle/fluid/inference/analysis/data_flow_graph.cc
+85
-1
paddle/fluid/inference/analysis/data_flow_graph.h
paddle/fluid/inference/analysis/data_flow_graph.h
+36
-0
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+67
-2
未找到文件。
paddle/fluid/inference/analysis/data_flow_graph.cc
浏览文件 @
0cefb946
...
...
@@ -90,6 +90,20 @@ std::string DataFlowGraph::DotString() const {
return
dot
.
Build
();
}
std
::
string
DataFlowGraph
::
HumanReadableInfo
(
bool
show_values
,
bool
show_functions
)
const
{
std
::
stringstream
values
,
functions
;
for
(
auto
&
n
:
nodes
.
nodes
())
{
if
(
show_values
&&
n
->
IsValue
())
{
values
<<
n
->
repr
()
<<
"
\n
"
;
}
if
(
show_functions
&&
n
->
IsFunction
())
{
functions
<<
n
->
repr
()
<<
"
\n
"
;
}
}
return
"Values:
\n
"
+
values
.
str
()
+
"
\n\n
"
+
"Functions:
\n
"
+
functions
.
str
();
}
//
// NodesBFSIterator
//
...
...
@@ -146,7 +160,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
if
((
!
queue_
.
empty
())
&&
(
!
other
.
queue_
.
empty
()))
{
return
queue_
.
front
()
==
other
.
queue_
.
front
()
&&
visited_
.
size
()
==
other
.
visited_
.
size
();
// here need to check the
// equality of queue and
// equality of queue and
// visited. Just a light but week implementation.
}
return
false
;
...
...
@@ -208,6 +222,76 @@ Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
return
stack_
.
top
();
}
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
NodesTSIterator
(
const
std
::
vector
<
Node
*>
&
source
)
{
PADDLE_ENFORCE
(
!
source
.
empty
(),
"Start points of topological sorting should not be empty!"
);
std
::
unordered_set
<
Node
*>
visited
;
std
::
unordered_set
<
Node
*>
to_visit
{
source
.
begin
(),
source
.
end
()};
std
::
vector
<
Node
*>
inlink_visited
;
while
(
!
to_visit
.
empty
())
{
std
::
vector
<
Node
*>
queue
(
to_visit
.
begin
(),
to_visit
.
end
());
for
(
auto
*
p
:
queue
)
{
inlink_visited
.
clear
();
std
::
copy_if
(
p
->
inlinks
.
begin
(),
p
->
inlinks
.
end
(),
std
::
back_inserter
(
inlink_visited
),
[
&
](
Node
*
x
)
{
return
visited
.
count
(
x
);
});
if
(
inlink_visited
.
size
()
==
p
->
inlinks
.
size
())
{
sorted_
.
push_back
(
p
);
for
(
auto
*
_
:
p
->
outlinks
)
{
if
(
!
visited
.
count
(
_
))
{
to_visit
.
insert
(
_
);
}
}
to_visit
.
erase
(
p
);
visited
.
insert
(
p
);
}
}
}
}
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
NodesTSIterator
(
const
paddle
::
inference
::
analysis
::
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
&
other
)
:
sorted_
(
other
.
sorted_
),
cursor_
(
other
.
cursor_
)
{}
Node
&
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
operator
*
()
{
PADDLE_ENFORCE_LT
(
cursor_
,
sorted_
.
size
());
return
*
sorted_
[
cursor_
];
}
paddle
::
inference
::
analysis
::
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
&
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
operator
++
()
{
if
(
++
cursor_
>=
sorted_
.
size
())
{
sorted_
.
clear
();
cursor_
=
0
;
}
return
*
this
;
}
paddle
::
inference
::
analysis
::
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
&
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
operator
=
(
const
paddle
::
inference
::
analysis
::
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
&
other
)
{
cursor_
=
other
.
cursor_
;
sorted_
=
other
.
sorted_
;
return
*
this
;
}
bool
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
operator
==
(
const
paddle
::
inference
::
analysis
::
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
&
other
)
{
return
sorted_
==
other
.
sorted_
&&
cursor_
==
other
.
cursor_
;
}
Node
*
GraphTraits
<
DataFlowGraph
>::
NodesTSIterator
::
operator
->
()
{
PADDLE_ENFORCE_LT
(
cursor_
,
sorted_
.
size
());
return
sorted_
[
cursor_
];
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/analysis/data_flow_graph.h
浏览文件 @
0cefb946
...
...
@@ -48,6 +48,9 @@ struct DataFlowGraph {
// Output a DOT graph file for debug.
std
::
string
DotString
()
const
;
std
::
string
HumanReadableInfo
(
bool
show_values
=
true
,
bool
show_functions
=
true
)
const
;
private:
// Remove duplicate edges and so on.
void
Clean
();
...
...
@@ -107,6 +110,32 @@ struct GraphTraits<DataFlowGraph> {
std
::
unordered_set
<
Node
*>
visited_
;
};
// Topological sorting iterator on nodes.
struct
NodesTSIterator
:
public
std
::
iterator
<
std
::
forward_iterator_tag
,
Node
*>
{
NodesTSIterator
()
=
default
;
explicit
NodesTSIterator
(
const
std
::
vector
<
Node
*>
&
source
);
NodesTSIterator
(
NodesTSIterator
&&
other
)
:
sorted_
(
std
::
move
(
other
.
sorted_
)),
cursor_
(
other
.
cursor_
)
{
other
.
cursor_
=
0
;
}
NodesTSIterator
(
const
NodesTSIterator
&
other
);
Node
&
operator
*
();
NodesTSIterator
&
operator
++
();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesTSIterator
&
operator
=
(
const
NodesTSIterator
&
other
);
bool
operator
==
(
const
NodesTSIterator
&
other
);
bool
operator
!=
(
const
NodesTSIterator
&
other
)
{
return
!
(
*
this
==
other
);
}
Node
*
operator
->
();
private:
std
::
vector
<
Node
*>
sorted_
;
int
cursor_
{
0
};
};
explicit
GraphTraits
(
DataFlowGraph
*
graph
)
:
graph_
(
graph
)
{}
// default use BFS to visit the nodes.
...
...
@@ -119,17 +148,24 @@ struct GraphTraits<DataFlowGraph> {
iterator_range
<
NodesDFSIterator
>
nodes_in_DFS
()
{
return
iterator_range
<
NodesDFSIterator
>
(
nodes_dfs_begin
(),
nodes_dfs_end
());
}
iterator_range
<
NodesTSIterator
>
nodes_in_TS
()
{
return
iterator_range
<
NodesTSIterator
>
(
nodes_ts_begin
(),
nodes_ts_end
());
}
private:
NodesBFSIterator
nodes_bfs_begin
()
{
return
NodesBFSIterator
(
graph_
->
inputs
);
}
NodesBFSIterator
nodes_bfs_end
()
{
return
NodesBFSIterator
();
}
NodesDFSIterator
nodes_dfs_begin
()
{
return
NodesDFSIterator
(
graph_
->
inputs
);
}
NodesDFSIterator
nodes_dfs_end
()
{
return
NodesDFSIterator
();
}
NodesTSIterator
nodes_ts_begin
()
{
return
NodesTSIterator
(
graph_
->
inputs
);
}
NodesTSIterator
nodes_ts_end
()
{
return
NodesTSIterator
();
}
private:
DataFlowGraph
*
graph_
;
};
...
...
paddle/fluid/inference/analysis/data_flow_graph_tester.cc
浏览文件 @
0cefb946
...
...
@@ -24,11 +24,11 @@ TEST(DataFlowGraph, BFS) {
auto
dfg
=
ProgramDescToDFG
(
desc
);
dfg
.
Build
();
for
(
auto
*
in
:
dfg
.
inputs
)
{
for
(
auto
*
in
:
dfg
.
inputs
)
{
LOG
(
INFO
)
<<
"inputs: "
<<
in
->
name
()
<<
" "
<<
static_cast
<
int
>
(
in
->
type
());
}
for
(
auto
*
out
:
dfg
.
outputs
)
{
for
(
auto
*
out
:
dfg
.
outputs
)
{
LOG
(
INFO
)
<<
"outputs: "
<<
out
->
name
()
<<
" "
<<
static_cast
<
int
>
(
out
->
type
());
}
...
...
@@ -57,6 +57,71 @@ TEST(DataFlowGraph, DFS) {
ASSERT_EQ
(
count
,
dfg
.
nodes
.
size
());
}
// Topological sorting.
/*
* Graph topology
* inputs: 0, 1, 2
* 0 -> 4
* 0 -> 5
* 1 -> 6
* 2 -> 7
* 4 -> 5
* 4 -> 7
* 4 -> 3
* 7 -> 3
*/
TEST
(
DataFlowGraph
,
TS
)
{
DataFlowGraph
graph
;
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
auto
*
node
=
graph
.
nodes
.
Create
(
Node
::
Type
::
kValue
);
node
->
SetName
(
"node-"
+
std
::
to_string
(
i
));
}
auto
add_link
=
[
&
](
int
i
,
int
j
)
{
Node
*
source
=
graph
.
nodes
.
GetMutable
(
i
);
Node
*
target
=
graph
.
nodes
.
GetMutable
(
j
);
target
->
inlinks
.
push_back
(
source
);
source
->
outlinks
.
push_back
(
target
);
};
graph
.
inputs
.
push_back
(
graph
.
nodes
.
GetMutable
(
0
));
graph
.
inputs
.
push_back
(
graph
.
nodes
.
GetMutable
(
1
));
graph
.
inputs
.
push_back
(
graph
.
nodes
.
GetMutable
(
2
));
add_link
(
0
,
4
);
add_link
(
0
,
5
);
add_link
(
1
,
6
);
add_link
(
2
,
7
);
add_link
(
4
,
5
);
add_link
(
4
,
7
);
add_link
(
4
,
3
);
add_link
(
7
,
3
);
auto
its
=
GraphTraits
<
DataFlowGraph
>
(
&
graph
).
nodes_in_TS
();
std
::
vector
<
int
>
sorted_ids
;
for
(
auto
it
=
its
.
begin
();
it
!=
its
.
end
();
++
it
)
{
LOG
(
INFO
)
<<
it
->
name
();
sorted_ids
.
push_back
(
it
->
id
());
}
// Assert a occurs prior to b in the sorted_ids.
auto
assert_positive_sequence_pair
=
[
&
](
int
a
,
int
b
)
{
auto
a_offset
=
std
::
find
(
sorted_ids
.
begin
(),
sorted_ids
.
end
(),
a
);
auto
b_offset
=
std
::
find
(
sorted_ids
.
begin
(),
sorted_ids
.
end
(),
b
);
ASSERT_LT
(
a_offset
,
b_offset
);
};
assert_positive_sequence_pair
(
2
,
7
);
assert_positive_sequence_pair
(
7
,
3
);
assert_positive_sequence_pair
(
4
,
3
);
assert_positive_sequence_pair
(
0
,
4
);
assert_positive_sequence_pair
(
0
,
5
);
assert_positive_sequence_pair
(
1
,
6
);
assert_positive_sequence_pair
(
4
,
5
);
assert_positive_sequence_pair
(
4
,
7
);
}
}
// namespace analysis
}
// namespace inference
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录